1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Class implementing a multi-worker parameter server tf.distribute strategy.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22 23 24from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib 25from tensorflow.python.distribute import device_util 26from tensorflow.python.distribute import distribute_lib 27from tensorflow.python.distribute import distribute_utils 28from tensorflow.python.distribute import input_lib 29from tensorflow.python.distribute import mirrored_run 30from tensorflow.python.distribute import multi_worker_util 31from tensorflow.python.distribute import numpy_dataset 32from tensorflow.python.distribute import ps_values 33from tensorflow.python.distribute import values 34from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver 35from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver 36from tensorflow.python.eager import context 37from tensorflow.python.framework import device as tf_device 38from tensorflow.python.framework import ops 39from tensorflow.python.ops import array_ops 40from tensorflow.python.ops import resource_variable_ops 41from tensorflow.python.ops import variable_scope as vs 42from tensorflow.python.platform import tf_logging as logging 43from tensorflow.python.training import device_setter 44from tensorflow.python.util import nest 45from tensorflow.python.util.tf_export import tf_export 46 47_LOCAL_CPU = "/device:CPU:0" 48 49 50@tf_export(v1=["distribute.experimental.ParameterServerStrategy"]) # pylint: disable=missing-docstring 51class ParameterServerStrategyV1(distribute_lib.StrategyV1): 52 """An asynchronous multi-worker parameter server tf.distribute strategy. 53 54 This strategy requires two roles: workers and parameter servers. Variables and 55 updates to those variables will be assigned to parameter servers and other 56 operations are assigned to workers. 57 58 When each worker has more than one GPU, operations will be replicated on all 59 GPUs. Even though operations may be replicated, variables are not and each 60 worker shares a common view for which parameter server a variable is assigned 61 to. 62 63 By default it uses `TFConfigClusterResolver` to detect configurations for 64 multi-worker training. This requires a 'TF_CONFIG' environment variable and 65 the 'TF_CONFIG' must have a cluster spec. 66 67 This class assumes each worker is running the same code independently, but 68 parameter servers are running a standard server. This means that while each 69 worker will synchronously compute a single gradient update across all GPUs, 70 updates between workers proceed asynchronously. Operations that occur only on 71 the first replica (such as incrementing the global step), will occur on the 72 first replica *of every worker*. 73 74 It is expected to call `call_for_each_replica(fn, ...)` for any 75 operations which potentially can be replicated across replicas (i.e. multiple 76 GPUs) even if there is only CPU or one GPU. When defining the `fn`, extra 77 caution needs to be taken: 78 79 1) It is generally not recommended to open a device scope under the strategy's 80 scope. A device scope (i.e. calling `tf.device`) will be merged with or 81 override the device for operations but will not change the device for 82 variables. 83 84 2) It is also not recommended to open a colocation scope (i.e. calling 85 `tf.compat.v1.colocate_with`) under the strategy's scope. For colocating 86 variables, use `strategy.extended.colocate_vars_with` instead. Colocation of 87 ops will possibly create device assignment conflicts. 88 89 Note: This strategy only works with the Estimator API. Pass an instance of 90 this strategy to the `experimental_distribute` argument when you create the 91 `RunConfig`. This instance of `RunConfig` should then be passed to the 92 `Estimator` instance on which `train_and_evaluate` is called. 93 94 For Example: 95 ``` 96 strategy = tf.distribute.experimental.ParameterServerStrategy() 97 run_config = tf.estimator.RunConfig( 98 experimental_distribute.train_distribute=strategy) 99 estimator = tf.estimator.Estimator(config=run_config) 100 tf.estimator.train_and_evaluate(estimator,...) 101 ``` 102 """ 103 104 def __init__(self, cluster_resolver=None): 105 """Initializes this strategy with an optional `cluster_resolver`. 106 107 Args: 108 cluster_resolver: Optional 109 `tf.distribute.cluster_resolver.ClusterResolver` object. Defaults to a 110 `tf.distribute.cluster_resolver.TFConfigClusterResolver`. 111 """ 112 if cluster_resolver is None: 113 cluster_resolver = TFConfigClusterResolver() 114 super(ParameterServerStrategyV1, self).__init__( 115 ParameterServerStrategyExtended( 116 self, cluster_resolver=cluster_resolver)) 117 distribute_lib.distribution_strategy_gauge.get_cell("V1").set( 118 "ParameterServerStrategy") 119 120 def experimental_distribute_dataset(self, dataset, options=None): 121 if (options and options.experimental_replication_mode == 122 distribute_lib.InputReplicationMode.PER_REPLICA): 123 raise NotImplementedError( 124 "InputReplicationMode.PER_REPLICA " 125 "is only supported in " 126 "`experimental_distribute_datasets_from_function`." 127 ) 128 self._raise_pss_error_if_eager() 129 super(ParameterServerStrategyV1, 130 self).experimental_distribute_dataset(dataset=dataset, 131 options=options) 132 133 def distribute_datasets_from_function(self, dataset_fn, options=None): 134 if (options and options.experimental_replication_mode == 135 distribute_lib.InputReplicationMode.PER_REPLICA): 136 raise NotImplementedError( 137 "InputReplicationMode.PER_REPLICA " 138 "is only supported in " 139 "`experimental_distribute_datasets_from_function` " 140 "of tf.distribute.MirroredStrategy") 141 self._raise_pss_error_if_eager() 142 super(ParameterServerStrategyV1, self).distribute_datasets_from_function( 143 dataset_fn=dataset_fn, options=options) 144 145 def run(self, fn, args=(), kwargs=None, options=None): 146 self._raise_pss_error_if_eager() 147 super(ParameterServerStrategyV1, self).run( 148 fn, args=args, kwargs=kwargs, options=options) 149 150 def scope(self): 151 self._raise_pss_error_if_eager() 152 return super(ParameterServerStrategyV1, self).scope() 153 154 def _raise_pss_error_if_eager(self): 155 if context.executing_eagerly(): 156 raise NotImplementedError( 157 "`tf.compat.v1.distribute.experimental.ParameterServerStrategy` " 158 "currently only works with the tf.Estimator API") 159 160 161# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1. 162class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1): 163 """Implementation of ParameterServerStrategy and CentralStorageStrategy.""" 164 165 def __init__(self, 166 container_strategy, 167 cluster_resolver=None, 168 compute_devices=None, 169 parameter_device=None): 170 super(ParameterServerStrategyExtended, self).__init__(container_strategy) 171 self._initialize_strategy( 172 cluster_resolver=cluster_resolver, 173 compute_devices=compute_devices, 174 parameter_device=parameter_device) 175 176 # We typically don't need to do all-reduce in this strategy. 177 self._cross_device_ops = ( 178 cross_device_ops_lib.ReductionToOneDevice(reduce_to_device=_LOCAL_CPU)) 179 180 def _initialize_strategy(self, 181 cluster_resolver=None, 182 compute_devices=None, 183 parameter_device=None): 184 if cluster_resolver and cluster_resolver.cluster_spec(): 185 self._initialize_multi_worker(cluster_resolver) 186 else: 187 self._initialize_local( 188 compute_devices, parameter_device, cluster_resolver=cluster_resolver) 189 190 def _initialize_multi_worker(self, cluster_resolver): 191 """Initialize devices for multiple workers. 192 193 It creates variable devices and compute devices. Variables and operations 194 will be assigned to them respectively. We have one compute device per 195 replica. The variable device is a device function or device string. The 196 default variable device assigns variables to parameter servers in a 197 round-robin fashion. 198 199 Args: 200 cluster_resolver: a descendant of `ClusterResolver` object. 201 202 Raises: 203 ValueError: if the cluster doesn't have ps jobs. 204 """ 205 # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in 206 # some cases. 207 if isinstance(cluster_resolver, TFConfigClusterResolver): 208 num_gpus = context.num_gpus() 209 else: 210 num_gpus = cluster_resolver.num_accelerators().get("GPU", 0) 211 212 # Save the num_gpus_per_worker for configure method. 213 self._num_gpus_per_worker = num_gpus 214 215 cluster_spec = cluster_resolver.cluster_spec() 216 task_type = cluster_resolver.task_type 217 task_id = cluster_resolver.task_id 218 if not task_type or task_id is None: 219 raise ValueError("When `cluster_spec` is given, you must also specify " 220 "`task_type` and `task_id`") 221 cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) 222 assert cluster_spec.as_dict() 223 224 self._worker_device = "/job:%s/task:%d" % (task_type, task_id) 225 self._input_host_device = numpy_dataset.SingleDevice(self._worker_device) 226 227 # Define compute devices which is a list of device strings and one for each 228 # replica. When there are GPUs, replicate operations on these GPUs. 229 # Otherwise, place operations on CPU. 230 if num_gpus > 0: 231 compute_devices = tuple( 232 "%s/device:GPU:%d" % (self._worker_device, i) 233 for i in range(num_gpus)) 234 else: 235 compute_devices = (self._worker_device,) 236 237 self._compute_devices = [ 238 device_util.canonicalize(d) for d in compute_devices] 239 240 # In distributed mode, place variables on ps jobs in a round-robin fashion. 241 # Note that devices returned from `replica_device_setter` are not 242 # canonical and therefore we don't canonicalize all variable devices to 243 # make them consistent. 244 # TODO(yuefengz): support passing a strategy object to control variable 245 # assignment. 246 # TODO(yuefengz): merge the logic of replica_device_setter into this 247 # class. 248 num_ps_replicas = len(cluster_spec.as_dict().get("ps", [])) 249 if num_ps_replicas == 0: 250 raise ValueError("The cluster spec needs to have `ps` jobs.") 251 self._variable_device = device_setter.replica_device_setter( 252 ps_tasks=num_ps_replicas, 253 worker_device=self._worker_device, 254 merge_devices=True, 255 cluster=cluster_spec) 256 257 # The `_parameter_devices` is needed for the `parameter_devices` property 258 # and is a list of all variable devices. Here parameter devices are all 259 # tasks of the "ps" job. 260 self._parameter_devices = tuple(map("/job:ps/task:{}".format, 261 range(num_ps_replicas))) 262 263 # Add a default device so that ops without specified devices will not end up 264 # on other workers. 265 self._default_device = self._worker_device 266 267 self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type, 268 task_id) 269 self._cluster_spec = cluster_spec 270 self._task_type = task_type 271 self._task_id = task_id 272 273 logging.info( 274 "Multi-worker ParameterServerStrategy with " 275 "cluster_spec = %r, task_type = %r, task_id = %r, " 276 "num_ps_replicas = %r, is_chief = %r, compute_devices = %r, " 277 "variable_device = %r", cluster_spec.as_dict(), task_type, task_id, 278 num_ps_replicas, self._is_chief, self._compute_devices, 279 self._variable_device) 280 281 # TODO(yuefengz): get rid of cluster_resolver argument when contrib's 282 # version no longer depends on this class. 283 def _initialize_local(self, 284 compute_devices, 285 parameter_device, 286 cluster_resolver=None): 287 """Initialize local devices for training.""" 288 self._worker_device = device_util.canonicalize("/device:CPU:0") 289 self._input_host_device = numpy_dataset.SingleDevice(self._worker_device) 290 291 if compute_devices is None: 292 if not cluster_resolver: 293 num_gpus = context.num_gpus() 294 else: 295 num_gpus = cluster_resolver.num_accelerators().get("GPU", 0) 296 # Save the num_gpus_per_worker for configure method which is used by the 297 # contrib version. 298 self._num_gpus_per_worker = num_gpus 299 300 compute_devices = device_util.local_devices_from_num_gpus(num_gpus) 301 302 compute_devices = [device_util.canonicalize(d) for d in compute_devices] 303 304 if parameter_device is None: 305 # If there is only one GPU, put everything on that GPU. Otherwise, place 306 # variables on CPU. 307 if len(compute_devices) == 1: 308 parameter_device = compute_devices[0] 309 else: 310 parameter_device = _LOCAL_CPU 311 312 self._variable_device = parameter_device 313 self._compute_devices = compute_devices 314 self._parameter_devices = (parameter_device,) 315 self._is_chief = True 316 self._cluster_spec = None 317 self._task_type = None 318 self._task_id = None 319 320 logging.info( 321 "ParameterServerStrategy (CentralStorageStrategy if you are using a " 322 "single machine) with compute_devices = %r, variable_device = %r", 323 compute_devices, self._variable_device) 324 325 def _input_workers_with_options(self, options=None): 326 if not options or options.experimental_prefetch_to_device: 327 return input_lib.InputWorkers( 328 [(self._worker_device, self._compute_devices)]) 329 else: 330 return input_lib.InputWorkers( 331 [(self._worker_device, 332 (self._worker_device,) * len(self._compute_devices))]) 333 334 @property 335 def _input_workers(self): 336 return self._input_workers_with_options() 337 338 def _validate_colocate_with_variable(self, colocate_with_variable): 339 distribute_utils.validate_colocate(colocate_with_variable, self) 340 341 def _experimental_distribute_dataset(self, dataset, options): 342 return input_lib.get_distributed_dataset( 343 dataset, 344 self._input_workers_with_options(options), 345 self._container_strategy(), 346 num_replicas_in_sync=self._num_replicas_in_sync) 347 348 def _make_dataset_iterator(self, dataset): 349 return input_lib.DatasetIterator( 350 dataset, 351 self._input_workers, 352 self._container_strategy(), 353 num_replicas_in_sync=self._num_replicas_in_sync) 354 355 def _make_input_fn_iterator( 356 self, 357 input_fn, 358 replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): 359 """Distributes the dataset to each local GPU.""" 360 if self._cluster_spec: 361 input_pipeline_id = multi_worker_util.id_in_cluster( 362 self._cluster_spec, self._task_type, self._task_id) 363 num_input_pipelines = multi_worker_util.worker_count( 364 self._cluster_spec, self._task_type) 365 else: 366 input_pipeline_id = 0 367 num_input_pipelines = 1 368 input_context = distribute_lib.InputContext( 369 num_input_pipelines=num_input_pipelines, 370 input_pipeline_id=input_pipeline_id, 371 num_replicas_in_sync=self._num_replicas_in_sync) 372 return input_lib.InputFunctionIterator(input_fn, self._input_workers, 373 [input_context], 374 self._container_strategy()) 375 376 def _experimental_make_numpy_dataset(self, numpy_input, session): 377 return numpy_dataset.one_host_numpy_dataset( 378 numpy_input, self._input_host_device, session) 379 380 def _distribute_datasets_from_function(self, dataset_fn, options): 381 if self._cluster_spec: 382 input_pipeline_id = multi_worker_util.id_in_cluster( 383 self._cluster_spec, self._task_type, self._task_id) 384 num_input_pipelines = multi_worker_util.worker_count( 385 self._cluster_spec, self._task_type) 386 else: 387 input_pipeline_id = 0 388 num_input_pipelines = 1 389 390 input_context = distribute_lib.InputContext( 391 num_input_pipelines=num_input_pipelines, 392 input_pipeline_id=input_pipeline_id, 393 num_replicas_in_sync=self._num_replicas_in_sync) 394 395 return input_lib.get_distributed_datasets_from_function( 396 dataset_fn, 397 self._input_workers_with_options(options), 398 [input_context], 399 self._container_strategy()) 400 401 def _experimental_distribute_values_from_function(self, value_fn): 402 per_replica_values = [] 403 for replica_id in range(self._num_replicas_in_sync): 404 per_replica_values.append( 405 value_fn(distribute_lib.ValueContext(replica_id, 406 self._num_replicas_in_sync))) 407 return distribute_utils.regroup(per_replica_values, always_wrap=True) 408 409 def _broadcast_to(self, tensor, destinations): 410 # This is both a fast path for Python constants, and a way to delay 411 # converting Python values to a tensor until we know what type it 412 # should be converted to. Otherwise we have trouble with: 413 # global_step.assign_add(1) 414 # since the `1` gets broadcast as an int32 but global_step is int64. 415 if isinstance(tensor, (float, int)): 416 return tensor 417 if not cross_device_ops_lib.check_destinations(destinations): 418 # TODO(josh11b): Use current logical device instead of 0 here. 419 destinations = self._compute_devices 420 return self._cross_device_ops.broadcast(tensor, destinations) 421 422 def _allow_variable_partition(self): 423 return not context.executing_eagerly() 424 425 # TODO(yuefengz): Not all ops in device_setter.STANDARD_PS_OPS will go through 426 # this creator, such as "MutableHashTable". 427 def _create_variable(self, next_creator, **kwargs): 428 if self._num_replicas_in_sync > 1: 429 aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) 430 if aggregation not in ( 431 vs.VariableAggregation.NONE, 432 vs.VariableAggregation.SUM, 433 vs.VariableAggregation.MEAN, 434 vs.VariableAggregation.ONLY_FIRST_REPLICA 435 ): 436 raise ValueError("Invalid variable aggregation mode: " + aggregation + 437 " for variable: " + kwargs["name"]) 438 439 def var_creator(**kwargs): 440 """Create an AggregatingVariable and fix up collections.""" 441 # Record what collections this variable should be added to. 442 collections = kwargs.pop("collections", None) 443 if collections is None: 444 collections = [ops.GraphKeys.GLOBAL_VARIABLES] 445 kwargs["collections"] = [] 446 447 # Create and wrap the variable. 448 v = next_creator(**kwargs) 449 wrapped = ps_values.AggregatingVariable(self._container_strategy(), v, 450 aggregation) 451 452 # Add the wrapped variable to the requested collections. 453 # The handling of eager mode and the global step matches 454 # ResourceVariable._init_from_args(). 455 if not context.executing_eagerly(): 456 g = ops.get_default_graph() 457 # If "trainable" is True, next_creator() will add the contained 458 # variable to the TRAINABLE_VARIABLES collection, so we manually 459 # remove it and replace with the wrapper. We can't set "trainable" 460 # to False for next_creator() since that causes functions like 461 # implicit_gradients to skip those variables. 462 if kwargs.get("trainable", True): 463 collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) 464 l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) 465 if v in l: 466 l.remove(v) 467 g.add_to_collections(collections, wrapped) 468 elif ops.GraphKeys.GLOBAL_STEP in collections: 469 ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, wrapped) 470 471 return wrapped 472 else: 473 var_creator = next_creator 474 475 if "colocate_with" in kwargs: 476 colocate_with = kwargs["colocate_with"] 477 if isinstance(colocate_with, numpy_dataset.SingleDevice): 478 with ops.device(colocate_with.device): 479 return var_creator(**kwargs) 480 with ops.device(None): 481 with ops.colocate_with(colocate_with): 482 return var_creator(**kwargs) 483 484 with ops.colocate_with(None, ignore_existing=True): 485 with ops.device(self._variable_device): 486 return var_creator(**kwargs) 487 488 def _call_for_each_replica(self, fn, args, kwargs): 489 return mirrored_run.call_for_each_replica(self._container_strategy(), fn, 490 args, kwargs) 491 492 def _verify_destinations_not_different_worker(self, destinations): 493 if not self._cluster_spec: 494 return 495 if destinations is None: 496 return 497 for d in cross_device_ops_lib.get_devices_from(destinations): 498 d_spec = tf_device.DeviceSpec.from_string(d) 499 if d_spec.job == self._task_type and d_spec.task != self._task_id: 500 raise ValueError( 501 "Cannot reduce to another worker: %r, current worker is %r" % 502 (d, self._worker_device)) 503 504 def _gather_to_implementation(self, value, destinations, axis, 505 options): 506 self._verify_destinations_not_different_worker(destinations) 507 if not isinstance(value, values.DistributedValues): 508 return value 509 return self._cross_device_ops._gather( # pylint: disable=protected-access 510 value, 511 destinations=destinations, 512 axis=axis, 513 options=options) 514 515 def _reduce_to(self, reduce_op, value, destinations, options): 516 self._verify_destinations_not_different_worker(destinations) 517 if not isinstance(value, values.DistributedValues): 518 # pylint: disable=protected-access 519 return cross_device_ops_lib.reduce_non_distributed_value( 520 reduce_op, value, destinations, self._num_replicas_in_sync) 521 return self._cross_device_ops.reduce( 522 reduce_op, value, destinations=destinations, options=options) 523 524 def _batch_reduce_to(self, reduce_op, value_destination_pairs, options): 525 for _, destinations in value_destination_pairs: 526 self._verify_destinations_not_different_worker(destinations) 527 return self._cross_device_ops.batch_reduce(reduce_op, 528 value_destination_pairs, options) 529 530 def _select_single_value(self, structured): 531 """Select any single value in `structured`.""" 532 533 def _select_fn(x): # pylint: disable=g-missing-docstring 534 if isinstance(x, values.Mirrored): 535 if len(x._devices) == 1: # pylint: disable=protected-access 536 return x._primary # pylint: disable=protected-access 537 else: 538 raise ValueError( 539 "You cannot update variable with a Mirrored object with multiple " 540 "components %r when using ParameterServerStrategy. You must " 541 "specify a single value or a Mirrored with a single value." % x) 542 elif isinstance(x, values.PerReplica): 543 raise ValueError( 544 "You cannot update variable with a PerReplica object %r when using " 545 "ParameterServerStrategy. You must specify a single value or a " 546 "Mirrored with a single value" % x) 547 else: 548 return x 549 550 return nest.map_structure(_select_fn, structured) 551 552 def _update(self, var, fn, args, kwargs, group): 553 if isinstance(var, ps_values.AggregatingVariable): 554 var = var.get() 555 if not resource_variable_ops.is_resource_variable(var): 556 raise ValueError( 557 "You can not update `var` %r. It must be a Variable." % var) 558 with ops.colocate_with(var), distribute_lib.UpdateContext(var.device): 559 result = fn(var, *self._select_single_value(args), 560 **self._select_single_value(kwargs)) 561 if group: 562 return result 563 else: 564 return nest.map_structure(self._local_results, result) 565 566 # TODO(yuefengz): does it need to call _select_single_value? 567 def _update_non_slot(self, colocate_with, fn, args, kwargs, group): 568 with ops.device( 569 colocate_with.device), distribute_lib.UpdateContext(colocate_with): 570 result = fn(*args, **kwargs) 571 if group: 572 return result 573 else: 574 return nest.map_structure(self._local_results, result) 575 576 def _local_results(self, val): 577 if isinstance(val, values.DistributedValues): 578 return val.values 579 return (val,) 580 581 def value_container(self, val): 582 if (hasattr(val, "_aggregating_container") and 583 not isinstance(val, ps_values.AggregatingVariable)): 584 wrapper = val._aggregating_container() # pylint: disable=protected-access 585 if wrapper is not None: 586 return wrapper 587 return val 588 589 def read_var(self, var): 590 # No need to distinguish between normal variables and replica-local 591 # variables. 592 return array_ops.identity(var) 593 594 def _configure(self, 595 session_config=None, 596 cluster_spec=None, 597 task_type=None, 598 task_id=None): 599 """Configures the strategy class with `cluster_spec`. 600 601 The strategy object will be re-initialized if `cluster_spec` is passed to 602 `configure` but was not passed when instantiating the strategy. 603 604 Args: 605 session_config: Session config object. 606 cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the 607 cluster configurations. 608 task_type: the current task type. 609 task_id: the current task id. 610 611 Raises: 612 ValueError: if `cluster_spec` is given but `task_type` or `task_id` is 613 not. 614 """ 615 if cluster_spec: 616 # Use the num_gpus_per_worker recorded in constructor since _configure 617 # doesn't take num_gpus. 618 cluster_resolver = SimpleClusterResolver( 619 cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec), 620 task_type=task_type, 621 task_id=task_id, 622 num_accelerators={"GPU": self._num_gpus_per_worker}) 623 self._initialize_multi_worker(cluster_resolver) 624 625 if session_config: 626 session_config.CopyFrom(self._update_config_proto(session_config)) 627 628 def _update_config_proto(self, config_proto): 629 updated_config = copy.deepcopy(config_proto) 630 if not self._cluster_spec: 631 updated_config.isolate_session_state = True 632 return updated_config 633 634 updated_config.isolate_session_state = False 635 636 assert self._task_type 637 assert self._task_id is not None 638 639 # The device filters prevent communication between workers. 640 del updated_config.device_filters[:] 641 if self._task_type in ["chief", "worker"]: 642 updated_config.device_filters.extend( 643 ["/job:%s/task:%d" % (self._task_type, self._task_id), "/job:ps"]) 644 elif self._task_type == "evaluator": 645 updated_config.device_filters.append( 646 "/job:%s/task:%d" % (self._task_type, self._task_id)) 647 return updated_config 648 649 def _in_multi_worker_mode(self): 650 """Whether this strategy indicates working in multi-worker settings.""" 651 return self._cluster_spec is not None 652 653 @property 654 def _num_replicas_in_sync(self): 655 return len(self._compute_devices) 656 657 @property 658 def worker_devices(self): 659 return self._compute_devices 660 661 @property 662 def worker_devices_by_replica(self): 663 return [[d] for d in self._compute_devices] 664 665 @property 666 def parameter_devices(self): 667 return self._parameter_devices 668 669 def non_slot_devices(self, var_list): 670 return min(var_list, key=lambda x: x.name) 671 672 @property 673 def experimental_between_graph(self): 674 # TODO(yuefengz): Should this return False in the local case? 675 return True 676 677 @property 678 def experimental_should_init(self): 679 return self._is_chief 680 681 @property 682 def should_checkpoint(self): 683 return self._is_chief 684 685 @property 686 def should_save_summary(self): 687 return self._is_chief 688 689 # TODO(priyag): Delete this once all strategies use global batch size. 690 @property 691 def _global_batch_size(self): 692 """`make_dataset_iterator` and `make_numpy_iterator` use global batch size. 693 694 `make_input_fn_iterator` assumes per-replica batching. 695 696 Returns: 697 Boolean. 698 """ 699 return True 700 701 def _get_local_replica_id(self, replica_id_in_sync_group): 702 return replica_id_in_sync_group 703 704 def _get_replica_id_in_sync_group(self, replica_id): 705 return replica_id 706