• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Class implementing a multi-worker parameter server tf.distribute strategy."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22
23
24from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
25from tensorflow.python.distribute import device_util
26from tensorflow.python.distribute import distribute_lib
27from tensorflow.python.distribute import distribute_utils
28from tensorflow.python.distribute import input_lib
29from tensorflow.python.distribute import mirrored_run
30from tensorflow.python.distribute import multi_worker_util
31from tensorflow.python.distribute import numpy_dataset
32from tensorflow.python.distribute import ps_values
33from tensorflow.python.distribute import values
34from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
35from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
36from tensorflow.python.eager import context
37from tensorflow.python.framework import device as tf_device
38from tensorflow.python.framework import ops
39from tensorflow.python.ops import array_ops
40from tensorflow.python.ops import resource_variable_ops
41from tensorflow.python.ops import variable_scope as vs
42from tensorflow.python.platform import tf_logging as logging
43from tensorflow.python.training import device_setter
44from tensorflow.python.util import nest
45from tensorflow.python.util.tf_export import tf_export
46
47_LOCAL_CPU = "/device:CPU:0"
48
49
50@tf_export(v1=["distribute.experimental.ParameterServerStrategy"])  # pylint: disable=missing-docstring
51class ParameterServerStrategyV1(distribute_lib.StrategyV1):
52  """An asynchronous multi-worker parameter server tf.distribute strategy.
53
54  This strategy requires two roles: workers and parameter servers. Variables and
55  updates to those variables will be assigned to parameter servers and other
56  operations are assigned to workers.
57
58  When each worker has more than one GPU, operations will be replicated on all
59  GPUs. Even though operations may be replicated, variables are not and each
60  worker shares a common view for which parameter server a variable is assigned
61  to.
62
63  By default it uses `TFConfigClusterResolver` to detect configurations for
64  multi-worker training. This requires a 'TF_CONFIG' environment variable and
65  the 'TF_CONFIG' must have a cluster spec.
66
67  This class assumes each worker is running the same code independently, but
68  parameter servers are running a standard server. This means that while each
69  worker will synchronously compute a single gradient update across all GPUs,
70  updates between workers proceed asynchronously. Operations that occur only on
71  the first replica (such as incrementing the global step), will occur on the
72  first replica *of every worker*.
73
74  It is expected to call `call_for_each_replica(fn, ...)` for any
75  operations which potentially can be replicated across replicas (i.e. multiple
76  GPUs) even if there is only CPU or one GPU. When defining the `fn`, extra
77  caution needs to be taken:
78
79  1) It is generally not recommended to open a device scope under the strategy's
80  scope. A device scope (i.e. calling `tf.device`) will be merged with or
81  override the device for operations but will not change the device for
82  variables.
83
84  2) It is also not recommended to open a colocation scope (i.e. calling
85  `tf.compat.v1.colocate_with`) under the strategy's scope. For colocating
86  variables, use `strategy.extended.colocate_vars_with` instead. Colocation of
87  ops will possibly create device assignment conflicts.
88
89  Note: This strategy only works with the Estimator API. Pass an instance of
90  this strategy to the `experimental_distribute` argument when you create the
91  `RunConfig`. This instance of `RunConfig` should then be passed to the
92  `Estimator` instance on which `train_and_evaluate` is called.
93
94  For Example:
95  ```
96  strategy = tf.distribute.experimental.ParameterServerStrategy()
97  run_config = tf.estimator.RunConfig(
98      experimental_distribute.train_distribute=strategy)
99  estimator = tf.estimator.Estimator(config=run_config)
100  tf.estimator.train_and_evaluate(estimator,...)
101  ```
102  """
103
104  def __init__(self, cluster_resolver=None):
105    """Initializes this strategy with an optional `cluster_resolver`.
106
107    Args:
108      cluster_resolver: Optional
109        `tf.distribute.cluster_resolver.ClusterResolver` object. Defaults to a
110        `tf.distribute.cluster_resolver.TFConfigClusterResolver`.
111    """
112    if cluster_resolver is None:
113      cluster_resolver = TFConfigClusterResolver()
114    super(ParameterServerStrategyV1, self).__init__(
115        ParameterServerStrategyExtended(
116            self, cluster_resolver=cluster_resolver))
117    distribute_lib.distribution_strategy_gauge.get_cell("V1").set(
118        "ParameterServerStrategy")
119
120  def experimental_distribute_dataset(self, dataset, options=None):
121    if (options and options.experimental_replication_mode ==
122        distribute_lib.InputReplicationMode.PER_REPLICA):
123      raise NotImplementedError(
124          "InputReplicationMode.PER_REPLICA "
125          "is only supported in "
126          "`experimental_distribute_datasets_from_function`."
127      )
128    self._raise_pss_error_if_eager()
129    super(ParameterServerStrategyV1,
130          self).experimental_distribute_dataset(dataset=dataset,
131                                                options=options)
132
133  def distribute_datasets_from_function(self, dataset_fn, options=None):
134    if (options and options.experimental_replication_mode ==
135        distribute_lib.InputReplicationMode.PER_REPLICA):
136      raise NotImplementedError(
137          "InputReplicationMode.PER_REPLICA "
138          "is only supported in "
139          "`experimental_distribute_datasets_from_function` "
140          "of tf.distribute.MirroredStrategy")
141    self._raise_pss_error_if_eager()
142    super(ParameterServerStrategyV1, self).distribute_datasets_from_function(
143        dataset_fn=dataset_fn, options=options)
144
145  def run(self, fn, args=(), kwargs=None, options=None):
146    self._raise_pss_error_if_eager()
147    super(ParameterServerStrategyV1, self).run(
148        fn, args=args, kwargs=kwargs, options=options)
149
150  def scope(self):
151    self._raise_pss_error_if_eager()
152    return super(ParameterServerStrategyV1, self).scope()
153
154  def _raise_pss_error_if_eager(self):
155    if context.executing_eagerly():
156      raise NotImplementedError(
157          "`tf.compat.v1.distribute.experimental.ParameterServerStrategy` "
158          "currently only works with the tf.Estimator API")
159
160
161# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1.
162class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
163  """Implementation of ParameterServerStrategy and CentralStorageStrategy."""
164
165  def __init__(self,
166               container_strategy,
167               cluster_resolver=None,
168               compute_devices=None,
169               parameter_device=None):
170    super(ParameterServerStrategyExtended, self).__init__(container_strategy)
171    self._initialize_strategy(
172        cluster_resolver=cluster_resolver,
173        compute_devices=compute_devices,
174        parameter_device=parameter_device)
175
176    # We typically don't need to do all-reduce in this strategy.
177    self._cross_device_ops = (
178        cross_device_ops_lib.ReductionToOneDevice(reduce_to_device=_LOCAL_CPU))
179
180  def _initialize_strategy(self,
181                           cluster_resolver=None,
182                           compute_devices=None,
183                           parameter_device=None):
184    if cluster_resolver and cluster_resolver.cluster_spec():
185      self._initialize_multi_worker(cluster_resolver)
186    else:
187      self._initialize_local(
188          compute_devices, parameter_device, cluster_resolver=cluster_resolver)
189
190  def _initialize_multi_worker(self, cluster_resolver):
191    """Initialize devices for multiple workers.
192
193    It creates variable devices and compute devices. Variables and operations
194    will be assigned to them respectively. We have one compute device per
195    replica. The variable device is a device function or device string. The
196    default variable device assigns variables to parameter servers in a
197    round-robin fashion.
198
199    Args:
200      cluster_resolver: a descendant of `ClusterResolver` object.
201
202    Raises:
203      ValueError: if the cluster doesn't have ps jobs.
204    """
205    # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in
206    # some cases.
207    if isinstance(cluster_resolver, TFConfigClusterResolver):
208      num_gpus = context.num_gpus()
209    else:
210      num_gpus = cluster_resolver.num_accelerators().get("GPU", 0)
211
212    # Save the num_gpus_per_worker for configure method.
213    self._num_gpus_per_worker = num_gpus
214
215    cluster_spec = cluster_resolver.cluster_spec()
216    task_type = cluster_resolver.task_type
217    task_id = cluster_resolver.task_id
218    if not task_type or task_id is None:
219      raise ValueError("When `cluster_spec` is given, you must also specify "
220                       "`task_type` and `task_id`")
221    cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
222    assert cluster_spec.as_dict()
223
224    self._worker_device = "/job:%s/task:%d" % (task_type, task_id)
225    self._input_host_device = numpy_dataset.SingleDevice(self._worker_device)
226
227    # Define compute devices which is a list of device strings and one for each
228    # replica. When there are GPUs, replicate operations on these GPUs.
229    # Otherwise, place operations on CPU.
230    if num_gpus > 0:
231      compute_devices = tuple(
232          "%s/device:GPU:%d" % (self._worker_device, i)
233          for i in range(num_gpus))
234    else:
235      compute_devices = (self._worker_device,)
236
237    self._compute_devices = [
238        device_util.canonicalize(d) for d in compute_devices]
239
240    # In distributed mode, place variables on ps jobs in a round-robin fashion.
241    # Note that devices returned from `replica_device_setter` are not
242    # canonical and therefore we don't canonicalize all variable devices to
243    # make them consistent.
244    # TODO(yuefengz): support passing a strategy object to control variable
245    # assignment.
246    # TODO(yuefengz): merge the logic of replica_device_setter into this
247    # class.
248    num_ps_replicas = len(cluster_spec.as_dict().get("ps", []))
249    if num_ps_replicas == 0:
250      raise ValueError("The cluster spec needs to have `ps` jobs.")
251    self._variable_device = device_setter.replica_device_setter(
252        ps_tasks=num_ps_replicas,
253        worker_device=self._worker_device,
254        merge_devices=True,
255        cluster=cluster_spec)
256
257    # The `_parameter_devices` is needed for the `parameter_devices` property
258    # and is a list of all variable devices. Here parameter devices are all
259    # tasks of the "ps" job.
260    self._parameter_devices = tuple(map("/job:ps/task:{}".format,
261                                        range(num_ps_replicas)))
262
263    # Add a default device so that ops without specified devices will not end up
264    # on other workers.
265    self._default_device = self._worker_device
266
267    self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
268                                                task_id)
269    self._cluster_spec = cluster_spec
270    self._task_type = task_type
271    self._task_id = task_id
272
273    logging.info(
274        "Multi-worker ParameterServerStrategy with "
275        "cluster_spec = %r, task_type = %r, task_id = %r, "
276        "num_ps_replicas = %r, is_chief = %r, compute_devices = %r, "
277        "variable_device = %r", cluster_spec.as_dict(), task_type, task_id,
278        num_ps_replicas, self._is_chief, self._compute_devices,
279        self._variable_device)
280
281  # TODO(yuefengz): get rid of cluster_resolver argument when contrib's
282  # version no longer depends on this class.
283  def _initialize_local(self,
284                        compute_devices,
285                        parameter_device,
286                        cluster_resolver=None):
287    """Initialize local devices for training."""
288    self._worker_device = device_util.canonicalize("/device:CPU:0")
289    self._input_host_device = numpy_dataset.SingleDevice(self._worker_device)
290
291    if compute_devices is None:
292      if not cluster_resolver:
293        num_gpus = context.num_gpus()
294      else:
295        num_gpus = cluster_resolver.num_accelerators().get("GPU", 0)
296      # Save the num_gpus_per_worker for configure method which is used by the
297      # contrib version.
298      self._num_gpus_per_worker = num_gpus
299
300      compute_devices = device_util.local_devices_from_num_gpus(num_gpus)
301
302    compute_devices = [device_util.canonicalize(d) for d in compute_devices]
303
304    if parameter_device is None:
305      # If there is only one GPU, put everything on that GPU. Otherwise, place
306      # variables on CPU.
307      if len(compute_devices) == 1:
308        parameter_device = compute_devices[0]
309      else:
310        parameter_device = _LOCAL_CPU
311
312    self._variable_device = parameter_device
313    self._compute_devices = compute_devices
314    self._parameter_devices = (parameter_device,)
315    self._is_chief = True
316    self._cluster_spec = None
317    self._task_type = None
318    self._task_id = None
319
320    logging.info(
321        "ParameterServerStrategy (CentralStorageStrategy if you are using a "
322        "single machine) with compute_devices = %r, variable_device = %r",
323        compute_devices, self._variable_device)
324
325  def _input_workers_with_options(self, options=None):
326    if not options or options.experimental_prefetch_to_device:
327      return input_lib.InputWorkers(
328          [(self._worker_device, self._compute_devices)])
329    else:
330      return input_lib.InputWorkers(
331          [(self._worker_device,
332            (self._worker_device,) * len(self._compute_devices))])
333
334  @property
335  def _input_workers(self):
336    return self._input_workers_with_options()
337
338  def _validate_colocate_with_variable(self, colocate_with_variable):
339    distribute_utils.validate_colocate(colocate_with_variable, self)
340
341  def _experimental_distribute_dataset(self, dataset, options):
342    return input_lib.get_distributed_dataset(
343        dataset,
344        self._input_workers_with_options(options),
345        self._container_strategy(),
346        num_replicas_in_sync=self._num_replicas_in_sync)
347
348  def _make_dataset_iterator(self, dataset):
349    return input_lib.DatasetIterator(
350        dataset,
351        self._input_workers,
352        self._container_strategy(),
353        num_replicas_in_sync=self._num_replicas_in_sync)
354
355  def _make_input_fn_iterator(
356      self,
357      input_fn,
358      replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
359    """Distributes the dataset to each local GPU."""
360    if self._cluster_spec:
361      input_pipeline_id = multi_worker_util.id_in_cluster(
362          self._cluster_spec, self._task_type, self._task_id)
363      num_input_pipelines = multi_worker_util.worker_count(
364          self._cluster_spec, self._task_type)
365    else:
366      input_pipeline_id = 0
367      num_input_pipelines = 1
368    input_context = distribute_lib.InputContext(
369        num_input_pipelines=num_input_pipelines,
370        input_pipeline_id=input_pipeline_id,
371        num_replicas_in_sync=self._num_replicas_in_sync)
372    return input_lib.InputFunctionIterator(input_fn, self._input_workers,
373                                           [input_context],
374                                           self._container_strategy())
375
376  def _experimental_make_numpy_dataset(self, numpy_input, session):
377    return numpy_dataset.one_host_numpy_dataset(
378        numpy_input, self._input_host_device, session)
379
380  def _distribute_datasets_from_function(self, dataset_fn, options):
381    if self._cluster_spec:
382      input_pipeline_id = multi_worker_util.id_in_cluster(
383          self._cluster_spec, self._task_type, self._task_id)
384      num_input_pipelines = multi_worker_util.worker_count(
385          self._cluster_spec, self._task_type)
386    else:
387      input_pipeline_id = 0
388      num_input_pipelines = 1
389
390    input_context = distribute_lib.InputContext(
391        num_input_pipelines=num_input_pipelines,
392        input_pipeline_id=input_pipeline_id,
393        num_replicas_in_sync=self._num_replicas_in_sync)
394
395    return input_lib.get_distributed_datasets_from_function(
396        dataset_fn,
397        self._input_workers_with_options(options),
398        [input_context],
399        self._container_strategy())
400
401  def _experimental_distribute_values_from_function(self, value_fn):
402    per_replica_values = []
403    for replica_id in range(self._num_replicas_in_sync):
404      per_replica_values.append(
405          value_fn(distribute_lib.ValueContext(replica_id,
406                                               self._num_replicas_in_sync)))
407    return distribute_utils.regroup(per_replica_values, always_wrap=True)
408
409  def _broadcast_to(self, tensor, destinations):
410    # This is both a fast path for Python constants, and a way to delay
411    # converting Python values to a tensor until we know what type it
412    # should be converted to. Otherwise we have trouble with:
413    #   global_step.assign_add(1)
414    # since the `1` gets broadcast as an int32 but global_step is int64.
415    if isinstance(tensor, (float, int)):
416      return tensor
417    if not cross_device_ops_lib.check_destinations(destinations):
418      # TODO(josh11b): Use current logical device instead of 0 here.
419      destinations = self._compute_devices
420    return self._cross_device_ops.broadcast(tensor, destinations)
421
422  def _allow_variable_partition(self):
423    return not context.executing_eagerly()
424
425  # TODO(yuefengz): Not all ops in device_setter.STANDARD_PS_OPS will go through
426  # this creator, such as "MutableHashTable".
427  def _create_variable(self, next_creator, **kwargs):
428    if self._num_replicas_in_sync > 1:
429      aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
430      if aggregation not in (
431          vs.VariableAggregation.NONE,
432          vs.VariableAggregation.SUM,
433          vs.VariableAggregation.MEAN,
434          vs.VariableAggregation.ONLY_FIRST_REPLICA
435      ):
436        raise ValueError("Invalid variable aggregation mode: " + aggregation +
437                         " for variable: " + kwargs["name"])
438
439      def var_creator(**kwargs):
440        """Create an AggregatingVariable and fix up collections."""
441        # Record what collections this variable should be added to.
442        collections = kwargs.pop("collections", None)
443        if collections is None:
444          collections = [ops.GraphKeys.GLOBAL_VARIABLES]
445        kwargs["collections"] = []
446
447        # Create and wrap the variable.
448        v = next_creator(**kwargs)
449        wrapped = ps_values.AggregatingVariable(self._container_strategy(), v,
450                                                aggregation)
451
452        # Add the wrapped variable to the requested collections.
453        # The handling of eager mode and the global step matches
454        # ResourceVariable._init_from_args().
455        if not context.executing_eagerly():
456          g = ops.get_default_graph()
457          # If "trainable" is True, next_creator() will add the contained
458          # variable to the TRAINABLE_VARIABLES collection, so we manually
459          # remove it and replace with the wrapper. We can't set "trainable"
460          # to False for next_creator() since that causes functions like
461          # implicit_gradients to skip those variables.
462          if kwargs.get("trainable", True):
463            collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
464            l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
465            if v in l:
466              l.remove(v)
467          g.add_to_collections(collections, wrapped)
468        elif ops.GraphKeys.GLOBAL_STEP in collections:
469          ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, wrapped)
470
471        return wrapped
472    else:
473      var_creator = next_creator
474
475    if "colocate_with" in kwargs:
476      colocate_with = kwargs["colocate_with"]
477      if isinstance(colocate_with, numpy_dataset.SingleDevice):
478        with ops.device(colocate_with.device):
479          return var_creator(**kwargs)
480      with ops.device(None):
481        with ops.colocate_with(colocate_with):
482          return var_creator(**kwargs)
483
484    with ops.colocate_with(None, ignore_existing=True):
485      with ops.device(self._variable_device):
486        return var_creator(**kwargs)
487
488  def _call_for_each_replica(self, fn, args, kwargs):
489    return mirrored_run.call_for_each_replica(self._container_strategy(), fn,
490                                              args, kwargs)
491
492  def _verify_destinations_not_different_worker(self, destinations):
493    if not self._cluster_spec:
494      return
495    if destinations is None:
496      return
497    for d in cross_device_ops_lib.get_devices_from(destinations):
498      d_spec = tf_device.DeviceSpec.from_string(d)
499      if d_spec.job == self._task_type and d_spec.task != self._task_id:
500        raise ValueError(
501            "Cannot reduce to another worker: %r, current worker is %r" %
502            (d, self._worker_device))
503
504  def _gather_to_implementation(self, value, destinations, axis,
505                                options):
506    self._verify_destinations_not_different_worker(destinations)
507    if not isinstance(value, values.DistributedValues):
508      return value
509    return self._cross_device_ops._gather(  # pylint: disable=protected-access
510        value,
511        destinations=destinations,
512        axis=axis,
513        options=options)
514
515  def _reduce_to(self, reduce_op, value, destinations, options):
516    self._verify_destinations_not_different_worker(destinations)
517    if not isinstance(value, values.DistributedValues):
518      # pylint: disable=protected-access
519      return cross_device_ops_lib.reduce_non_distributed_value(
520          reduce_op, value, destinations, self._num_replicas_in_sync)
521    return self._cross_device_ops.reduce(
522        reduce_op, value, destinations=destinations, options=options)
523
524  def _batch_reduce_to(self, reduce_op, value_destination_pairs, options):
525    for _, destinations in value_destination_pairs:
526      self._verify_destinations_not_different_worker(destinations)
527    return self._cross_device_ops.batch_reduce(reduce_op,
528                                               value_destination_pairs, options)
529
530  def _select_single_value(self, structured):
531    """Select any single value in `structured`."""
532
533    def _select_fn(x):  # pylint: disable=g-missing-docstring
534      if isinstance(x, values.Mirrored):
535        if len(x._devices) == 1:  # pylint: disable=protected-access
536          return x._primary  # pylint: disable=protected-access
537        else:
538          raise ValueError(
539              "You cannot update variable with a Mirrored object with multiple "
540              "components %r when using ParameterServerStrategy. You must "
541              "specify a single value or a Mirrored with a single value." % x)
542      elif isinstance(x, values.PerReplica):
543        raise ValueError(
544            "You cannot update variable with a PerReplica object %r when using "
545            "ParameterServerStrategy. You must specify a single value or a "
546            "Mirrored with a single value" % x)
547      else:
548        return x
549
550    return nest.map_structure(_select_fn, structured)
551
552  def _update(self, var, fn, args, kwargs, group):
553    if isinstance(var, ps_values.AggregatingVariable):
554      var = var.get()
555    if not resource_variable_ops.is_resource_variable(var):
556      raise ValueError(
557          "You can not update `var` %r. It must be a Variable." % var)
558    with ops.colocate_with(var), distribute_lib.UpdateContext(var.device):
559      result = fn(var, *self._select_single_value(args),
560                  **self._select_single_value(kwargs))
561      if group:
562        return result
563      else:
564        return nest.map_structure(self._local_results, result)
565
566  # TODO(yuefengz): does it need to call _select_single_value?
567  def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
568    with ops.device(
569        colocate_with.device), distribute_lib.UpdateContext(colocate_with):
570      result = fn(*args, **kwargs)
571      if group:
572        return result
573      else:
574        return nest.map_structure(self._local_results, result)
575
576  def _local_results(self, val):
577    if isinstance(val, values.DistributedValues):
578      return val.values
579    return (val,)
580
581  def value_container(self, val):
582    if (hasattr(val, "_aggregating_container") and
583        not isinstance(val, ps_values.AggregatingVariable)):
584      wrapper = val._aggregating_container()  # pylint: disable=protected-access
585      if wrapper is not None:
586        return wrapper
587    return val
588
589  def read_var(self, var):
590    # No need to distinguish between normal variables and replica-local
591    # variables.
592    return array_ops.identity(var)
593
594  def _configure(self,
595                 session_config=None,
596                 cluster_spec=None,
597                 task_type=None,
598                 task_id=None):
599    """Configures the strategy class with `cluster_spec`.
600
601    The strategy object will be re-initialized if `cluster_spec` is passed to
602    `configure` but was not passed when instantiating the strategy.
603
604    Args:
605      session_config: Session config object.
606      cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
607        cluster configurations.
608      task_type: the current task type.
609      task_id: the current task id.
610
611    Raises:
612      ValueError: if `cluster_spec` is given but `task_type` or `task_id` is
613        not.
614    """
615    if cluster_spec:
616      # Use the num_gpus_per_worker recorded in constructor since _configure
617      # doesn't take num_gpus.
618      cluster_resolver = SimpleClusterResolver(
619          cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec),
620          task_type=task_type,
621          task_id=task_id,
622          num_accelerators={"GPU": self._num_gpus_per_worker})
623      self._initialize_multi_worker(cluster_resolver)
624
625    if session_config:
626      session_config.CopyFrom(self._update_config_proto(session_config))
627
628  def _update_config_proto(self, config_proto):
629    updated_config = copy.deepcopy(config_proto)
630    if not self._cluster_spec:
631      updated_config.isolate_session_state = True
632      return updated_config
633
634    updated_config.isolate_session_state = False
635
636    assert self._task_type
637    assert self._task_id is not None
638
639    # The device filters prevent communication between workers.
640    del updated_config.device_filters[:]
641    if self._task_type in ["chief", "worker"]:
642      updated_config.device_filters.extend(
643          ["/job:%s/task:%d" % (self._task_type, self._task_id), "/job:ps"])
644    elif self._task_type == "evaluator":
645      updated_config.device_filters.append(
646          "/job:%s/task:%d" % (self._task_type, self._task_id))
647    return updated_config
648
649  def _in_multi_worker_mode(self):
650    """Whether this strategy indicates working in multi-worker settings."""
651    return self._cluster_spec is not None
652
653  @property
654  def _num_replicas_in_sync(self):
655    return len(self._compute_devices)
656
657  @property
658  def worker_devices(self):
659    return self._compute_devices
660
661  @property
662  def worker_devices_by_replica(self):
663    return [[d] for d in self._compute_devices]
664
665  @property
666  def parameter_devices(self):
667    return self._parameter_devices
668
669  def non_slot_devices(self, var_list):
670    return min(var_list, key=lambda x: x.name)
671
672  @property
673  def experimental_between_graph(self):
674    # TODO(yuefengz): Should this return False in the local case?
675    return True
676
677  @property
678  def experimental_should_init(self):
679    return self._is_chief
680
681  @property
682  def should_checkpoint(self):
683    return self._is_chief
684
685  @property
686  def should_save_summary(self):
687    return self._is_chief
688
689  # TODO(priyag): Delete this once all strategies use global batch size.
690  @property
691  def _global_batch_size(self):
692    """`make_dataset_iterator` and `make_numpy_iterator` use global batch size.
693
694    `make_input_fn_iterator` assumes per-replica batching.
695
696    Returns:
697      Boolean.
698    """
699    return True
700
701  def _get_local_replica_id(self, replica_id_in_sync_group):
702    return replica_id_in_sync_group
703
704  def _get_replica_id_in_sync_group(self, replica_id):
705    return replica_id
706