1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Class implementing a multi-worker parameter server tf.distribute strategy.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22 23 24from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib 25from tensorflow.python.distribute import device_util 26from tensorflow.python.distribute import distribute_lib 27from tensorflow.python.distribute import distribute_utils 28from tensorflow.python.distribute import input_lib 29from tensorflow.python.distribute import mirrored_run 30from tensorflow.python.distribute import multi_worker_util 31from tensorflow.python.distribute import numpy_dataset 32from tensorflow.python.distribute import ps_values 33from tensorflow.python.distribute import values 34from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver 35from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver 36from tensorflow.python.eager import context 37from tensorflow.python.framework import device as tf_device 38from tensorflow.python.framework import ops 39from tensorflow.python.ops import array_ops 40from tensorflow.python.ops import resource_variable_ops 41from tensorflow.python.ops import variable_scope as vs 42from tensorflow.python.platform import tf_logging as logging 43from tensorflow.python.training import device_setter 44from tensorflow.python.util import nest 45from tensorflow.python.util.tf_export import tf_export 46 47_LOCAL_CPU = "/device:CPU:0" 48 49 50@tf_export(v1=["distribute.experimental.ParameterServerStrategy"]) # pylint: disable=missing-docstring 51class ParameterServerStrategyV1(distribute_lib.StrategyV1): 52 """An asynchronous multi-worker parameter server tf.distribute strategy. 53 54 This strategy requires two roles: workers and parameter servers. Variables and 55 updates to those variables will be assigned to parameter servers and other 56 operations are assigned to workers. 57 58 When each worker has more than one GPU, operations will be replicated on all 59 GPUs. Even though operations may be replicated, variables are not and each 60 worker shares a common view for which parameter server a variable is assigned 61 to. 62 63 By default it uses `TFConfigClusterResolver` to detect configurations for 64 multi-worker training. This requires a 'TF_CONFIG' environment variable and 65 the 'TF_CONFIG' must have a cluster spec. 66 67 This class assumes each worker is running the same code independently, but 68 parameter servers are running a standard server. This means that while each 69 worker will synchronously compute a single gradient update across all GPUs, 70 updates between workers proceed asynchronously. Operations that occur only on 71 the first replica (such as incrementing the global step), will occur on the 72 first replica *of every worker*. 73 74 It is expected to call `call_for_each_replica(fn, ...)` for any 75 operations which potentially can be replicated across replicas (i.e. multiple 76 GPUs) even if there is only CPU or one GPU. When defining the `fn`, extra 77 caution needs to be taken: 78 79 1) It is generally not recommended to open a device scope under the strategy's 80 scope. A device scope (i.e. calling `tf.device`) will be merged with or 81 override the device for operations but will not change the device for 82 variables. 83 84 2) It is also not recommended to open a colocation scope (i.e. calling 85 `tf.compat.v1.colocate_with`) under the strategy's scope. For colocating 86 variables, use `strategy.extended.colocate_vars_with` instead. Colocation of 87 ops will possibly create device assignment conflicts. 88 89 Note: This strategy only works with the Estimator API. Pass an instance of 90 this strategy to the `experimental_distribute` argument when you create the 91 `RunConfig`. This instance of `RunConfig` should then be passed to the 92 `Estimator` instance on which `train_and_evaluate` is called. 93 94 For Example: 95 ``` 96 strategy = tf.distribute.experimental.ParameterServerStrategy() 97 run_config = tf.estimator.RunConfig( 98 experimental_distribute.train_distribute=strategy) 99 estimator = tf.estimator.Estimator(config=run_config) 100 tf.estimator.train_and_evaluate(estimator,...) 101 ``` 102 """ 103 104 def __init__(self, cluster_resolver=None): 105 """Initializes this strategy with an optional `cluster_resolver`. 106 107 Args: 108 cluster_resolver: Optional 109 `tf.distribute.cluster_resolver.ClusterResolver` object. Defaults to a 110 `tf.distribute.cluster_resolver.TFConfigClusterResolver`. 111 """ 112 if cluster_resolver is None: 113 cluster_resolver = TFConfigClusterResolver() 114 super(ParameterServerStrategyV1, self).__init__( 115 ParameterServerStrategyExtended( 116 self, cluster_resolver=cluster_resolver)) 117 distribute_lib.distribution_strategy_gauge.get_cell("V1").set( 118 "ParameterServerStrategy") 119 120 def experimental_distribute_dataset(self, dataset, options=None): 121 if (options and options.experimental_replication_mode == 122 distribute_lib.InputReplicationMode.PER_REPLICA): 123 raise NotImplementedError( 124 "InputReplicationMode.PER_REPLICA " 125 "is only supported in " 126 "`experimental_distribute_datasets_from_function`." 127 ) 128 self._raise_pss_error_if_eager() 129 super(ParameterServerStrategyV1, 130 self).experimental_distribute_dataset(dataset=dataset, 131 options=options) 132 133 def distribute_datasets_from_function(self, dataset_fn, options=None): 134 if (options and options.experimental_replication_mode == 135 distribute_lib.InputReplicationMode.PER_REPLICA): 136 raise NotImplementedError( 137 "InputReplicationMode.PER_REPLICA " 138 "is only supported in " 139 "`experimental_distribute_datasets_from_function` " 140 "of tf.distribute.MirroredStrategy") 141 self._raise_pss_error_if_eager() 142 super(ParameterServerStrategyV1, self).distribute_datasets_from_function( 143 dataset_fn=dataset_fn, options=options) 144 145 def run(self, fn, args=(), kwargs=None, options=None): 146 self._raise_pss_error_if_eager() 147 super(ParameterServerStrategyV1, self).run( 148 fn, args=args, kwargs=kwargs, options=options) 149 150 def scope(self): 151 self._raise_pss_error_if_eager() 152 return super(ParameterServerStrategyV1, self).scope() 153 154 def _raise_pss_error_if_eager(self): 155 if context.executing_eagerly(): 156 raise NotImplementedError( 157 "`tf.compat.v1.distribute.experimental.ParameterServerStrategy` " 158 "currently only works with the tf.Estimator API") 159 160 161# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1. 162class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1): 163 """Implementation of ParameterServerStrategy and CentralStorageStrategy.""" 164 165 def __init__(self, 166 container_strategy, 167 cluster_resolver=None, 168 compute_devices=None, 169 parameter_device=None): 170 super(ParameterServerStrategyExtended, self).__init__(container_strategy) 171 self._initialize_strategy( 172 cluster_resolver=cluster_resolver, 173 compute_devices=compute_devices, 174 parameter_device=parameter_device) 175 176 # We typically don't need to do all-reduce in this strategy. 177 self._cross_device_ops = ( 178 cross_device_ops_lib.ReductionToOneDevice(reduce_to_device=_LOCAL_CPU)) 179 180 def _initialize_strategy(self, 181 cluster_resolver=None, 182 compute_devices=None, 183 parameter_device=None): 184 if cluster_resolver and cluster_resolver.cluster_spec(): 185 self._initialize_multi_worker(cluster_resolver) 186 else: 187 self._initialize_local( 188 compute_devices, parameter_device, cluster_resolver=cluster_resolver) 189 190 def _initialize_multi_worker(self, cluster_resolver): 191 """Initialize devices for multiple workers. 192 193 It creates variable devices and compute devices. Variables and operations 194 will be assigned to them respectively. We have one compute device per 195 replica. The variable device is a device function or device string. The 196 default variable device assigns variables to parameter servers in a 197 round-robin fashion. 198 199 Args: 200 cluster_resolver: a descendant of `ClusterResolver` object. 201 202 Raises: 203 ValueError: if the cluster doesn't have ps jobs. 204 """ 205 # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in 206 # some cases. 207 if isinstance(cluster_resolver, TFConfigClusterResolver): 208 num_gpus = context.num_gpus() 209 else: 210 num_gpus = cluster_resolver.num_accelerators().get("GPU", 0) 211 212 # Save the num_gpus_per_worker for configure method. 213 self._num_gpus_per_worker = num_gpus 214 215 cluster_spec = cluster_resolver.cluster_spec() 216 task_type = cluster_resolver.task_type 217 task_id = cluster_resolver.task_id 218 if not task_type or task_id is None: 219 raise ValueError("When `cluster_spec` is given, you must also specify " 220 "`task_type` and `task_id`") 221 cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) 222 assert cluster_spec.as_dict() 223 224 self._worker_device = "/job:%s/task:%d" % (task_type, task_id) 225 self._input_host_device = numpy_dataset.SingleDevice(self._worker_device) 226 227 # Define compute devices which is a list of device strings and one for each 228 # replica. When there are GPUs, replicate operations on these GPUs. 229 # Otherwise, place operations on CPU. 230 if num_gpus > 0: 231 compute_devices = tuple( 232 "%s/device:GPU:%d" % (self._worker_device, i) 233 for i in range(num_gpus)) 234 else: 235 compute_devices = (self._worker_device,) 236 237 self._compute_devices = [ 238 device_util.canonicalize(d) for d in compute_devices] 239 240 # In distributed mode, place variables on ps jobs in a round-robin fashion. 241 # Note that devices returned from `replica_device_setter` are not 242 # canonical and therefore we don't canonicalize all variable devices to 243 # make them consistent. 244 # TODO(yuefengz): support passing a strategy object to control variable 245 # assignment. 246 # TODO(yuefengz): merge the logic of replica_device_setter into this 247 # class. 248 num_ps_replicas = len(cluster_spec.as_dict().get("ps", [])) 249 if num_ps_replicas == 0: 250 raise ValueError("The cluster spec needs to have `ps` jobs.") 251 self._variable_device = device_setter.replica_device_setter( 252 ps_tasks=num_ps_replicas, 253 worker_device=self._worker_device, 254 merge_devices=True, 255 cluster=cluster_spec) 256 257 # The `_parameter_devices` is needed for the `parameter_devices` property 258 # and is a list of all variable devices. Here parameter devices are all 259 # tasks of the "ps" job. 260 self._parameter_devices = tuple(map("/job:ps/task:{}".format, 261 range(num_ps_replicas))) 262 263 # Add a default device so that ops without specified devices will not end up 264 # on other workers. 265 self._default_device = self._worker_device 266 267 self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type, 268 task_id) 269 self._cluster_spec = cluster_spec 270 self._task_type = task_type 271 self._task_id = task_id 272 273 logging.info( 274 "Multi-worker ParameterServerStrategy with " 275 "cluster_spec = %r, task_type = %r, task_id = %r, " 276 "num_ps_replicas = %r, is_chief = %r, compute_devices = %r, " 277 "variable_device = %r", cluster_spec.as_dict(), task_type, task_id, 278 num_ps_replicas, self._is_chief, self._compute_devices, 279 self._variable_device) 280 281 # TODO(yuefengz): get rid of cluster_resolver argument when contrib's 282 # version no longer depends on this class. 283 def _initialize_local(self, 284 compute_devices, 285 parameter_device, 286 cluster_resolver=None): 287 """Initialize local devices for training.""" 288 self._worker_device = device_util.canonicalize("/device:CPU:0") 289 self._input_host_device = numpy_dataset.SingleDevice(self._worker_device) 290 291 if compute_devices is None: 292 if not cluster_resolver: 293 num_gpus = context.num_gpus() 294 else: 295 num_gpus = cluster_resolver.num_accelerators().get("GPU", 0) 296 # Save the num_gpus_per_worker for configure method which is used by the 297 # contrib version. 298 self._num_gpus_per_worker = num_gpus 299 300 compute_devices = device_util.local_devices_from_num_gpus(num_gpus) 301 302 compute_devices = [device_util.canonicalize(d) for d in compute_devices] 303 304 if parameter_device is None: 305 # If there is only one GPU, put everything on that GPU. Otherwise, place 306 # variables on CPU. 307 if len(compute_devices) == 1: 308 parameter_device = compute_devices[0] 309 else: 310 parameter_device = _LOCAL_CPU 311 312 self._variable_device = parameter_device 313 self._compute_devices = compute_devices 314 self._parameter_devices = (parameter_device,) 315 self._is_chief = True 316 self._cluster_spec = None 317 self._task_type = None 318 self._task_id = None 319 320 logging.info( 321 "ParameterServerStrategy (CentralStorageStrategy if you are using a " 322 "single machine) with compute_devices = %r, variable_device = %r", 323 compute_devices, self._variable_device) 324 325 def _input_workers_with_options(self, options=None): 326 if not options or options.experimental_fetch_to_device: 327 return input_lib.InputWorkers( 328 [(self._worker_device, self._compute_devices)]) 329 else: 330 return input_lib.InputWorkers( 331 [(self._worker_device, 332 (self._worker_device,) * len(self._compute_devices))]) 333 334 @property 335 def _input_workers(self): 336 return self._input_workers_with_options() 337 338 def _validate_colocate_with_variable(self, colocate_with_variable): 339 distribute_utils.validate_colocate(colocate_with_variable, self) 340 341 def _experimental_distribute_dataset(self, dataset, options): 342 return input_lib.get_distributed_dataset( 343 dataset, 344 self._input_workers_with_options(options), 345 self._container_strategy(), 346 num_replicas_in_sync=self._num_replicas_in_sync, 347 options=options) 348 349 def _make_dataset_iterator(self, dataset): 350 return input_lib.DatasetIterator( 351 dataset, 352 self._input_workers, 353 self._container_strategy(), 354 num_replicas_in_sync=self._num_replicas_in_sync) 355 356 def _make_input_fn_iterator( 357 self, 358 input_fn, 359 replication_mode=distribute_lib.InputReplicationMode.PER_WORKER): 360 """Distributes the dataset to each local GPU.""" 361 if self._cluster_spec: 362 input_pipeline_id = multi_worker_util.id_in_cluster( 363 self._cluster_spec, self._task_type, self._task_id) 364 num_input_pipelines = multi_worker_util.worker_count( 365 self._cluster_spec, self._task_type) 366 else: 367 input_pipeline_id = 0 368 num_input_pipelines = 1 369 input_context = distribute_lib.InputContext( 370 num_input_pipelines=num_input_pipelines, 371 input_pipeline_id=input_pipeline_id, 372 num_replicas_in_sync=self._num_replicas_in_sync) 373 return input_lib.InputFunctionIterator(input_fn, self._input_workers, 374 [input_context], 375 self._container_strategy()) 376 377 def _experimental_make_numpy_dataset(self, numpy_input, session): 378 return numpy_dataset.one_host_numpy_dataset( 379 numpy_input, self._input_host_device, session) 380 381 def _distribute_datasets_from_function(self, dataset_fn, options): 382 if self._cluster_spec: 383 input_pipeline_id = multi_worker_util.id_in_cluster( 384 self._cluster_spec, self._task_type, self._task_id) 385 num_input_pipelines = multi_worker_util.worker_count( 386 self._cluster_spec, self._task_type) 387 else: 388 input_pipeline_id = 0 389 num_input_pipelines = 1 390 391 input_context = distribute_lib.InputContext( 392 num_input_pipelines=num_input_pipelines, 393 input_pipeline_id=input_pipeline_id, 394 num_replicas_in_sync=self._num_replicas_in_sync) 395 396 return input_lib.get_distributed_datasets_from_function( 397 dataset_fn, 398 self._input_workers_with_options(options), [input_context], 399 self._container_strategy(), 400 options=options) 401 402 def _experimental_distribute_values_from_function(self, value_fn): 403 per_replica_values = [] 404 for replica_id in range(self._num_replicas_in_sync): 405 per_replica_values.append( 406 value_fn(distribute_lib.ValueContext(replica_id, 407 self._num_replicas_in_sync))) 408 return distribute_utils.regroup(per_replica_values, always_wrap=True) 409 410 def _broadcast_to(self, tensor, destinations): 411 # This is both a fast path for Python constants, and a way to delay 412 # converting Python values to a tensor until we know what type it 413 # should be converted to. Otherwise we have trouble with: 414 # global_step.assign_add(1) 415 # since the `1` gets broadcast as an int32 but global_step is int64. 416 if isinstance(tensor, (float, int)): 417 return tensor 418 if not cross_device_ops_lib.check_destinations(destinations): 419 # TODO(josh11b): Use current logical device instead of 0 here. 420 destinations = self._compute_devices 421 return self._cross_device_ops.broadcast(tensor, destinations) 422 423 def _allow_variable_partition(self): 424 return not context.executing_eagerly() 425 426 def _create_var_creator(self, next_creator, **kwargs): 427 if self._num_replicas_in_sync > 1: 428 aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) 429 if aggregation not in ( 430 vs.VariableAggregation.NONE, 431 vs.VariableAggregation.SUM, 432 vs.VariableAggregation.MEAN, 433 vs.VariableAggregation.ONLY_FIRST_REPLICA 434 ): 435 raise ValueError("Invalid variable aggregation mode: " + aggregation + 436 " for variable: " + kwargs["name"]) 437 438 def var_creator(**kwargs): 439 """Create an AggregatingVariable and fix up collections.""" 440 # Record what collections this variable should be added to. 441 collections = kwargs.pop("collections", None) 442 if collections is None: 443 collections = [ops.GraphKeys.GLOBAL_VARIABLES] 444 kwargs["collections"] = [] 445 446 # Create and wrap the variable. 447 v = next_creator(**kwargs) 448 wrapped = ps_values.AggregatingVariable(self._container_strategy(), v, 449 aggregation) 450 451 # Add the wrapped variable to the requested collections. 452 # The handling of eager mode and the global step matches 453 # ResourceVariable._init_from_args(). 454 if not context.executing_eagerly(): 455 g = ops.get_default_graph() 456 # If "trainable" is True, next_creator() will add the contained 457 # variable to the TRAINABLE_VARIABLES collection, so we manually 458 # remove it and replace with the wrapper. We can't set "trainable" 459 # to False for next_creator() since that causes functions like 460 # implicit_gradients to skip those variables. 461 if kwargs.get("trainable", True): 462 collections.append(ops.GraphKeys.TRAINABLE_VARIABLES) 463 l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES) 464 if v in l: 465 l.remove(v) 466 g.add_to_collections(collections, wrapped) 467 elif ops.GraphKeys.GLOBAL_STEP in collections: 468 ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, wrapped) 469 470 return wrapped 471 return var_creator 472 else: 473 return next_creator 474 475 # TODO(yuefengz): Not all ops in device_setter.STANDARD_PS_OPS will go through 476 # this creator, such as "MutableHashTable". 477 def _create_variable(self, next_creator, **kwargs): 478 var_creator = self._create_var_creator(next_creator, **kwargs) 479 480 if "colocate_with" in kwargs: 481 colocate_with = kwargs["colocate_with"] 482 if isinstance(colocate_with, numpy_dataset.SingleDevice): 483 with ops.device(colocate_with.device): 484 return var_creator(**kwargs) 485 with ops.device(None): 486 with ops.colocate_with(colocate_with): 487 return var_creator(**kwargs) 488 489 with ops.colocate_with(None, ignore_existing=True): 490 with ops.device(self._variable_device): 491 return var_creator(**kwargs) 492 493 def _call_for_each_replica(self, fn, args, kwargs): 494 return mirrored_run.call_for_each_replica(self._container_strategy(), fn, 495 args, kwargs) 496 497 def _verify_destinations_not_different_worker(self, destinations): 498 if not self._cluster_spec: 499 return 500 if destinations is None: 501 return 502 for d in cross_device_ops_lib.get_devices_from(destinations): 503 d_spec = tf_device.DeviceSpec.from_string(d) 504 if d_spec.job == self._task_type and d_spec.task != self._task_id: 505 raise ValueError( 506 "Cannot reduce to another worker: %r, current worker is %r" % 507 (d, self._worker_device)) 508 509 def _gather_to_implementation(self, value, destinations, axis, 510 options): 511 self._verify_destinations_not_different_worker(destinations) 512 if not isinstance(value, values.DistributedValues): 513 return value 514 return self._cross_device_ops._gather( # pylint: disable=protected-access 515 value, 516 destinations=destinations, 517 axis=axis, 518 options=options) 519 520 def _reduce_to(self, reduce_op, value, destinations, options): 521 self._verify_destinations_not_different_worker(destinations) 522 if not isinstance(value, values.DistributedValues): 523 # pylint: disable=protected-access 524 return cross_device_ops_lib.reduce_non_distributed_value( 525 reduce_op, value, destinations, self._num_replicas_in_sync) 526 return self._cross_device_ops.reduce( 527 reduce_op, value, destinations=destinations, options=options) 528 529 def _batch_reduce_to(self, reduce_op, value_destination_pairs, options): 530 for _, destinations in value_destination_pairs: 531 self._verify_destinations_not_different_worker(destinations) 532 return self._cross_device_ops.batch_reduce(reduce_op, 533 value_destination_pairs, options) 534 535 def _select_single_value(self, structured): 536 """Select any single value in `structured`.""" 537 538 def _select_fn(x): # pylint: disable=g-missing-docstring 539 if isinstance(x, values.Mirrored) or isinstance(x, values.PerReplica): 540 return x._primary # pylint: disable=protected-access 541 else: 542 return x 543 544 return nest.map_structure(_select_fn, structured) 545 546 def _update(self, var, fn, args, kwargs, group): 547 if isinstance(var, ps_values.AggregatingVariable): 548 var = var.get() 549 if not resource_variable_ops.is_resource_variable(var): 550 raise ValueError( 551 "You can not update `var` %r. It must be a Variable." % var) 552 with ops.colocate_with(var), distribute_lib.UpdateContext(var.device): 553 result = fn(var, *self._select_single_value(args), 554 **self._select_single_value(kwargs)) 555 if group: 556 return result 557 else: 558 return nest.map_structure(self._local_results, result) 559 560 # TODO(yuefengz): does it need to call _select_single_value? 561 def _update_non_slot(self, colocate_with, fn, args, kwargs, group): 562 with ops.device( 563 colocate_with.device), distribute_lib.UpdateContext(colocate_with): 564 result = fn(*args, **kwargs) 565 if group: 566 return result 567 else: 568 return nest.map_structure(self._local_results, result) 569 570 def value_container(self, val): 571 if (hasattr(val, "_aggregating_container") and 572 not isinstance(val, ps_values.AggregatingVariable)): 573 wrapper = val._aggregating_container() # pylint: disable=protected-access 574 if wrapper is not None: 575 return wrapper 576 return val 577 578 def read_var(self, var): 579 # No need to distinguish between normal variables and replica-local 580 # variables. 581 return array_ops.identity(var) 582 583 def _configure(self, 584 session_config=None, 585 cluster_spec=None, 586 task_type=None, 587 task_id=None): 588 """Configures the strategy class with `cluster_spec`. 589 590 The strategy object will be re-initialized if `cluster_spec` is passed to 591 `configure` but was not passed when instantiating the strategy. 592 593 Args: 594 session_config: Session config object. 595 cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the 596 cluster configurations. 597 task_type: the current task type. 598 task_id: the current task id. 599 600 Raises: 601 ValueError: if `cluster_spec` is given but `task_type` or `task_id` is 602 not. 603 """ 604 if cluster_spec: 605 # Use the num_gpus_per_worker recorded in constructor since _configure 606 # doesn't take num_gpus. 607 cluster_resolver = SimpleClusterResolver( 608 cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec), 609 task_type=task_type, 610 task_id=task_id, 611 num_accelerators={"GPU": self._num_gpus_per_worker}) 612 self._initialize_multi_worker(cluster_resolver) 613 614 if session_config: 615 session_config.CopyFrom(self._update_config_proto(session_config)) 616 617 def _update_config_proto(self, config_proto): 618 updated_config = copy.deepcopy(config_proto) 619 if not self._cluster_spec: 620 updated_config.isolate_session_state = True 621 return updated_config 622 623 updated_config.isolate_session_state = False 624 625 assert self._task_type 626 assert self._task_id is not None 627 628 # The device filters prevent communication between workers. 629 del updated_config.device_filters[:] 630 if self._task_type in ["chief", "worker"]: 631 updated_config.device_filters.extend( 632 ["/job:%s/task:%d" % (self._task_type, self._task_id), "/job:ps"]) 633 elif self._task_type == "evaluator": 634 updated_config.device_filters.append( 635 "/job:%s/task:%d" % (self._task_type, self._task_id)) 636 return updated_config 637 638 def _in_multi_worker_mode(self): 639 """Whether this strategy indicates working in multi-worker settings.""" 640 return self._cluster_spec is not None 641 642 @property 643 def _num_replicas_in_sync(self): 644 return len(self._compute_devices) 645 646 @property 647 def worker_devices(self): 648 return self._compute_devices 649 650 @property 651 def worker_devices_by_replica(self): 652 return [[d] for d in self._compute_devices] 653 654 @property 655 def parameter_devices(self): 656 return self._parameter_devices 657 658 def non_slot_devices(self, var_list): 659 return min(var_list, key=lambda x: x.name) 660 661 @property 662 def experimental_between_graph(self): 663 # TODO(yuefengz): Should this return False in the local case? 664 return True 665 666 @property 667 def experimental_should_init(self): 668 return self._is_chief 669 670 @property 671 def should_checkpoint(self): 672 return self._is_chief 673 674 @property 675 def should_save_summary(self): 676 return self._is_chief 677 678 # TODO(priyag): Delete this once all strategies use global batch size. 679 @property 680 def _global_batch_size(self): 681 """`make_dataset_iterator` and `make_numpy_iterator` use global batch size. 682 683 `make_input_fn_iterator` assumes per-replica batching. 684 685 Returns: 686 Boolean. 687 """ 688 return True 689 690 def _get_local_replica_id(self, replica_id_in_sync_group): 691 return replica_id_in_sync_group 692 693 def _get_replica_id_in_sync_group(self, replica_id): 694 return replica_id 695