1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Class MirroredStrategy implementing tf.distribute.Strategy.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22 23from tensorflow.python.distribute import collective_util 24from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib 25from tensorflow.python.distribute import cross_device_utils 26from tensorflow.python.distribute import device_util 27from tensorflow.python.distribute import distribute_lib 28from tensorflow.python.distribute import distribute_utils 29from tensorflow.python.distribute import input_lib 30from tensorflow.python.distribute import mirrored_run 31from tensorflow.python.distribute import multi_worker_util 32from tensorflow.python.distribute import numpy_dataset 33from tensorflow.python.distribute import reduce_util 34from tensorflow.python.distribute import values 35from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver 36from tensorflow.python.eager import context 37from tensorflow.python.eager import tape 38from tensorflow.python.framework import config 39from tensorflow.python.framework import constant_op 40from tensorflow.python.framework import device as tf_device 41from tensorflow.python.framework import dtypes 42from tensorflow.python.framework import ops 43from tensorflow.python.ops import array_ops 44from tensorflow.python.ops import control_flow_ops 45from tensorflow.python.ops import variables as variables_lib 46from tensorflow.python.platform import tf_logging as logging 47from tensorflow.python.util import nest 48from tensorflow.python.util.tf_export import tf_export 49 50# TODO(josh11b): Replace asserts in this file with if ...: raise ... 51 52 53def _is_device_list_single_worker(devices): 54 """Checks whether the devices list is for single or multi-worker. 55 56 Args: 57 devices: a list of device strings or tf.config.LogicalDevice objects, for 58 either local or for remote devices. 59 60 Returns: 61 a boolean indicating whether these device strings are for local or for 62 remote. 63 64 Raises: 65 ValueError: if device strings are not consistent. 66 """ 67 specs = [] 68 for d in devices: 69 name = d.name if isinstance(d, context.LogicalDevice) else d 70 specs.append(tf_device.DeviceSpec.from_string(name)) 71 num_workers = len({(d.job, d.task, d.replica) for d in specs}) 72 all_local = all(d.job in (None, "localhost") for d in specs) 73 any_local = any(d.job in (None, "localhost") for d in specs) 74 75 if any_local and not all_local: 76 raise ValueError("Local device string cannot have job specified other " 77 "than 'localhost'") 78 79 if num_workers == 1 and not all_local: 80 if any(d.task is None for d in specs): 81 raise ValueError("Remote device string must have task specified.") 82 83 return num_workers == 1 84 85 86def _cluster_spec_to_device_list(cluster_spec, num_gpus_per_worker): 87 """Returns a device list given a cluster spec.""" 88 cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) 89 devices = [] 90 for task_type in ("chief", "worker"): 91 for task_id in range(len(cluster_spec.as_dict().get(task_type, []))): 92 if num_gpus_per_worker == 0: 93 devices.append("/job:%s/task:%d/device:CPU:0" % (task_type, task_id)) 94 else: 95 devices.extend([ 96 "/job:%s/task:%d/device:GPU:%i" % (task_type, task_id, gpu_id) 97 for gpu_id in range(num_gpus_per_worker) 98 ]) 99 return devices 100 101 102def _group_device_list(devices): 103 """Groups the devices list by task_type and task_id. 104 105 Args: 106 devices: a list of device strings for remote devices. 107 108 Returns: 109 a dict of list of device strings mapping from task_type to a list of devices 110 for the task_type in the ascending order of task_id. 111 """ 112 assert not _is_device_list_single_worker(devices) 113 device_dict = {} 114 115 for d in devices: 116 d_spec = tf_device.DeviceSpec.from_string(d) 117 118 # Create an entry for the task_type. 119 if d_spec.job not in device_dict: 120 device_dict[d_spec.job] = [] 121 122 # Fill the device list for task_type until it covers the task_id. 123 while len(device_dict[d_spec.job]) <= d_spec.task: 124 device_dict[d_spec.job].append([]) 125 126 device_dict[d_spec.job][d_spec.task].append(d) 127 128 return device_dict 129 130 131def _is_gpu_device(device): 132 return tf_device.DeviceSpec.from_string(device).device_type == "GPU" 133 134 135def _infer_num_gpus_per_worker(devices): 136 """Infers the number of GPUs on each worker. 137 138 Currently to make multi-worker cross device ops work, we need all workers to 139 have the same number of GPUs. 140 141 Args: 142 devices: a list of device strings, can be either local devices or remote 143 devices. 144 145 Returns: 146 number of GPUs per worker. 147 148 Raises: 149 ValueError if workers have different number of GPUs or GPU indices are not 150 consecutive and starting from 0. 151 """ 152 if _is_device_list_single_worker(devices): 153 return sum(1 for d in devices if _is_gpu_device(d)) 154 else: 155 device_dict = _group_device_list(devices) 156 num_gpus = None 157 for _, devices_in_task in device_dict.items(): 158 for device_in_task in devices_in_task: 159 if num_gpus is None: 160 num_gpus = sum(1 for d in device_in_task if _is_gpu_device(d)) 161 162 # Verify other workers have the same number of GPUs. 163 elif num_gpus != sum(1 for d in device_in_task if _is_gpu_device(d)): 164 raise ValueError("All workers should have the same number of GPUs.") 165 166 for d in device_in_task: 167 d_spec = tf_device.DeviceSpec.from_string(d) 168 if (d_spec.device_type == "GPU" and 169 d_spec.device_index >= num_gpus): 170 raise ValueError("GPU `device_index` on a worker should be " 171 "consecutive and start from 0.") 172 return num_gpus 173 174 175def all_local_devices(num_gpus=None): 176 devices = config.list_logical_devices("GPU") 177 if num_gpus is not None: 178 devices = devices[:num_gpus] 179 return devices or config.list_logical_devices("CPU") 180 181 182def all_devices(): 183 devices = [] 184 tfconfig = TFConfigClusterResolver() 185 if tfconfig.cluster_spec().as_dict(): 186 devices = _cluster_spec_to_device_list(tfconfig.cluster_spec(), 187 context.num_gpus()) 188 return devices if devices else all_local_devices() 189 190 191@tf_export("distribute.MirroredStrategy", v1=[]) # pylint: disable=g-classes-have-attributes 192class MirroredStrategy(distribute_lib.Strategy): 193 """Synchronous training across multiple replicas on one machine. 194 195 This strategy is typically used for training on one 196 machine with multiple GPUs. For TPUs, use 197 `tf.distribute.TPUStrategy`. To use `MirroredStrategy` with multiple workers, 198 please refer to `tf.distribute.experimental.MultiWorkerMirroredStrategy`. 199 200 For example, a variable created under a `MirroredStrategy` is a 201 `MirroredVariable`. If no devices are specified in the constructor argument of 202 the strategy then it will use all the available GPUs. If no GPUs are found, it 203 will use the available CPUs. Note that TensorFlow treats all CPUs on a 204 machine as a single device, and uses threads internally for parallelism. 205 206 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 207 >>> with strategy.scope(): 208 ... x = tf.Variable(1.) 209 >>> x 210 MirroredVariable:{ 211 0: <tf.Variable ... shape=() dtype=float32, numpy=1.0>, 212 1: <tf.Variable ... shape=() dtype=float32, numpy=1.0> 213 } 214 215 While using distribution strategies, all the variable creation should be done 216 within the strategy's scope. This will replicate the variables across all the 217 replicas and keep them in sync using an all-reduce algorithm. 218 219 Variables created inside a `MirroredStrategy` which is wrapped with a 220 `tf.function` are still `MirroredVariables`. 221 222 >>> x = [] 223 >>> @tf.function # Wrap the function with tf.function. 224 ... def create_variable(): 225 ... if not x: 226 ... x.append(tf.Variable(1.)) 227 ... return x[0] 228 >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) 229 >>> with strategy.scope(): 230 ... _ = create_variable() 231 ... print(x[0]) 232 MirroredVariable:{ 233 0: <tf.Variable ... shape=() dtype=float32, numpy=1.0>, 234 1: <tf.Variable ... shape=() dtype=float32, numpy=1.0> 235 } 236 237 `experimental_distribute_dataset` can be used to distribute the dataset across 238 the replicas when writing your own training loop. If you are using `.fit` and 239 `.compile` methods available in `tf.keras`, then `tf.keras` will handle the 240 distribution for you. 241 242 For example: 243 244 ```python 245 my_strategy = tf.distribute.MirroredStrategy() 246 with my_strategy.scope(): 247 @tf.function 248 def distribute_train_epoch(dataset): 249 def replica_fn(input): 250 # process input and return result 251 return result 252 253 total_result = 0 254 for x in dataset: 255 per_replica_result = my_strategy.run(replica_fn, args=(x,)) 256 total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM, 257 per_replica_result, axis=None) 258 return total_result 259 260 dist_dataset = my_strategy.experimental_distribute_dataset(dataset) 261 for _ in range(EPOCHS): 262 train_result = distribute_train_epoch(dist_dataset) 263 ``` 264 265 Args: 266 devices: a list of device strings such as `['/gpu:0', '/gpu:1']`. If 267 `None`, all available GPUs are used. If no GPUs are found, CPU is used. 268 cross_device_ops: optional, a descedant of `CrossDeviceOps`. If this is not 269 set, `NcclAllReduce()` will be used by default. One would customize this 270 if NCCL isn't available or if a special implementation that exploits 271 the particular hardware is available. 272 """ 273 274 # Only set this in tests. 275 _collective_key_base = 0 276 277 def __init__(self, devices=None, cross_device_ops=None): 278 extended = MirroredExtended( 279 self, devices=devices, cross_device_ops=cross_device_ops) 280 super(MirroredStrategy, self).__init__(extended) 281 distribute_lib.distribution_strategy_gauge.get_cell("V2").set( 282 "MirroredStrategy") 283 284 285@tf_export(v1=["distribute.MirroredStrategy"]) 286class MirroredStrategyV1(distribute_lib.StrategyV1): # pylint: disable=g-missing-docstring 287 288 __doc__ = MirroredStrategy.__doc__ 289 290 # Only set this in tests. 291 _collective_key_base = 0 292 293 def __init__(self, devices=None, cross_device_ops=None): 294 extended = MirroredExtended( 295 self, devices=devices, cross_device_ops=cross_device_ops) 296 super(MirroredStrategyV1, self).__init__(extended) 297 distribute_lib.distribution_strategy_gauge.get_cell("V1").set( 298 "MirroredStrategy") 299 300 301# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1. 302class MirroredExtended(distribute_lib.StrategyExtendedV1): 303 """Implementation of MirroredStrategy.""" 304 305 # If this is set to True, use NCCL collective ops instead of NCCL cross device 306 # ops. 307 _prefer_collective_ops = False 308 309 def __init__(self, container_strategy, devices=None, cross_device_ops=None): 310 super(MirroredExtended, self).__init__(container_strategy) 311 if context.executing_eagerly(): 312 if devices and not _is_device_list_single_worker(devices): 313 raise RuntimeError("In-graph multi-worker training with " 314 "`MirroredStrategy` is not supported in eager mode.") 315 else: 316 if TFConfigClusterResolver().cluster_spec().as_dict(): 317 # if you are executing in eager mode, only the single machine code 318 # path is supported. 319 logging.info("Initializing local devices since in-graph multi-worker " 320 "training with `MirroredStrategy` is not supported in " 321 "eager mode. TF_CONFIG will be ignored when " 322 "when initializing `MirroredStrategy`.") 323 devices = devices or all_local_devices() 324 else: 325 devices = devices or all_devices() 326 327 assert devices, ("Got an empty `devices` list and unable to recognize " 328 "any local devices.") 329 self._cross_device_ops = cross_device_ops 330 if self._prefer_collective_ops: 331 self._communication_options = collective_util.Options( 332 implementation=collective_util.CommunicationImplementation.NCCL) 333 else: 334 self._communication_options = collective_util.Options() 335 self._collective_ops_in_use = False 336 self._collective_key_base = container_strategy._collective_key_base 337 self._initialize_strategy(devices) 338 339 # TODO(b/128995245): Enable last partial batch support in graph mode. 340 if ops.executing_eagerly_outside_functions(): 341 self.experimental_enable_get_next_as_optional = True 342 343 # Flag to turn on VariablePolicy. 344 self._use_var_policy = False 345 346 def _initialize_strategy(self, devices): 347 # The _initialize_strategy method is intended to be used by distribute 348 # coordinator as well. 349 assert devices, "Must specify at least one device." 350 devices = tuple(device_util.resolve(d) for d in devices) 351 assert len(set(devices)) == len(devices), ( 352 "No duplicates allowed in `devices` argument: %s" % (devices,)) 353 if _is_device_list_single_worker(devices): 354 self._initialize_single_worker(devices) 355 if self._prefer_collective_ops and ( 356 isinstance(self._cross_device_ops, cross_device_ops_lib.NcclAllReduce) 357 or isinstance(self._inferred_cross_device_ops, 358 cross_device_ops_lib.NcclAllReduce)): 359 self._use_collective_ops(devices) 360 self._inferred_cross_device_ops = None 361 logging.info("Using MirroredStrategy with devices %r", devices) 362 else: 363 self._initialize_multi_worker(devices) 364 365 def _use_collective_ops(self, devices): 366 if ops.executing_eagerly_outside_functions(): 367 try: 368 context.context().configure_collective_ops( 369 scoped_allocator_enabled_ops=("CollectiveReduce",)) 370 except RuntimeError: 371 logging.warning("Collective ops is not configured at program startup." 372 " Some performance features may not be enabled.") 373 374 self._collective_keys = cross_device_utils.CollectiveKeys( 375 group_key_start=1 + self._collective_key_base) # pylint: disable=protected-access 376 self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce( 377 devices=self._devices, 378 group_size=len(self._devices), 379 collective_keys=self._collective_keys) 380 self._collective_ops_in_use = True 381 382 def _initialize_single_worker(self, devices): 383 """Initializes the object for single-worker training.""" 384 self._devices = tuple(device_util.canonicalize(d) for d in devices) 385 self._input_workers_devices = ( 386 (device_util.canonicalize("/device:CPU:0", devices[0]), devices),) 387 388 self._inferred_cross_device_ops = None if self._cross_device_ops else ( 389 cross_device_ops_lib.select_cross_device_ops(devices)) 390 self._host_input_device = numpy_dataset.SingleDevice( 391 self._input_workers_devices[0][0]) 392 self._is_multi_worker_training = False 393 device_spec = tf_device.DeviceSpec.from_string( 394 self._input_workers_devices[0][0]) 395 # Ensures when we enter strategy.scope() we use the correct default device 396 if device_spec.job is not None and device_spec.job != "localhost": 397 self._default_device = "/job:%s/replica:%d/task:%d" % ( 398 device_spec.job, device_spec.replica, device_spec.task) 399 400 def _initialize_multi_worker(self, devices): 401 """Initializes the object for multi-worker training.""" 402 device_dict = _group_device_list(devices) 403 workers = [] 404 worker_devices = [] 405 for job in ("chief", "worker"): 406 for task in range(len(device_dict.get(job, []))): 407 worker = "/job:%s/task:%d" % (job, task) 408 workers.append(worker) 409 worker_devices.append((worker, device_dict[job][task])) 410 411 # Setting `_default_device` will add a device scope in the 412 # distribution.scope. We set the default device to the first worker. When 413 # users specify device under distribution.scope by 414 # with tf.device("/cpu:0"): 415 # ... 416 # their ops will end up on the cpu device of its first worker, e.g. 417 # "/job:worker/task:0/device:CPU:0". Note this is not used in replica mode. 418 self._default_device = workers[0] 419 self._host_input_device = numpy_dataset.SingleDevice(workers[0]) 420 421 self._devices = tuple(devices) 422 self._input_workers_devices = worker_devices 423 self._is_multi_worker_training = True 424 425 if len(workers) > 1: 426 # Grandfather usage in the legacy tests if they're configured properly. 427 if (not isinstance(self._cross_device_ops, 428 cross_device_ops_lib.ReductionToOneDevice) or 429 self._cross_device_ops._num_between_graph_workers > 1): # pylint: disable=protected-access 430 raise ValueError( 431 "In-graph multi-worker training with `MirroredStrategy` is not " 432 "supported.") 433 self._inferred_cross_device_ops = self._cross_device_ops 434 else: 435 # TODO(yuefengz): make `select_cross_device_ops` work with device strings 436 # containing job names. 437 self._inferred_cross_device_ops = cross_device_ops_lib.NcclAllReduce() 438 439 logging.info("Using MirroredStrategy with remote devices %r", devices) 440 441 def _input_workers_with_options(self, options=None): 442 if not options: 443 return input_lib.InputWorkers(self._input_workers_devices) 444 if (options.experimental_replication_mode == 445 distribute_lib.InputReplicationMode.PER_REPLICA): 446 if options.experimental_place_dataset_on_device: 447 self._input_workers_devices = ( 448 tuple( 449 (device_util.canonicalize(d, d), (d,)) for d in self._devices)) 450 else: 451 self._input_workers_devices = ( 452 tuple((device_util.canonicalize("/device:CPU:0", d), (d,)) 453 for d in self._devices)) 454 return input_lib.InputWorkers(self._input_workers_devices) 455 else: 456 if not options.experimental_prefetch_to_device: 457 return input_lib.InputWorkers([ 458 (host_device, (host_device,) * len(compute_devices)) 459 for host_device, compute_devices in self._input_workers_devices 460 ]) 461 else: 462 return input_lib.InputWorkers(self._input_workers_devices) 463 464 @property 465 def _input_workers(self): 466 return self._input_workers_with_options() 467 468 def _get_variable_creator_initial_value(self, 469 replica_id, 470 device, 471 primary_var, 472 **kwargs): 473 """Return the initial value for variables on a replica.""" 474 if replica_id == 0: 475 return kwargs["initial_value"] 476 else: 477 assert primary_var is not None 478 assert device is not None 479 assert kwargs is not None 480 481 def initial_value_fn(): 482 if context.executing_eagerly() or ops.inside_function(): 483 init_value = primary_var.value() 484 return array_ops.identity(init_value) 485 else: 486 with ops.device(device): 487 init_value = primary_var.initial_value 488 return array_ops.identity(init_value) 489 490 return initial_value_fn 491 492 def _create_variable(self, next_creator, **kwargs): 493 """Create a mirrored variable. See `DistributionStrategy.scope`.""" 494 colocate_with = kwargs.pop("colocate_with", None) 495 if colocate_with is None: 496 devices = self._devices 497 elif isinstance(colocate_with, numpy_dataset.SingleDevice): 498 with ops.device(colocate_with.device): 499 return next_creator(**kwargs) 500 else: 501 devices = colocate_with._devices # pylint: disable=protected-access 502 503 def _real_mirrored_creator(**kwargs): # pylint: disable=g-missing-docstring 504 value_list = [] 505 for i, d in enumerate(devices): 506 with ops.device(d): 507 kwargs["initial_value"] = self._get_variable_creator_initial_value( 508 replica_id=i, 509 device=d, 510 primary_var=value_list[0] if value_list else None, 511 **kwargs) 512 if i > 0: 513 # Give replicas meaningful distinct names: 514 var0name = value_list[0].name.split(":")[0] 515 # We append a / to variable names created on replicas with id > 0 to 516 # ensure that we ignore the name scope and instead use the given 517 # name as the absolute name of the variable. 518 kwargs["name"] = "%s/replica_%d/" % (var0name, i) 519 with context.device_policy(context.DEVICE_PLACEMENT_SILENT): 520 # Don't record operations (e.g. other variable reads) during 521 # variable creation. 522 with tape.stop_recording(): 523 v = next_creator(**kwargs) 524 assert not isinstance(v, values.DistributedVariable) 525 value_list.append(v) 526 return value_list 527 528 return distribute_utils.create_mirrored_variable( 529 self._container_strategy(), _real_mirrored_creator, 530 distribute_utils.VARIABLE_CLASS_MAPPING, 531 distribute_utils.VARIABLE_POLICY_MAPPING, **kwargs) 532 533 def _validate_colocate_with_variable(self, colocate_with_variable): 534 distribute_utils.validate_colocate_distributed_variable( 535 colocate_with_variable, self) 536 537 def _make_dataset_iterator(self, dataset): 538 return input_lib.DatasetIterator( 539 dataset, 540 self._input_workers, 541 self._container_strategy(), 542 num_replicas_in_sync=self._num_replicas_in_sync) 543 544 def _make_input_fn_iterator( 545 self, 546 input_fn, 547 replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): 548 input_contexts = [] 549 num_workers = self._input_workers.num_workers 550 for i in range(num_workers): 551 input_contexts.append(distribute_lib.InputContext( 552 num_input_pipelines=num_workers, 553 input_pipeline_id=i, 554 num_replicas_in_sync=self._num_replicas_in_sync)) 555 return input_lib.InputFunctionIterator(input_fn, self._input_workers, 556 input_contexts, 557 self._container_strategy()) 558 559 def _experimental_distribute_dataset(self, dataset, options): 560 if (options and options.experimental_replication_mode == 561 distribute_lib.InputReplicationMode.PER_REPLICA): 562 raise NotImplementedError( 563 "InputReplicationMode.PER_REPLICA " 564 "is only supported in " 565 "`experimental_distribute_datasets_from_function`." 566 ) 567 return input_lib.get_distributed_dataset( 568 dataset, 569 self._input_workers_with_options(options), 570 self._container_strategy(), 571 num_replicas_in_sync=self._num_replicas_in_sync) 572 573 def _experimental_make_numpy_dataset(self, numpy_input, session): 574 return numpy_dataset.one_host_numpy_dataset( 575 numpy_input, self._host_input_device, session) 576 577 def _distribute_datasets_from_function(self, dataset_fn, options): 578 input_workers = self._input_workers_with_options(options) 579 input_contexts = [] 580 num_workers = input_workers.num_workers 581 for i in range(num_workers): 582 input_contexts.append(distribute_lib.InputContext( 583 num_input_pipelines=num_workers, 584 input_pipeline_id=i, 585 num_replicas_in_sync=self._num_replicas_in_sync)) 586 587 return input_lib.get_distributed_datasets_from_function( 588 dataset_fn, input_workers, input_contexts, self._container_strategy(), 589 options) 590 591 def _experimental_distribute_values_from_function(self, value_fn): 592 per_replica_values = [] 593 for replica_id in range(self._num_replicas_in_sync): 594 per_replica_values.append(value_fn( 595 distribute_lib.ValueContext(replica_id, 596 self._num_replicas_in_sync))) 597 return distribute_utils.regroup(per_replica_values, always_wrap=True) 598 599 # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. 600 def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, 601 initial_loop_values=None): 602 if initial_loop_values is None: 603 initial_loop_values = {} 604 initial_loop_values = nest.flatten(initial_loop_values) 605 606 ctx = input_lib.MultiStepContext() 607 def body(i, *args): 608 """A wrapper around `fn` to create the while loop body.""" 609 del args 610 fn_result = fn(ctx, iterator.get_next()) 611 for (name, output) in ctx.last_step_outputs.items(): 612 # Convert all outputs to tensors, potentially from `DistributedValues`. 613 ctx.last_step_outputs[name] = self._local_results(output) 614 flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) 615 with ops.control_dependencies([fn_result]): 616 return [i + 1] + flat_last_step_outputs 617 618 # We capture the control_flow_context at this point, before we run `fn` 619 # inside a while_loop. This is useful in cases where we might need to exit 620 # these contexts and get back to the outer context to do some things, for 621 # e.g. create an op which should be evaluated only once at the end of the 622 # loop on the host. One such usage is in creating metrics' value op. 623 self._outer_control_flow_context = ( 624 ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access 625 626 cond = lambda i, *args: i < iterations 627 i = constant_op.constant(0) 628 loop_result = control_flow_ops.while_loop( 629 cond, body, [i] + initial_loop_values, name="", 630 parallel_iterations=1, back_prop=False, swap_memory=False, 631 return_same_structure=True) 632 del self._outer_control_flow_context 633 634 ctx.run_op = control_flow_ops.group(loop_result) 635 636 # Convert the last_step_outputs from a list to the original dict structure 637 # of last_step_outputs. 638 last_step_tensor_outputs = loop_result[1:] 639 last_step_tensor_outputs_dict = nest.pack_sequence_as( 640 ctx.last_step_outputs, last_step_tensor_outputs) 641 642 for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access 643 output = last_step_tensor_outputs_dict[name] 644 # For outputs that have already been reduced, wrap them in a Mirrored 645 # container, else in a PerReplica container. 646 if reduce_op is None: 647 last_step_tensor_outputs_dict[name] = distribute_utils.regroup(output) 648 else: 649 assert len(output) == 1 650 last_step_tensor_outputs_dict[name] = output[0] 651 652 ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access 653 return ctx 654 655 def _broadcast_to(self, tensor, destinations): 656 # This is both a fast path for Python constants, and a way to delay 657 # converting Python values to a tensor until we know what type it 658 # should be converted to. Otherwise we have trouble with: 659 # global_step.assign_add(1) 660 # since the `1` gets broadcast as an int32 but global_step is int64. 661 if isinstance(tensor, (float, int)): 662 return tensor 663 # TODO(josh11b): In eager mode, use one thread per device, or async mode. 664 if not destinations: 665 # TODO(josh11b): Use current logical device instead of 0 here. 666 destinations = self._devices 667 return self._get_cross_device_ops(tensor).broadcast(tensor, destinations) 668 669 def _call_for_each_replica(self, fn, args, kwargs): 670 return mirrored_run.call_for_each_replica( 671 self._container_strategy(), fn, args, kwargs) 672 673 def _configure(self, 674 session_config=None, 675 cluster_spec=None, 676 task_type=None, 677 task_id=None): 678 del task_type, task_id 679 680 if session_config: 681 session_config.CopyFrom(self._update_config_proto(session_config)) 682 683 if cluster_spec: 684 # TODO(yuefengz): remove the following code once cluster_resolver is 685 # added. 686 num_gpus_per_worker = _infer_num_gpus_per_worker(self._devices) 687 multi_worker_devices = _cluster_spec_to_device_list( 688 cluster_spec, num_gpus_per_worker) 689 self._initialize_multi_worker(multi_worker_devices) 690 691 def _update_config_proto(self, config_proto): 692 updated_config = copy.deepcopy(config_proto) 693 updated_config.isolate_session_state = True 694 return updated_config 695 696 def _get_cross_device_ops(self, value): 697 if self._collective_ops_in_use: 698 if isinstance(value, values.DistributedValues): 699 value_int32 = True in { 700 dtypes.as_dtype(v.dtype) == dtypes.int32 for v in value.values 701 } 702 else: 703 value_int32 = dtypes.as_dtype(value.dtype) == dtypes.int32 704 if value_int32: 705 return cross_device_ops_lib.ReductionToOneDevice() 706 707 return self._cross_device_ops or self._inferred_cross_device_ops 708 709 def _gather_to_implementation(self, value, destinations, axis, options): 710 if not isinstance(value, values.DistributedValues): 711 # ReductionToOneDevice._gather accepts DistributedValues only. 712 return value 713 return self._get_cross_device_ops(value)._gather( # pylint: disable=protected-access 714 value, 715 destinations=destinations, 716 axis=axis, 717 options=self._communication_options.merge(options)) 718 719 def _reduce_to(self, reduce_op, value, destinations, options): 720 if (distribute_utils.is_mirrored(value) and 721 reduce_op == reduce_util.ReduceOp.MEAN): 722 return value 723 assert not distribute_utils.is_mirrored(value) 724 if not isinstance(value, values.DistributedValues): 725 # This function handles reducing values that are not PerReplica or 726 # Mirrored values. For example, the same value could be present on all 727 # replicas in which case `value` would be a single value or value could 728 # be 0. 729 return cross_device_ops_lib.reduce_non_distributed_value( 730 reduce_op, value, destinations, self._num_replicas_in_sync) 731 if self._collective_ops_in_use and ( 732 (not cross_device_ops_lib._devices_match(value, destinations) or # pylint: disable=protected-access 733 any("cpu" in d.lower() 734 for d in cross_device_ops_lib.get_devices_from(destinations)))): 735 return cross_device_ops_lib.ReductionToOneDevice().reduce( 736 reduce_op, value, destinations) 737 return self._get_cross_device_ops(value).reduce( 738 reduce_op, 739 value, 740 destinations=destinations, 741 options=self._communication_options.merge(options)) 742 743 def _batch_reduce_to(self, reduce_op, value_destination_pairs, options): 744 cross_device_ops = None 745 for value, _ in value_destination_pairs: 746 if cross_device_ops is None: 747 cross_device_ops = self._get_cross_device_ops(value) 748 elif cross_device_ops is not self._get_cross_device_ops(value): 749 raise ValueError("inputs to batch_reduce_to must be either all on the " 750 "the host or all on the compute devices") 751 return cross_device_ops.batch_reduce( 752 reduce_op, 753 value_destination_pairs, 754 options=self._communication_options.merge(options)) 755 756 def _update(self, var, fn, args, kwargs, group): 757 # TODO(josh11b): In eager mode, use one thread per device. 758 assert isinstance(var, values.DistributedVariable) 759 if (var.synchronization != variables_lib.VariableSynchronization.ON_READ and 760 var.aggregation != variables_lib.VariableAggregation.NONE): 761 distribute_utils.assert_mirrored(args) 762 distribute_utils.assert_mirrored(kwargs) 763 updates = [] 764 for i, v in enumerate(var.values): 765 name = "update_%d" % i 766 with ops.device(v.device), \ 767 distribute_lib.UpdateContext(i), \ 768 ops.name_scope(name): 769 # If args and kwargs are not mirrored, the value is returned as is. 770 updates.append( 771 fn(v, *distribute_utils.select_replica(i, args), 772 **distribute_utils.select_replica(i, kwargs))) 773 return distribute_utils.update_regroup(self, updates, group) 774 775 def _update_non_slot(self, colocate_with, fn, args, kwargs, group): 776 assert isinstance(colocate_with, tuple) 777 # TODO(josh11b): In eager mode, use one thread per device. 778 updates = [] 779 for i, d in enumerate(colocate_with): 780 name = "update_%d" % i 781 with ops.device(d), distribute_lib.UpdateContext(i), ops.name_scope(name): 782 updates.append( 783 fn(*distribute_utils.select_replica_mirrored(i, args), 784 **distribute_utils.select_replica_mirrored(i, kwargs))) 785 return distribute_utils.update_regroup(self, updates, group) 786 787 def read_var(self, replica_local_var): 788 """Read the aggregate value of a replica-local variable.""" 789 # pylint: disable=protected-access 790 if distribute_utils.is_sync_on_read(replica_local_var): 791 return replica_local_var._get_cross_replica() 792 assert distribute_utils.is_mirrored(replica_local_var) 793 return array_ops.identity(replica_local_var._get()) 794 # pylint: enable=protected-access 795 796 def _local_results(self, val): 797 if isinstance(val, values.DistributedValues): 798 return val._values # pylint: disable=protected-access 799 return (val,) 800 801 def value_container(self, val): 802 return distribute_utils.value_container(val) 803 804 @property 805 def _num_replicas_in_sync(self): 806 return len(self._devices) 807 808 @property 809 def worker_devices(self): 810 return self._devices 811 812 @property 813 def worker_devices_by_replica(self): 814 return [[d] for d in self._devices] 815 816 @property 817 def parameter_devices(self): 818 return self.worker_devices 819 820 @property 821 def experimental_between_graph(self): 822 return False 823 824 @property 825 def experimental_should_init(self): 826 return True 827 828 @property 829 def should_checkpoint(self): 830 return True 831 832 @property 833 def should_save_summary(self): 834 return True 835 836 def non_slot_devices(self, var_list): 837 del var_list 838 # TODO(josh11b): Should this be the last logical device instead? 839 return self._devices 840 841 # TODO(priyag): Delete this once all strategies use global batch size. 842 @property 843 def _global_batch_size(self): 844 """`make_dataset_iterator` and `make_numpy_iterator` use global batch size. 845 846 `make_input_fn_iterator` assumes per-replica batching. 847 848 Returns: 849 Boolean. 850 """ 851 return True 852 853 def _in_multi_worker_mode(self): 854 """Whether this strategy indicates working in multi-worker settings.""" 855 return False 856 857 def _get_local_replica_id(self, replica_id_in_sync_group): 858 return replica_id_in_sync_group 859 860 def _get_replica_id_in_sync_group(self, replica_id): 861 return replica_id 862