• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Class implementing a multi-worker parameter server tf.distribute strategy."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22
23
24from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
25from tensorflow.python.distribute import device_util
26from tensorflow.python.distribute import distribute_lib
27from tensorflow.python.distribute import distribute_utils
28from tensorflow.python.distribute import input_lib
29from tensorflow.python.distribute import mirrored_run
30from tensorflow.python.distribute import multi_worker_util
31from tensorflow.python.distribute import numpy_dataset
32from tensorflow.python.distribute import ps_values
33from tensorflow.python.distribute import values
34from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
35from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
36from tensorflow.python.eager import context
37from tensorflow.python.framework import device as tf_device
38from tensorflow.python.framework import ops
39from tensorflow.python.ops import array_ops
40from tensorflow.python.ops import resource_variable_ops
41from tensorflow.python.ops import variable_scope as vs
42from tensorflow.python.platform import tf_logging as logging
43from tensorflow.python.training import device_setter
44from tensorflow.python.util import nest
45from tensorflow.python.util.tf_export import tf_export
46
47_LOCAL_CPU = "/device:CPU:0"
48
49
50@tf_export(v1=["distribute.experimental.ParameterServerStrategy"])  # pylint: disable=missing-docstring
51class ParameterServerStrategyV1(distribute_lib.StrategyV1):
52  """An asynchronous multi-worker parameter server tf.distribute strategy.
53
54  This strategy requires two roles: workers and parameter servers. Variables and
55  updates to those variables will be assigned to parameter servers and other
56  operations are assigned to workers.
57
58  When each worker has more than one GPU, operations will be replicated on all
59  GPUs. Even though operations may be replicated, variables are not and each
60  worker shares a common view for which parameter server a variable is assigned
61  to.
62
63  By default it uses `TFConfigClusterResolver` to detect configurations for
64  multi-worker training. This requires a 'TF_CONFIG' environment variable and
65  the 'TF_CONFIG' must have a cluster spec.
66
67  This class assumes each worker is running the same code independently, but
68  parameter servers are running a standard server. This means that while each
69  worker will synchronously compute a single gradient update across all GPUs,
70  updates between workers proceed asynchronously. Operations that occur only on
71  the first replica (such as incrementing the global step), will occur on the
72  first replica *of every worker*.
73
74  It is expected to call `call_for_each_replica(fn, ...)` for any
75  operations which potentially can be replicated across replicas (i.e. multiple
76  GPUs) even if there is only CPU or one GPU. When defining the `fn`, extra
77  caution needs to be taken:
78
79  1) It is generally not recommended to open a device scope under the strategy's
80  scope. A device scope (i.e. calling `tf.device`) will be merged with or
81  override the device for operations but will not change the device for
82  variables.
83
84  2) It is also not recommended to open a colocation scope (i.e. calling
85  `tf.compat.v1.colocate_with`) under the strategy's scope. For colocating
86  variables, use `strategy.extended.colocate_vars_with` instead. Colocation of
87  ops will possibly create device assignment conflicts.
88
89  Note: This strategy only works with the Estimator API. Pass an instance of
90  this strategy to the `experimental_distribute` argument when you create the
91  `RunConfig`. This instance of `RunConfig` should then be passed to the
92  `Estimator` instance on which `train_and_evaluate` is called.
93
94  For Example:
95  ```
96  strategy = tf.distribute.experimental.ParameterServerStrategy()
97  run_config = tf.estimator.RunConfig(
98      experimental_distribute.train_distribute=strategy)
99  estimator = tf.estimator.Estimator(config=run_config)
100  tf.estimator.train_and_evaluate(estimator,...)
101  ```
102  """
103
104  def __init__(self, cluster_resolver=None):
105    """Initializes this strategy with an optional `cluster_resolver`.
106
107    Args:
108      cluster_resolver: Optional
109        `tf.distribute.cluster_resolver.ClusterResolver` object. Defaults to a
110        `tf.distribute.cluster_resolver.TFConfigClusterResolver`.
111    """
112    if cluster_resolver is None:
113      cluster_resolver = TFConfigClusterResolver()
114    super(ParameterServerStrategyV1, self).__init__(
115        ParameterServerStrategyExtended(
116            self, cluster_resolver=cluster_resolver))
117    distribute_lib.distribution_strategy_gauge.get_cell("V1").set(
118        "ParameterServerStrategy")
119
120  def experimental_distribute_dataset(self, dataset, options=None):
121    if (options and options.experimental_replication_mode ==
122        distribute_lib.InputReplicationMode.PER_REPLICA):
123      raise NotImplementedError(
124          "InputReplicationMode.PER_REPLICA "
125          "is only supported in "
126          "`experimental_distribute_datasets_from_function`."
127      )
128    self._raise_pss_error_if_eager()
129    super(ParameterServerStrategyV1,
130          self).experimental_distribute_dataset(dataset=dataset,
131                                                options=options)
132
133  def distribute_datasets_from_function(self, dataset_fn, options=None):
134    if (options and options.experimental_replication_mode ==
135        distribute_lib.InputReplicationMode.PER_REPLICA):
136      raise NotImplementedError(
137          "InputReplicationMode.PER_REPLICA "
138          "is only supported in "
139          "`experimental_distribute_datasets_from_function` "
140          "of tf.distribute.MirroredStrategy")
141    self._raise_pss_error_if_eager()
142    super(ParameterServerStrategyV1, self).distribute_datasets_from_function(
143        dataset_fn=dataset_fn, options=options)
144
145  def run(self, fn, args=(), kwargs=None, options=None):
146    self._raise_pss_error_if_eager()
147    super(ParameterServerStrategyV1, self).run(
148        fn, args=args, kwargs=kwargs, options=options)
149
150  def scope(self):
151    self._raise_pss_error_if_eager()
152    return super(ParameterServerStrategyV1, self).scope()
153
154  def _raise_pss_error_if_eager(self):
155    if context.executing_eagerly():
156      raise NotImplementedError(
157          "`tf.compat.v1.distribute.experimental.ParameterServerStrategy` "
158          "currently only works with the tf.Estimator API")
159
160
161# TODO(josh11b): Switch to V2 when we no longer need to support tf.compat.v1.
162class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
163  """Implementation of ParameterServerStrategy and CentralStorageStrategy."""
164
165  def __init__(self,
166               container_strategy,
167               cluster_resolver=None,
168               compute_devices=None,
169               parameter_device=None):
170    super(ParameterServerStrategyExtended, self).__init__(container_strategy)
171    self._initialize_strategy(
172        cluster_resolver=cluster_resolver,
173        compute_devices=compute_devices,
174        parameter_device=parameter_device)
175
176    # We typically don't need to do all-reduce in this strategy.
177    self._cross_device_ops = (
178        cross_device_ops_lib.ReductionToOneDevice(reduce_to_device=_LOCAL_CPU))
179
180  def _initialize_strategy(self,
181                           cluster_resolver=None,
182                           compute_devices=None,
183                           parameter_device=None):
184    if cluster_resolver and cluster_resolver.cluster_spec():
185      self._initialize_multi_worker(cluster_resolver)
186    else:
187      self._initialize_local(
188          compute_devices, parameter_device, cluster_resolver=cluster_resolver)
189
190  def _initialize_multi_worker(self, cluster_resolver):
191    """Initialize devices for multiple workers.
192
193    It creates variable devices and compute devices. Variables and operations
194    will be assigned to them respectively. We have one compute device per
195    replica. The variable device is a device function or device string. The
196    default variable device assigns variables to parameter servers in a
197    round-robin fashion.
198
199    Args:
200      cluster_resolver: a descendant of `ClusterResolver` object.
201
202    Raises:
203      ValueError: if the cluster doesn't have ps jobs.
204    """
205    # TODO(b/126786766): TFConfigClusterResolver returns wrong number of GPUs in
206    # some cases.
207    if isinstance(cluster_resolver, TFConfigClusterResolver):
208      num_gpus = context.num_gpus()
209    else:
210      num_gpus = cluster_resolver.num_accelerators().get("GPU", 0)
211
212    # Save the num_gpus_per_worker for configure method.
213    self._num_gpus_per_worker = num_gpus
214
215    cluster_spec = cluster_resolver.cluster_spec()
216    task_type = cluster_resolver.task_type
217    task_id = cluster_resolver.task_id
218    if not task_type or task_id is None:
219      raise ValueError("When `cluster_spec` is given, you must also specify "
220                       "`task_type` and `task_id`")
221    cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
222    assert cluster_spec.as_dict()
223
224    self._worker_device = "/job:%s/task:%d" % (task_type, task_id)
225    self._input_host_device = numpy_dataset.SingleDevice(self._worker_device)
226
227    # Define compute devices which is a list of device strings and one for each
228    # replica. When there are GPUs, replicate operations on these GPUs.
229    # Otherwise, place operations on CPU.
230    if num_gpus > 0:
231      compute_devices = tuple(
232          "%s/device:GPU:%d" % (self._worker_device, i)
233          for i in range(num_gpus))
234    else:
235      compute_devices = (self._worker_device,)
236
237    self._compute_devices = [
238        device_util.canonicalize(d) for d in compute_devices]
239
240    # In distributed mode, place variables on ps jobs in a round-robin fashion.
241    # Note that devices returned from `replica_device_setter` are not
242    # canonical and therefore we don't canonicalize all variable devices to
243    # make them consistent.
244    # TODO(yuefengz): support passing a strategy object to control variable
245    # assignment.
246    # TODO(yuefengz): merge the logic of replica_device_setter into this
247    # class.
248    num_ps_replicas = len(cluster_spec.as_dict().get("ps", []))
249    if num_ps_replicas == 0:
250      raise ValueError("The cluster spec needs to have `ps` jobs.")
251    self._variable_device = device_setter.replica_device_setter(
252        ps_tasks=num_ps_replicas,
253        worker_device=self._worker_device,
254        merge_devices=True,
255        cluster=cluster_spec)
256
257    # The `_parameter_devices` is needed for the `parameter_devices` property
258    # and is a list of all variable devices. Here parameter devices are all
259    # tasks of the "ps" job.
260    self._parameter_devices = tuple(map("/job:ps/task:{}".format,
261                                        range(num_ps_replicas)))
262
263    # Add a default device so that ops without specified devices will not end up
264    # on other workers.
265    self._default_device = self._worker_device
266
267    self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
268                                                task_id)
269    self._cluster_spec = cluster_spec
270    self._task_type = task_type
271    self._task_id = task_id
272
273    logging.info(
274        "Multi-worker ParameterServerStrategy with "
275        "cluster_spec = %r, task_type = %r, task_id = %r, "
276        "num_ps_replicas = %r, is_chief = %r, compute_devices = %r, "
277        "variable_device = %r", cluster_spec.as_dict(), task_type, task_id,
278        num_ps_replicas, self._is_chief, self._compute_devices,
279        self._variable_device)
280
281  # TODO(yuefengz): get rid of cluster_resolver argument when contrib's
282  # version no longer depends on this class.
283  def _initialize_local(self,
284                        compute_devices,
285                        parameter_device,
286                        cluster_resolver=None):
287    """Initialize local devices for training."""
288    self._worker_device = device_util.canonicalize("/device:CPU:0")
289    self._input_host_device = numpy_dataset.SingleDevice(self._worker_device)
290
291    if compute_devices is None:
292      if not cluster_resolver:
293        num_gpus = context.num_gpus()
294      else:
295        num_gpus = cluster_resolver.num_accelerators().get("GPU", 0)
296      # Save the num_gpus_per_worker for configure method which is used by the
297      # contrib version.
298      self._num_gpus_per_worker = num_gpus
299
300      compute_devices = device_util.local_devices_from_num_gpus(num_gpus)
301
302    compute_devices = [device_util.canonicalize(d) for d in compute_devices]
303
304    if parameter_device is None:
305      # If there is only one GPU, put everything on that GPU. Otherwise, place
306      # variables on CPU.
307      if len(compute_devices) == 1:
308        parameter_device = compute_devices[0]
309      else:
310        parameter_device = _LOCAL_CPU
311
312    self._variable_device = parameter_device
313    self._compute_devices = compute_devices
314    self._parameter_devices = (parameter_device,)
315    self._is_chief = True
316    self._cluster_spec = None
317    self._task_type = None
318    self._task_id = None
319
320    logging.info(
321        "ParameterServerStrategy (CentralStorageStrategy if you are using a "
322        "single machine) with compute_devices = %r, variable_device = %r",
323        compute_devices, self._variable_device)
324
325  def _input_workers_with_options(self, options=None):
326    if not options or options.experimental_fetch_to_device:
327      return input_lib.InputWorkers(
328          [(self._worker_device, self._compute_devices)])
329    else:
330      return input_lib.InputWorkers(
331          [(self._worker_device,
332            (self._worker_device,) * len(self._compute_devices))])
333
334  @property
335  def _input_workers(self):
336    return self._input_workers_with_options()
337
338  def _validate_colocate_with_variable(self, colocate_with_variable):
339    distribute_utils.validate_colocate(colocate_with_variable, self)
340
341  def _experimental_distribute_dataset(self, dataset, options):
342    return input_lib.get_distributed_dataset(
343        dataset,
344        self._input_workers_with_options(options),
345        self._container_strategy(),
346        num_replicas_in_sync=self._num_replicas_in_sync,
347        options=options)
348
349  def _make_dataset_iterator(self, dataset):
350    return input_lib.DatasetIterator(
351        dataset,
352        self._input_workers,
353        self._container_strategy(),
354        num_replicas_in_sync=self._num_replicas_in_sync)
355
356  def _make_input_fn_iterator(
357      self,
358      input_fn,
359      replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
360    """Distributes the dataset to each local GPU."""
361    if self._cluster_spec:
362      input_pipeline_id = multi_worker_util.id_in_cluster(
363          self._cluster_spec, self._task_type, self._task_id)
364      num_input_pipelines = multi_worker_util.worker_count(
365          self._cluster_spec, self._task_type)
366    else:
367      input_pipeline_id = 0
368      num_input_pipelines = 1
369    input_context = distribute_lib.InputContext(
370        num_input_pipelines=num_input_pipelines,
371        input_pipeline_id=input_pipeline_id,
372        num_replicas_in_sync=self._num_replicas_in_sync)
373    return input_lib.InputFunctionIterator(input_fn, self._input_workers,
374                                           [input_context],
375                                           self._container_strategy())
376
377  def _experimental_make_numpy_dataset(self, numpy_input, session):
378    return numpy_dataset.one_host_numpy_dataset(
379        numpy_input, self._input_host_device, session)
380
381  def _distribute_datasets_from_function(self, dataset_fn, options):
382    if self._cluster_spec:
383      input_pipeline_id = multi_worker_util.id_in_cluster(
384          self._cluster_spec, self._task_type, self._task_id)
385      num_input_pipelines = multi_worker_util.worker_count(
386          self._cluster_spec, self._task_type)
387    else:
388      input_pipeline_id = 0
389      num_input_pipelines = 1
390
391    input_context = distribute_lib.InputContext(
392        num_input_pipelines=num_input_pipelines,
393        input_pipeline_id=input_pipeline_id,
394        num_replicas_in_sync=self._num_replicas_in_sync)
395
396    return input_lib.get_distributed_datasets_from_function(
397        dataset_fn,
398        self._input_workers_with_options(options), [input_context],
399        self._container_strategy(),
400        options=options)
401
402  def _experimental_distribute_values_from_function(self, value_fn):
403    per_replica_values = []
404    for replica_id in range(self._num_replicas_in_sync):
405      per_replica_values.append(
406          value_fn(distribute_lib.ValueContext(replica_id,
407                                               self._num_replicas_in_sync)))
408    return distribute_utils.regroup(per_replica_values, always_wrap=True)
409
410  def _broadcast_to(self, tensor, destinations):
411    # This is both a fast path for Python constants, and a way to delay
412    # converting Python values to a tensor until we know what type it
413    # should be converted to. Otherwise we have trouble with:
414    #   global_step.assign_add(1)
415    # since the `1` gets broadcast as an int32 but global_step is int64.
416    if isinstance(tensor, (float, int)):
417      return tensor
418    if not cross_device_ops_lib.check_destinations(destinations):
419      # TODO(josh11b): Use current logical device instead of 0 here.
420      destinations = self._compute_devices
421    return self._cross_device_ops.broadcast(tensor, destinations)
422
423  def _allow_variable_partition(self):
424    return not context.executing_eagerly()
425
426  def _create_var_creator(self, next_creator, **kwargs):
427    if self._num_replicas_in_sync > 1:
428      aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
429      if aggregation not in (
430          vs.VariableAggregation.NONE,
431          vs.VariableAggregation.SUM,
432          vs.VariableAggregation.MEAN,
433          vs.VariableAggregation.ONLY_FIRST_REPLICA
434      ):
435        raise ValueError("Invalid variable aggregation mode: " + aggregation +
436                         " for variable: " + kwargs["name"])
437
438      def var_creator(**kwargs):
439        """Create an AggregatingVariable and fix up collections."""
440        # Record what collections this variable should be added to.
441        collections = kwargs.pop("collections", None)
442        if collections is None:
443          collections = [ops.GraphKeys.GLOBAL_VARIABLES]
444        kwargs["collections"] = []
445
446        # Create and wrap the variable.
447        v = next_creator(**kwargs)
448        wrapped = ps_values.AggregatingVariable(self._container_strategy(), v,
449                                                aggregation)
450
451        # Add the wrapped variable to the requested collections.
452        # The handling of eager mode and the global step matches
453        # ResourceVariable._init_from_args().
454        if not context.executing_eagerly():
455          g = ops.get_default_graph()
456          # If "trainable" is True, next_creator() will add the contained
457          # variable to the TRAINABLE_VARIABLES collection, so we manually
458          # remove it and replace with the wrapper. We can't set "trainable"
459          # to False for next_creator() since that causes functions like
460          # implicit_gradients to skip those variables.
461          if kwargs.get("trainable", True):
462            collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
463            l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
464            if v in l:
465              l.remove(v)
466          g.add_to_collections(collections, wrapped)
467        elif ops.GraphKeys.GLOBAL_STEP in collections:
468          ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, wrapped)
469
470        return wrapped
471      return var_creator
472    else:
473      return next_creator
474
475  # TODO(yuefengz): Not all ops in device_setter.STANDARD_PS_OPS will go through
476  # this creator, such as "MutableHashTable".
477  def _create_variable(self, next_creator, **kwargs):
478    var_creator = self._create_var_creator(next_creator, **kwargs)
479
480    if "colocate_with" in kwargs:
481      colocate_with = kwargs["colocate_with"]
482      if isinstance(colocate_with, numpy_dataset.SingleDevice):
483        with ops.device(colocate_with.device):
484          return var_creator(**kwargs)
485      with ops.device(None):
486        with ops.colocate_with(colocate_with):
487          return var_creator(**kwargs)
488
489    with ops.colocate_with(None, ignore_existing=True):
490      with ops.device(self._variable_device):
491        return var_creator(**kwargs)
492
493  def _call_for_each_replica(self, fn, args, kwargs):
494    return mirrored_run.call_for_each_replica(self._container_strategy(), fn,
495                                              args, kwargs)
496
497  def _verify_destinations_not_different_worker(self, destinations):
498    if not self._cluster_spec:
499      return
500    if destinations is None:
501      return
502    for d in cross_device_ops_lib.get_devices_from(destinations):
503      d_spec = tf_device.DeviceSpec.from_string(d)
504      if d_spec.job == self._task_type and d_spec.task != self._task_id:
505        raise ValueError(
506            "Cannot reduce to another worker: %r, current worker is %r" %
507            (d, self._worker_device))
508
509  def _gather_to_implementation(self, value, destinations, axis,
510                                options):
511    self._verify_destinations_not_different_worker(destinations)
512    if not isinstance(value, values.DistributedValues):
513      return value
514    return self._cross_device_ops._gather(  # pylint: disable=protected-access
515        value,
516        destinations=destinations,
517        axis=axis,
518        options=options)
519
520  def _reduce_to(self, reduce_op, value, destinations, options):
521    self._verify_destinations_not_different_worker(destinations)
522    if not isinstance(value, values.DistributedValues):
523      # pylint: disable=protected-access
524      return cross_device_ops_lib.reduce_non_distributed_value(
525          reduce_op, value, destinations, self._num_replicas_in_sync)
526    return self._cross_device_ops.reduce(
527        reduce_op, value, destinations=destinations, options=options)
528
529  def _batch_reduce_to(self, reduce_op, value_destination_pairs, options):
530    for _, destinations in value_destination_pairs:
531      self._verify_destinations_not_different_worker(destinations)
532    return self._cross_device_ops.batch_reduce(reduce_op,
533                                               value_destination_pairs, options)
534
535  def _select_single_value(self, structured):
536    """Select any single value in `structured`."""
537
538    def _select_fn(x):  # pylint: disable=g-missing-docstring
539      if isinstance(x, values.Mirrored) or isinstance(x, values.PerReplica):
540        return x._primary  # pylint: disable=protected-access
541      else:
542        return x
543
544    return nest.map_structure(_select_fn, structured)
545
546  def _update(self, var, fn, args, kwargs, group):
547    if isinstance(var, ps_values.AggregatingVariable):
548      var = var.get()
549    if not resource_variable_ops.is_resource_variable(var):
550      raise ValueError(
551          "You can not update `var` %r. It must be a Variable." % var)
552    with ops.colocate_with(var), distribute_lib.UpdateContext(var.device):
553      result = fn(var, *self._select_single_value(args),
554                  **self._select_single_value(kwargs))
555      if group:
556        return result
557      else:
558        return nest.map_structure(self._local_results, result)
559
560  # TODO(yuefengz): does it need to call _select_single_value?
561  def _update_non_slot(self, colocate_with, fn, args, kwargs, group):
562    with ops.device(
563        colocate_with.device), distribute_lib.UpdateContext(colocate_with):
564      result = fn(*args, **kwargs)
565      if group:
566        return result
567      else:
568        return nest.map_structure(self._local_results, result)
569
570  def value_container(self, val):
571    if (hasattr(val, "_aggregating_container") and
572        not isinstance(val, ps_values.AggregatingVariable)):
573      wrapper = val._aggregating_container()  # pylint: disable=protected-access
574      if wrapper is not None:
575        return wrapper
576    return val
577
578  def read_var(self, var):
579    # No need to distinguish between normal variables and replica-local
580    # variables.
581    return array_ops.identity(var)
582
583  def _configure(self,
584                 session_config=None,
585                 cluster_spec=None,
586                 task_type=None,
587                 task_id=None):
588    """Configures the strategy class with `cluster_spec`.
589
590    The strategy object will be re-initialized if `cluster_spec` is passed to
591    `configure` but was not passed when instantiating the strategy.
592
593    Args:
594      session_config: Session config object.
595      cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
596        cluster configurations.
597      task_type: the current task type.
598      task_id: the current task id.
599
600    Raises:
601      ValueError: if `cluster_spec` is given but `task_type` or `task_id` is
602        not.
603    """
604    if cluster_spec:
605      # Use the num_gpus_per_worker recorded in constructor since _configure
606      # doesn't take num_gpus.
607      cluster_resolver = SimpleClusterResolver(
608          cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec),
609          task_type=task_type,
610          task_id=task_id,
611          num_accelerators={"GPU": self._num_gpus_per_worker})
612      self._initialize_multi_worker(cluster_resolver)
613
614    if session_config:
615      session_config.CopyFrom(self._update_config_proto(session_config))
616
617  def _update_config_proto(self, config_proto):
618    updated_config = copy.deepcopy(config_proto)
619    if not self._cluster_spec:
620      updated_config.isolate_session_state = True
621      return updated_config
622
623    updated_config.isolate_session_state = False
624
625    assert self._task_type
626    assert self._task_id is not None
627
628    # The device filters prevent communication between workers.
629    del updated_config.device_filters[:]
630    if self._task_type in ["chief", "worker"]:
631      updated_config.device_filters.extend(
632          ["/job:%s/task:%d" % (self._task_type, self._task_id), "/job:ps"])
633    elif self._task_type == "evaluator":
634      updated_config.device_filters.append(
635          "/job:%s/task:%d" % (self._task_type, self._task_id))
636    return updated_config
637
638  def _in_multi_worker_mode(self):
639    """Whether this strategy indicates working in multi-worker settings."""
640    return self._cluster_spec is not None
641
642  @property
643  def _num_replicas_in_sync(self):
644    return len(self._compute_devices)
645
646  @property
647  def worker_devices(self):
648    return self._compute_devices
649
650  @property
651  def worker_devices_by_replica(self):
652    return [[d] for d in self._compute_devices]
653
654  @property
655  def parameter_devices(self):
656    return self._parameter_devices
657
658  def non_slot_devices(self, var_list):
659    return min(var_list, key=lambda x: x.name)
660
661  @property
662  def experimental_between_graph(self):
663    # TODO(yuefengz): Should this return False in the local case?
664    return True
665
666  @property
667  def experimental_should_init(self):
668    return self._is_chief
669
670  @property
671  def should_checkpoint(self):
672    return self._is_chief
673
674  @property
675  def should_save_summary(self):
676    return self._is_chief
677
678  # TODO(priyag): Delete this once all strategies use global batch size.
679  @property
680  def _global_batch_size(self):
681    """`make_dataset_iterator` and `make_numpy_iterator` use global batch size.
682
683    `make_input_fn_iterator` assumes per-replica batching.
684
685    Returns:
686      Boolean.
687    """
688    return True
689
690  def _get_local_replica_id(self, replica_id_in_sync_group):
691    return replica_id_in_sync_group
692
693  def _get_replica_id_in_sync_group(self, replica_id):
694    return replica_id
695