• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python3
2# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Parameter server strategy V2 class.
17
18This is currently under development and the API is subject to change.
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import os
26
27from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
28from tensorflow.python.distribute import device_util
29from tensorflow.python.distribute import distribute_lib
30from tensorflow.python.distribute import input_lib
31from tensorflow.python.distribute import mirrored_run
32from tensorflow.python.distribute import multi_worker_util
33from tensorflow.python.distribute import parameter_server_strategy
34from tensorflow.python.distribute import ps_values
35from tensorflow.python.distribute import sharded_variable
36from tensorflow.python.distribute import values
37from tensorflow.python.eager import remote
38from tensorflow.python.framework import config
39from tensorflow.python.framework import device as tf_device
40from tensorflow.python.framework import ops
41from tensorflow.python.framework import tensor_shape
42from tensorflow.python.ops import variable_scope as vs
43from tensorflow.python.platform import tf_logging as logging
44from tensorflow.python.training import server_lib
45from tensorflow.python.training.tracking import base as trackable
46from tensorflow.python.util import nest
47from tensorflow.python.util import tf_inspect
48from tensorflow.python.util.tf_export import tf_export
49
50ALLOWED_TASK_TYPES = ("chief", "worker", "ps")
51
52
53@tf_export("distribute.experimental.ParameterServerStrategy", v1=[])
54class ParameterServerStrategyV2(distribute_lib.Strategy):
55  """An multi-worker tf.distribute strategy with parameter servers.
56
57  Parameter server training is a common data-parallel method to scale up a
58  machine learning model on multiple machines. A parameter server training
59  cluster consists of workers and parameter servers. Variables are created on
60  parameter servers and they are read and updated by workers in each step.
61  By default, workers read and update these variables independently without
62  synchronizing with each other. Under this configuration, it is known as
63  asynchronous training.
64
65  In TensorFlow 2, we recommend an architecture based on central coordination
66  for parameter server training. Each worker and parameter server runs a
67  `tf.distribute.Server`, and on top of that, a coordinator task is responsible
68  for creating resources on workers and parameter servers, dispatching
69  functions, and coordinating the training. The coordinator uses a
70  `tf.distribute.experimental.coordinator.ClusterCoordinator` to coordinate the
71  cluster, and a `tf.distribute.experimental.ParameterServerStrategy` to define
72  variables on parameter servers and computation on workers.
73
74  For the training to work, the coordinator dispatches `tf.function`s to be
75  executed on remote workers. Upon receiving requests from the coordinator, a
76  worker executes the `tf.function` by reading the variables from parameter
77  servers, executing the ops, and updating the variables on the parameter
78  servers. Each of the worker only processes the requests from the coordinator,
79  and communicates with parameter servers, without direct interactions with
80  other workers in the cluster.
81
82  As a result, failures of some workers do not prevent the cluster from
83  continuing the work, and this allows the cluster to train with instances that
84  can be occasionally unavailable (e.g. preemptible or spot instances). The
85  coordinator and parameter servers though, must be available at all times for
86  the cluster to make progress.
87
88  Note that the coordinator is not one of the training workers. Instead, it
89  creates resources such as variables and datasets, dispatchs `tf.function`s,
90  saves checkpoints and so on. In addition to workers, parameter servers and
91  the coordinator, an optional evaluator can be run on the side that
92  periodically reads the checkpoints saved by the coordinator and runs
93  evaluations against each checkpoint.
94
95  `ParameterServerStrategy` is supported with two training APIs: [Custom
96  Training Loop (CTL)]
97  (https://www.tensorflow.org/tutorials/distribute/custom_training)
98  and [Keras Training API, also known as `Model.fit`]
99  (https://www.tensorflow.org/tutorials/distribute/keras). CTL is recommended
100  when users prefer to define the details of their training loop, and
101  `Model.fit` is recommended when users prefer a high-level abstraction and
102  handling of training.
103
104  When using a CTL, `ParameterServerStrategy` has to work in conjunction with a
105  `tf.distribute.experimental.coordinator.ClusterCoordinator` object.
106
107  When using `Model.fit`, currently only the
108  `tf.keras.utils.experimental.DatasetCreator` input type is supported.
109
110  __Example code for coordinator__
111
112  This section provides code snippets that are intended to be run on (the only)
113  one task that is designated as the coordinator. Note that `cluster_resolver`,
114  `variable_partitioner`, and `dataset_fn` arguments are explained in the
115  following "Cluster setup", "Variable partitioning", and "Dataset preparation"
116  sections.
117
118  With a CTL,
119
120  ```python
121  # Prepare a strategy to use with the cluster and variable partitioning info.
122  strategy = tf.distribute.experimental.ParameterServerStrategy(
123      cluster_resolver=...,
124      variable_partitioner=...)
125  coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
126      strategy=strategy)
127
128  # Prepare a distribute dataset that will place datasets on the workers.
129  distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn=...)
130
131  with strategy.scope():
132    model = ...
133    optimizer, metrics = ...  # Keras optimizer/metrics are great choices
134    checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
135    checkpoint_manager = tf.train.CheckpointManager(
136        checkpoint, checkpoint_dir, max_to_keep=2)
137    # `load_checkpoint` infers initial epoch from `optimizer.iterations`.
138    initial_epoch = load_checkpoint(checkpoint_manager) or 0
139
140  @tf.function
141  def worker_fn(iterator):
142
143    def replica_fn(inputs):
144      batch_data, labels = inputs
145      # calculate gradient, applying gradient, metrics update etc.
146
147    strategy.run(replica_fn, args=(next(iterator),))
148
149  for epoch in range(initial_epoch, num_epoch):
150    distributed_iterator = iter(distributed_dataset)  # Reset iterator state.
151    for step in range(steps_per_epoch):
152
153      # Asynchronously schedule the `worker_fn` to be executed on an arbitrary
154      # worker. This call returns immediately.
155      coordinator.schedule(worker_fn, args=(distributed_iterator,))
156
157    # `join` blocks until all scheduled `worker_fn`s finish execution. Once it
158    # returns, we can read the metrics and save checkpoints as needed.
159    coordinator.join()
160    logging.info('Metric result: %r', metrics.result())
161    train_accuracy.reset_states()
162    checkpoint_manager.save()
163  ```
164
165  With `Model.fit`,
166
167  ```python
168  # Prepare a strategy to use with the cluster and variable partitioning info.
169  strategy = tf.distribute.experimental.ParameterServerStrategy(
170      cluster_resolver=...,
171      variable_partitioner=...)
172
173  # A dataset function takes a `input_context` and returns a `Dataset`
174  def dataset_fn(input_context):
175    dataset = tf.data.Dataset.from_tensors(...)
176    return dataset.repeat().shard(...).batch(...).prefetch(...)
177
178  # With `Model.fit`, a `DatasetCreator` needs to be used.
179  input = tf.keras.utils.experimental.DatasetCreator(dataset_fn=...)
180
181  with strategy.scope():
182    model = ...  # Make sure the `Model` is created within scope.
183  model.compile(optimizer="rmsprop", loss="mse", steps_per_execution=..., ...)
184
185  # Optional callbacks to checkpoint the model, back up the progress, etc.
186  callbacks = [tf.keras.callbacks.ModelCheckpoint(...), ...]
187
188  # `steps_per_epoch` is required with `ParameterServerStrategy`.
189  model.fit(input, epochs=..., steps_per_epoch=..., callbacks=callbacks)
190  ```
191
192  __Example code for worker and parameter servers__
193
194  In addition to the coordinator, there should be tasks designated as
195  "worker" or "ps". They should run the following code to start a TensorFlow
196  server, waiting for coordinator's requests:
197
198  ```python
199  # Provide a `tf.distribute.cluster_resolver.ClusterResolver` that serves
200  # the cluster information. See below "Cluster setup" section.
201  cluster_resolver = ...
202
203  server = tf.distribute.Server(
204      cluster_resolver.cluster_spec(),
205      job_name=cluster_resolver.task_type,
206      task_index=cluster_resolver.task_id,
207      protocol="grpc")
208
209  # Blocking the process that starts a server from exiting.
210  server.join()
211  ```
212
213  __Cluster setup__
214
215  In order for the tasks in the cluster to know other tasks' addresses,
216  a `tf.distribute.cluster_resolver.ClusterResolver` is required to be used
217  in coordinator, worker, and ps. The
218  `tf.distribute.cluster_resolver.ClusterResolver` is responsible for providing
219  the cluster information, as well as the task type and id of the current task.
220  See `tf.distribute.cluster_resolver.ClusterResolver` for more information.
221
222  If `TF_CONFIG` environment variable is set, a
223  `tf.distribute.cluster_resolver.TFConfigClusterResolver` should be used as
224  well.
225
226  Since there are assumptions in
227  `tf.distribute.experimental.ParameterServerStrategy` around the naming of the
228  task types, "chief", "ps", and "worker" should be used in the
229  `tf.distribute.cluster_resolver.ClusterResolver` to refer to the coordinator,
230  parameter servers, and workers, respectively.
231
232  The following example demonstrates setting `TF_CONFIG` for the task designated
233  as a parameter server (task type "ps") and index 1 (the second task), in a
234  cluster with 1 chief, 2 parameter servers, and 3 workers. Note that it needs
235  to be set before the use of
236  `tf.distribute.cluster_resolver.TFConfigClusterResolver`.
237
238  Example code for cluster setup:
239  ```python
240  os.environ['TF_CONFIG'] = '''
241  {
242    "cluster": {
243      "chief": ["chief.example.com:2222"],
244      "ps": ["ps0.example.com:2222", "ps1.example.com:2222"],
245      "worker": ["worker0.example.com:2222", "worker1.example.com:2222",
246                 "worker2.example.com:2222"]
247    },
248    "task": {
249      "type": "ps",
250      "index": 1
251    }
252  }
253  '''
254  ```
255
256  If you prefer to run the same binary for all tasks, you will need to let the
257  binary branch into different roles at the beginning of the program:
258  ```python
259  # If coordinator, create a strategy and start the training program.
260  if cluster_resolver.task_type == 'chief':
261    strategy = tf.distribute.experimental.ParameterServerStrategy(
262        cluster_resolver)
263    ...
264
265  # If worker/ps, create a server
266  elif cluster_resolver.task_type in ("worker", "ps"):
267    server = tf.distribute.Server(...)
268    ...
269  ```
270  Alternatively, you can also start a bunch of TensorFlow servers in advance and
271  connect to them later. The coordinator can be in the same cluster or on any
272  machine that has connectivity to workers and parameter servers. This is
273  covered in our guide and tutorial.
274
275  __Variable creation with `strategy.scope()`__
276
277  `tf.distribute.experimental.ParameterServerStrategy` follows the
278  `tf.distribute` API contract where variable creation is expected to be inside
279  the context manager returned by `strategy.scope()`, in order to be correctly
280  placed on parameter servers in a round-robin manner:
281
282  ```python
283  # In this example, we're assuming having 3 ps.
284  strategy = tf.distribute.experimental.ParameterServerStrategy(
285      cluster_resolver=...)
286  coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
287      strategy=strategy)
288
289  # Variables should be created inside scope to be placed on parameter servers.
290  # If created outside scope such as `v1` here, it would be placed on the
291  # coordinator.
292  v1 = tf.Variable(initial_value=0.0)
293
294  with strategy.scope():
295    v2 = tf.Variable(initial_value=1.0)
296    v3 = tf.Variable(initial_value=2.0)
297    v4 = tf.Variable(initial_value=3.0)
298    v5 = tf.Variable(initial_value=4.0)
299
300  # v2 through v5 are created in scope and are distributed on parameter servers.
301  # Default placement is round-robin but the order should not be relied on.
302  assert v2.device == "/job:ps/replica:0/task:0/device:CPU:0"
303  assert v3.device == "/job:ps/replica:0/task:1/device:CPU:0"
304  assert v4.device == "/job:ps/replica:0/task:2/device:CPU:0"
305  assert v5.device == "/job:ps/replica:0/task:0/device:CPU:0"
306  ```
307
308  See `distribute.Strategy.scope` for more information.
309
310  __Variable partitioning__
311
312  Having dedicated servers to store variables means being able to divide up, or
313  "shard" the variables across the ps. Partitioning large variable among ps is a
314  commonly used technique to boost training throughput and mitigate memory
315  constraints. It enables parallel computations and updates on different shards
316  of a variable, and often yields better load balancing across parameter
317  servers. Without sharding, models with large variables (e.g, embeddings) that
318  can't fit into one machine's memory would otherwise be unable to train.
319
320  With `tf.distribute.experimental.ParameterServerStrategy`, if a
321  `variable_partitioner` is provided to `__init__` and certain conditions are
322  satisfied, the resulting variables created in scope are sharded across the
323  parameter servers, in a round-robin fashion. The variable reference returned
324  from `tf.Variable` becomes a type that serves as the container of the sharded
325  variables. One can access `variables` attribute of this container for the
326  actual variable components. If building model with `tf.Module` or Keras,
327  the variable components are collected in the `variables` alike attributes.
328
329  It is recommended to use size-based partitioners like
330  `tf.distribute.experimental.partitioners.MinSizePartitioner` to avoid
331  partitioning small variables, which could have negative impact on model
332  training speed.
333
334  ```python
335  # Partition the embedding layer into 2 shards.
336  variable_partitioner = (
337    tf.distribute.experimental.partitioners.MinSizePartitioner(
338      min_shard_bytes=(256 << 10),
339      max_shards = 2))
340  strategy = tf.distribute.experimental.ParameterServerStrategy(
341    cluster_resolver=...,
342    variable_partitioner = variable_partitioner)
343  with strategy.scope():
344    embedding = tf.keras.layers.Embedding(input_dim=1024, output_dim=1024)
345  assert len(embedding.variables) == 2
346  assert isinstance(embedding.variables[0], tf.Variable)
347  assert isinstance(embedding.variables[1], tf.Variable)
348  assert embedding.variables[0].shape == (512, 1024)
349  assert embedding.variables[1].shape == (512, 1024)
350  ```
351
352  The sharded variable container can be converted to a `Tensor` via
353  `tf.convert_to_tensor`. This means the container can be directly used in most
354  Python Ops where such `Tensor` conversion automatically happens. For example,
355  in the above code snippet, `x * self.w` would implicitly apply the said tensor
356  conversion. Note that such conversion can be expensive, as the variable
357  components need to be transferred from multiple parameter servers to where
358  the value is used.
359
360  `tf.nn.embedding_lookup` on the other hand doesn't apply the tensor
361  conversion, and performs parallel lookups on the variable components instead.
362  This is crucial to scale up embedding lookups when the embedding table
363  variable is large.
364
365  When a partitioned variable is saved to a `SavedModel`, it will be saved as if
366  it is one single variable. This improves serving efficiency by eliminating
367  a number of Ops that handle the partiton aspects.
368
369  Known limitations of variable partitioning:
370
371  * Number of partitions must not change across Checkpoint saving/loading.
372
373  * After saving partitioned variables to a SavedModel, the SavedModel can't be
374    loaded via `tf.saved_model.load`.
375
376  * Partition variable doesn't directly work with `tf.GradientTape`, please use
377    the `variables` attributes to get the actual variable components and use
378    them in gradient APIs instead.
379
380  __Dataset preparation__
381
382  With `tf.distribute.experimental.ParameterServerStrategy`, a dataset is
383  created in each of the workers to be used for training. This is done by
384  creating a `dataset_fn` that takes no argument and returns a
385  `tf.data.Dataset`, and passing the `dataset_fn` into
386  `tf.distribute.experimental.coordinator.
387  ClusterCoordinator.create_per_worker_dataset`. We recommend the dataset to be
388  shuffled and repeated to have the examples run through the training as evenly
389  as possible.
390
391  ```python
392  def dataset_fn():
393    filenames = ...
394    dataset = tf.data.Dataset.from_tensor_slices(filenames)
395
396    # Dataset is recommended to be shuffled, and repeated.
397    return dataset.shuffle(buffer_size=...).repeat().batch(batch_size=...)
398
399  coordinator =
400      tf.distribute.experimental.coordinator.ClusterCoordinator(strategy=...)
401  distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn)
402  ```
403
404  __Limitations__
405
406  * `tf.distribute.experimental.ParameterServerStrategy` in TF2 is experimental,
407  and the API is subject to further changes.
408
409  * When using `Model.fit`, `tf.distribute.experimental.ParameterServerStrategy`
410  must be used with a `tf.keras.utils.experimental.DatasetCreator`, and
411  `steps_per_epoch` must be specified.
412  """
413
414  # pyformat: disable
415  def __init__(self, cluster_resolver, variable_partitioner=None):
416    """Initializes the TF2 parameter server strategy.
417
418    This initializes the `tf.distribute.experimental.ParameterServerStrategy`
419    object to be ready for use with
420    `tf.distribute.experimental.coordinator.ClusterCoordinator`.
421
422    Args:
423      cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver`
424        object.
425      variable_partitioner:
426        a `distribute.experimental.partitioners.Partitioner` that specifies
427        how to partition variables. If `None`, variables will not be
428        partitioned.
429
430        * Predefined partitioners in `tf.distribute.experimental.partitioners`
431        can be used for this argument. A commonly used partitioner is
432        `MinSizePartitioner(min_shard_bytes = 256 << 10, max_shards = num_ps)`,
433        which allocates at least 256K per shard, and each ps gets at most one
434        shard.
435
436        * `variable_partitioner` will be called for each variable created under
437        strategy `scope` to instruct how the variable should be partitioned.
438        Variables that have only one partition along the partitioning axis
439        (i.e., no need for partition) will be created as a normal `tf.Variable`.
440
441        * Only the first / outermost axis partitioning is supported.
442
443        * Div partition strategy is used to partition variables. Assuming we
444        assign consecutive integer ids along the first axis of a variable, then
445        ids are assigned to shards in a contiguous manner, while attempting to
446        keep each shard size identical. If the ids do not evenly divide the
447        number of shards, each of the first several shards will be assigned one
448        more id. For instance, a variable whose first dimension is 13 has 13
449        ids, and they are split across 5 shards as:
450        `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`.
451
452        * Variables created under `strategy.extended.colocate_vars_with` will
453        not be partitioned.
454    """
455    # pyformat: enable
456    self._cluster_resolver = cluster_resolver
457
458    self._verify_args_and_config(cluster_resolver)
459    self._cluster_coordinator = None
460    logging.info(
461        "`tf.distribute.experimental.ParameterServerStrategy` is initialized "
462        "with cluster_spec: %s", cluster_resolver.cluster_spec())
463
464    # TODO(b/167894802): Make coordinator, worker, and ps names customizable.
465    self._connect_to_cluster(coordinator_name="chief")
466    self._extended = ParameterServerStrategyV2Extended(self, cluster_resolver,
467                                                       variable_partitioner)
468    super(ParameterServerStrategyV2, self).__init__(self._extended)
469    distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
470        "ParameterServerStrategy")
471    self._should_use_with_coordinator = True
472    # Used while constructing distributed iterators.
473    self._canonicalize_devices = False
474
475  def _connect_to_cluster(self, coordinator_name):
476    if coordinator_name in ["worker", "ps"]:
477      raise ValueError("coordinator name should not be 'worker' or 'ps'.")
478    cluster_spec = self._cluster_resolver.cluster_spec()
479    self._num_workers = len(cluster_spec.as_dict().get("worker", ()))
480    self._num_ps = len(cluster_spec.as_dict().get("ps", ()))
481
482    device_filters = server_lib.ClusterDeviceFilters()
483    # For any worker, only the devices on ps and coordinator nodes are visible
484    for i in range(self._num_workers):
485      device_filters.set_device_filters(
486          "worker", i, ["/job:ps", "/job:%s" % coordinator_name])
487    # Similarly for any ps, only the devices on workers and coordinator are
488    # visible
489    for i in range(self._num_ps):
490      device_filters.set_device_filters(
491          "ps", i, ["/job:worker", "/job:%s" % coordinator_name])
492
493    # Allow at most one outstanding RPC for each worker at a certain time. This
494    # is to simplify worker failure handling in the runtime
495    os.environ["TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE"] = "False"
496
497    logging.info("%s is now connecting to cluster with cluster_spec: %r",
498                 self.__class__.__name__, cluster_spec)
499    remote.connect_to_cluster(
500        cluster_spec,
501        job_name=coordinator_name,
502        protocol=self._cluster_resolver.rpc_layer,
503        cluster_device_filters=device_filters)
504
505    distribute_lib.distribution_strategy_replica_gauge.get_cell(
506        "ps_strategy_num_workers").set(self._num_workers)
507    distribute_lib.distribution_strategy_replica_gauge.get_cell(
508        "ps_strategy_num_ps").set(self._num_ps)
509
510  def _verify_args_and_config(self, cluster_resolver):
511    if not cluster_resolver.cluster_spec():
512      raise ValueError("Cluster spec must be non-empty in "
513                       "`tf.distribute.cluster_resolver.ClusterResolver`.")
514    cluster_spec = cluster_resolver.cluster_spec()
515
516    # The following checks if the task types are allowed (chief, ps, worker).
517    multi_worker_util._validate_cluster_spec(  # pylint: disable=protected-access
518        cluster_spec,
519        cluster_resolver.task_type,
520        cluster_resolver.task_id)
521
522    if multi_worker_util.task_count(cluster_spec, "ps") < 1:
523      raise ValueError("There must be at least one ps.")
524
525    if multi_worker_util.task_count(cluster_spec, "worker") < 1:
526      raise ValueError("There must be at least one worker.")
527
528
529class ParameterServerStrategyV2Extended(
530    parameter_server_strategy.ParameterServerStrategyExtended):
531  """Extended class for ParameterServerStrategyV2.
532
533  Please see `tf.distribute.StrategyExtended` doc for more information.
534  """
535
536  def __init__(self, container_strategy, cluster_resolver,
537               variable_partitioner):
538    """Initialization of ParameterServerStrategyV2Extended."""
539    super(ParameterServerStrategyV2Extended, self).__init__(container_strategy)
540    self._num_ps = len(cluster_resolver.cluster_spec().as_dict().get("ps", []))
541    self._num_workers = len(cluster_resolver.cluster_spec().as_dict().get(
542        "worker", []))
543    self._variable_count = 0
544
545    self._variable_partitioner = variable_partitioner
546    # The following two attrs are to verify that `ParameterServerStrategy`
547    # methods are properly used with a `ClusterCoordinator`.
548    self._used_with_coordinator = False
549    self._being_scheduled = False
550    self._set_num_gpus()
551    distribute_lib.distribution_strategy_replica_gauge.get_cell(
552        "num_gpus_per_worker").set(self._num_gpus_per_worker)
553
554    # Don't canonicalize the devices here since this code is executed on Chief,
555    # but we want the reduce evaluation to be done on each worker. Placer will
556    # automatically choose the right device based on current context.
557    # TODO(ishark): Use select_cross_device_ops instead.
558    self._cross_device_ops = cross_device_ops_lib.ReductionToOneDevice(
559        reduce_to_device="/device:CPU:0")
560    self._cross_device_ops._canonicalize_devices = False  # pylint: disable=protected-access
561    self._allow_run_without_coordinator = False
562
563  def _set_num_gpus(self):
564    devices = config.list_logical_devices("GPU")
565    per_worker_gpus = {}
566    for d in devices:
567      d_spec = tf_device.DeviceSpec.from_string(d.name)
568      if d_spec.device_type == "GPU" and d_spec.job == "worker":
569        # TODO(b/167894802): update if worker name is customizable
570        job_spec = d_spec.replace(device_type=None, device_index=None)
571        per_worker_gpus[job_spec] = per_worker_gpus.get(job_spec, 0) + 1
572
573    num_gpus = 0
574    for _, count in per_worker_gpus.items():
575      if num_gpus > 0 and count != num_gpus:
576        raise ValueError("Mismatched number of GPUs per worker")
577      num_gpus = count
578
579    self._num_gpus_per_worker = num_gpus
580    logging.info(f"Number of GPUs on workers: {self._num_gpus_per_worker}")
581
582  @property
583  def _num_replicas_in_sync(self):
584    return self._num_gpus_per_worker or 1
585
586  def _create_var_creator(self, next_creator, **kwargs):
587    aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
588
589    def var_creator(**kwargs):
590      """Create an AggregatingVariable."""
591      # Create and wrap the variable.
592      v = next_creator(**kwargs)
593      wrapped_v = ps_values.CachingVariable(v)
594      wrapped = ps_values.AggregatingVariable(self._container_strategy(),
595                                              wrapped_v, aggregation)
596      return wrapped
597
598    if self._num_replicas_in_sync > 1:
599      if aggregation not in (
600          vs.VariableAggregation.NONE,
601          vs.VariableAggregation.SUM,
602          vs.VariableAggregation.MEAN,
603          vs.VariableAggregation.ONLY_FIRST_REPLICA
604      ):
605        raise ValueError("Invalid variable aggregation mode: " + aggregation +
606                         " for variable: " + kwargs["name"])
607      return var_creator
608    else:
609      def variable_creator_single_replica(**kwargs):
610        v = next_creator(**kwargs)
611        return ps_values.CachingVariable(v)
612      return variable_creator_single_replica
613
614  def _create_variable(self, next_creator, **kwargs):
615    """Implements StrategyExtendedV2._create_variable.
616
617    Creates a `Variable` or a `ShardedVariable`. A `ShardedVariable` will be
618    created if satisfying all the following criteria:
619      1. `self._variable_partitioner` results in more than one partition on the
620         first axis.
621      2. variable's rank is greater than 0.
622      3. variable is not colocated with another variable.
623    Otherwise a `Variable` will be created.
624
625    Args:
626      next_creator: See `variable_scope.variable_creator_scope`; the next
627        creator in the chain.
628      **kwargs: Passed through to the next creator.
629
630    Returns:
631      A `Variable` or `ShardedVariable`.
632    """
633
634    var_creator = self._create_var_creator(next_creator, **kwargs)
635    if "colocate_with" in kwargs:  # Never partition colocated_with variables.
636      colocate_with = kwargs["colocate_with"]
637      # Clear the variable scope to avoid possible conflicts between device
638      # scope and colocation scope.
639      with ops.device(None):
640        with ops.colocate_with(colocate_with):
641          var = var_creator(**kwargs)
642          logging.debug(
643              "Creating variable (name:%s, shape:%r) that colocates with %s",
644              var.name, var.shape, kwargs["colocate_with"].name)
645          return var
646
647    if self._variable_partitioner is None:
648      return self._create_variable_round_robin(var_creator, **kwargs)
649
650    name = kwargs.get("name", None)
651    initial_value = kwargs.get("initial_value", None)
652    if initial_value is None:
653      raise ValueError(
654          "It looks like you are using `ParameterServerStrategy` with a "
655          "`variable_partitioner`, and trying to create a variable without "
656          "specifying `initial_value`. This is not allowed. Please specify the "
657          "`initial_value`. This can also happen if you are trying to load a "
658          "saved_model within a `ParameterServerStrategy` scope. Loading a "
659          "saved_model with `variable_partitioner` is not supported.")
660
661    # Two cases where initial_value can be a callable:
662    #   1. initial_value is passed as a callable, e.g, an `initializer` class.
663    #   2. restoring from checkpoint, initial_value is a
664    #     "CheckpointInitialValueCallable".
665    init_from_fn = callable(initial_value)
666
667    dtype = kwargs.get("dtype", None)
668    shape = kwargs.get("shape", None)
669    if init_from_fn and (shape is None or dtype is None):
670      init_from_fn = False
671      initial_value = initial_value()
672    if not init_from_fn:
673      # The initial_value is created on coordinator, it will need to be sent to
674      # ps for variable initialization, which can be inefficient and can
675      # potentially hit the 2GB limit on protobuf serialization.
676      initial_value = ops.convert_to_tensor(initial_value, dtype=dtype)
677      dtype = initial_value.dtype
678      shape = initial_value.shape
679    else:
680      shape = tensor_shape.as_shape(shape)
681
682    if shape.rank == 0:  # Skip partitioning rank-0 variable.
683      return self._create_variable_round_robin(var_creator, **kwargs)
684
685    num_partitions = self._variable_partitioner(shape=shape, dtype=dtype)
686    if not num_partitions or num_partitions[0] == 0 or any(
687        v != 1 for v in num_partitions[1:]):
688      raise ValueError(
689          "variable_partitioner must return a list/tuple whose elements are 1"
690          " besides the first element (non-zero), got: %r" % num_partitions)
691
692    if num_partitions[0] == 1:  # no partition
693      return self._create_variable_round_robin(var_creator, **kwargs)
694
695    # Use "div" partition strategy to partition the variable.
696    num_partitions = min(num_partitions[0], shape[0])
697    base = shape[0] // num_partitions
698    extra = shape[0] % num_partitions
699    # An example: num_partitions=4, shape[0]=10, partitions: [3, 3, 2, 2]
700    # offsets: [0, 3, 6, 8, 10]
701    offsets = []
702    for i in range(num_partitions):
703      if i == 0:
704        offsets.append(0)
705      else:
706        prev_shard_size = base + (1 if i - 1 < extra else 0)
707        offsets.append(offsets[i - 1] + prev_shard_size)
708    offsets.append(shape[0])
709
710    def init_shard_fn(shard_index):
711      if not init_from_fn:
712        logging.log_if(
713            logging.WARN, _INEFFICIENT_INIT_WARNING % name, shard_index == 0 and
714            shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS)
715        return initial_value[offsets[shard_index]:offsets[shard_index + 1]]
716      partition_shape = (offsets[shard_index + 1] -
717                         offsets[shard_index],) + shape[1:]
718      partition_offset = (offsets[shard_index],) + (0,) * len(shape[1:])
719      arg_spec = tf_inspect.getfullargspec(initial_value)
720      if ("shard_info" not in arg_spec.args and
721          "shard_info" not in arg_spec.kwonlyargs):
722        try:
723          value = initial_value(
724              partition_shape=partition_shape,
725              partition_offset=partition_offset)
726        except (TypeError, ValueError):
727          # TypeError: Initializer doesn't accept kwargs
728          # ValueError: Initializer doesn't accept partition kwargs
729          # In both cases we go ahead creating the full value and then slice.
730          value = initial_value()
731
732        if value.shape == partition_shape:
733          # Initializer supports partition: value is the partition value.
734          return value
735        else:
736          # Initializer doesn't support partition: value is the full value
737          # and needs to be sliced to get the partition value.
738          logging.log_if(
739              logging.WARN, _INEFFICIENT_INIT_WARNING % name,
740              shard_index == 0 and
741              shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS)
742          return value[offsets[shard_index]:offsets[shard_index + 1]]
743      else:
744        # For compatibility with `CheckpointInitialValueCallable`.
745        return initial_value(
746            shard_info=trackable.ShardInfo(
747                shape=tensor_shape.as_shape(partition_shape),
748                offset=partition_offset))
749
750    var_list = []
751    for i in range(num_partitions):
752      kwargs["shape"] = (offsets[i + 1] - offsets[i],) + shape[1:]
753      kwargs["initial_value"] = lambda: init_shard_fn(i)
754      if name is not None:
755        kwargs["name"] = "{}/part_{}".format(name, i)
756      var_list.append(self._create_variable_round_robin(var_creator, **kwargs))
757
758    result = sharded_variable.ShardedVariable(var_list)
759    return result
760
761  def _create_variable_round_robin(self, next_creator, **kwargs):
762    # Clear the colocation scope to avoid possible conflicts between device
763    # scope and colocation scope.
764    with ops.colocate_with(None, ignore_existing=True):
765      # Explicitly set CPU:0 device for PS in case create variable is called
766      # inside replica_fn and worker has with GPU:0 scope.
767      with ops.device("/job:ps/task:%d/device:CPU:0" %
768                      (self._variable_count % self._num_ps)):
769        var = next_creator(**kwargs)
770        logging.debug(
771            "Creating variable (name:%s, shape:%r) on "
772            "/job:ps/task:%d/device:CPU:0",
773            var.name, var.shape, (self._variable_count % self._num_ps))
774        self._variable_count += 1
775        return var
776
777  def _assert_used_with_cluster_coordinator(self):
778    if (not self._used_with_coordinator and
779        not self._allow_run_without_coordinator):
780      raise NotImplementedError(
781          "`tf.distribute.experimental.ParameterServerStrategy` must be used "
782          "with `tf.distribute.experimental.coordinator.ClusterCoordinator` in "
783          "a custom training loop. If you are using `Model.fit`, please supply "
784          "a dataset function directly to a "
785          "`tf.keras.utils.experimental.DatasetCreator` instead.")
786
787  def _assert_being_scheduled_by_cluster_coordinator(self):
788    if not self._being_scheduled and not self._allow_run_without_coordinator:
789      logging.warning(
790          "It is detected that a function used with "
791          "`tf.distribute.experimental.ParameterServerStrategy` "
792          "is executed locally on the coordinator. This is inefficient but may "
793          "be valid for one-off tasks such as inferring output signature. "
794          "To properly distribute functions to run on workers, `run` or "
795          "`reduce` should be used within a function passed to `"
796          "tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`."
797      )
798
799  # options is not used right now. But we may want to support options while
800  # creating InputWorkers in future, similar to MirroredStrategy.
801  def _input_workers_with_options(self, options=None):
802    input_workers_devices = (
803        ("/device:CPU:0", self.worker_devices),)
804    return input_lib.InputWorkers(
805        input_workers_devices, canonicalize_devices=False)
806
807  def _experimental_distribute_dataset(self, dataset, options):
808    input_workers_devices = self._input_workers_with_options()
809
810    # If this DistributedDataset is created outside ClusterCoordinator, i,e,
811    # outside a tf.function, we don't build its underlying datasets immediately
812    # until it is passed to ClusterCoordinator.create_per_worker_dataset.
813    return input_lib.get_distributed_dataset(
814        dataset,
815        input_workers_devices,
816        self._container_strategy(),
817        num_replicas_in_sync=self._num_replicas_in_sync,
818        options=options,
819        build=ops.inside_function())  # will be built by ClusterCoordinator
820
821  def _distribute_datasets_from_function(self, dataset_fn, options):
822    # There is no synchronization beyond a worker and thus, the number of
823    # input pipelines in sync is only 1 per worker.
824    input_pipeline_id_in_sync = 0
825    num_input_pipelines_in_sync = 1
826
827    input_context = distribute_lib.InputContext(
828        num_input_pipelines=num_input_pipelines_in_sync,
829        input_pipeline_id=input_pipeline_id_in_sync,
830        num_replicas_in_sync=self._num_replicas_in_sync)
831
832    # If this DistributedDatasetFromFunction is created outside
833    # ClusterCoordinator, i,e, outside a tf.function, we don't build its
834    # underlying datasets immediately until it is passed to
835    # ClusterCoordinator.create_per_worker_dataset.
836    return input_lib.get_distributed_datasets_from_function(
837        dataset_fn,
838        self._input_workers_with_options(options),
839        [input_context],
840        self._container_strategy(),
841        options=options,
842        build=ops.inside_function())  # will be built by ClusterCoordinator
843
844  @property
845  def worker_devices(self):
846    num_gpus = self._num_gpus_per_worker
847    if num_gpus > 0:
848      compute_devices = tuple("/device:GPU:%d" % (i,) for i in range(num_gpus))
849    else:
850      compute_devices = ("/device:CPU:0",)
851    return compute_devices
852
853  def _call_for_each_replica(self, fn, args, kwargs):
854    self._assert_being_scheduled_by_cluster_coordinator()
855
856    return mirrored_run.call_for_each_replica(self._container_strategy(), fn,
857                                              args, kwargs)
858
859  def _reduce(self, reduce_op, value):
860    self._assert_being_scheduled_by_cluster_coordinator()
861    dst = device_util.current() or self._default_device or "/device:CPU:0"
862    destinations = device_util.canonicalize_without_job_and_task(dst)
863    result = self._local_results(
864        self.reduce_to(reduce_op, value, destinations))[0]
865    return result
866
867  def _reduce_to(self, reduce_op, value, destinations, options):
868    self._assert_being_scheduled_by_cluster_coordinator()
869
870    def get_values(x):
871      if isinstance(x, values.DistributedValues):
872        return self._cross_device_ops.reduce(
873            reduce_op, x, destinations=destinations)  # pylint: disable=protected-access
874      return x
875
876    return nest.map_structure(get_values, value)
877
878
879# The warning that will be logged if the way we initialize sharded variables
880# is memory-inefficient.
881_INEFFICIENT_INIT_WARNING = (
882    "Large variable %s is partitioned but not initialized in a "
883    "memory-efficient way. On each shard, the full value is first being "
884    "created and then sliced into smaller values. To reduce the memory "
885    "footprint, explicitly specify `dtype` and `shape` when creating "
886    "variables, and use `tf.initializers` to initialize the variable. "
887    "Note that some initializers (e.g., orthogonal) don't support "
888    "memory-efficient initialization and there is not much you can do here.")
889
890_LARGE_VARIABLE_NUM_ELEMENTS = 1e9
891