• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python3
2# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Parameter server strategy V2 class.
17
18This is currently under development and the API is subject to change.
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import os
26
27from tensorflow.python.distribute import distribute_lib
28from tensorflow.python.distribute import distribute_utils
29from tensorflow.python.distribute import multi_worker_util
30from tensorflow.python.distribute import parameter_server_strategy
31from tensorflow.python.distribute import sharded_variable
32from tensorflow.python.eager import remote
33from tensorflow.python.framework import constant_op
34from tensorflow.python.framework import dtypes
35from tensorflow.python.framework import ops
36from tensorflow.python.framework import tensor_shape
37from tensorflow.python.platform import tf_logging as logging
38from tensorflow.python.training import server_lib
39from tensorflow.python.training.tracking import base as trackable
40from tensorflow.python.util import tf_inspect
41from tensorflow.python.util.tf_export import tf_export
42
43ALLOWED_TASK_TYPES = ("chief", "worker", "ps")
44
45
46@tf_export("distribute.experimental.ParameterServerStrategy", v1=[])
47class ParameterServerStrategyV2(distribute_lib.Strategy):
48  """An multi-worker tf.distribute strategy with parameter servers.
49
50  Parameter server training is a common data-parallel method to scale up a
51  machine learning model on multiple machines. A parameter server training
52  cluster consists of workers and parameter servers. Variables are created on
53  parameter servers and they are read and updated by workers in each step.
54  By default, workers read and update these variables independently without
55  synchronizing with each other. Under this configuration, it is known as
56  asynchronous training.
57
58  In TensorFlow 2, we recommend an architecture based on central coordination
59  for parameter server training. Each worker and parameter server runs a
60  `tf.distribute.Server`, and on top of that, a coordinator task is responsible
61  for creating resources on workers and parameter servers, dispatching
62  functions, and coordinating the training. The coordinator uses a
63  `tf.distribute.experimental.coordinator.ClusterCoordinator` to coordinate the
64  cluster, and a `tf.distribute.experimental.ParameterServerStrategy` to define
65  variables on parameter servers and computation on workers.
66
67  For the training to work, the coordinator dispatches `tf.function`s to be
68  executed on remote workers. Upon receiving requests from the coordinator, a
69  worker executes the `tf.function` by reading the variables from parameter
70  servers, executing the ops, and updating the variables on the parameter
71  servers. Each of the worker only processes the requests from the coordinator,
72  and communicates with parameter servers, without direct interactions with
73  other workers in the cluster.
74
75  As a result, failures of some workers do not prevent the cluster from
76  continuing the work, and this allows the cluster to train with instances that
77  can be occasionally unavailable (e.g. preemptible or spot instances). The
78  coordinator and parameter servers though, must be available at all times for
79  the cluster to make progress.
80
81  Note that the coordinator is not one of the training workers. Instead, it
82  creates resources such as variables and datasets, dispatchs `tf.function`s,
83  saves checkpoints and so on. In addition to workers, parameter servers and
84  the coordinator, an optional evaluator can be run on the side that
85  periodically reads the checkpoints saved by the coordinator and runs
86  evaluations against each checkpoint.
87
88  `tf.distribute.experimental.ParameterServerStrategy` has to work in
89  conjunction with a `tf.distribute.experimental.coordinator.ClusterCoordinator`
90  object. Standalone usage of
91  `tf.distribute.experimental.ParameterServerStrategy` without central
92  coordination is not supported at this time.
93
94  __Example code for coordinator__
95
96  Here's an example usage of the API, with a custom training loop to train a
97  model. This code snippet is intended to be run on (the only) one task that
98  is designated as the coordinator. Note that `cluster_resolver`,
99  `variable_partitioner`, and `dataset_fn` arguments are explained in the
100  following "Cluster setup", "Variable partitioning", and "Dataset preparation"
101  sections.
102
103  ```python
104  # Set the environment variable to allow reporting worker and ps failure to the
105  # coordinator. This a short-term workaround.
106  os.environ["GRPC_FAIL_FAST"] = "use_caller"
107
108  # Prepare a strategy to use with the cluster and variable partitioning info.
109  strategy = tf.distribute.experimental.ParameterServerStrategy(
110      cluster_resolver=...,
111      variable_partitioner=...)
112  coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
113      strategy=strategy)
114
115  # Prepare a distribute dataset that will place datasets on the workers.
116  distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn=...)
117
118  with strategy.scope():
119    model = ...
120    optimizer, metrics = ...  # Keras optimizer/metrics are great choices
121    checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
122    checkpoint_manager = tf.train.CheckpointManager(
123        checkpoint, checkpoint_dir, max_to_keep=2)
124    # `load_checkpoint` infers initial epoch from `optimizer.iterations`.
125    initial_epoch = load_checkpoint(checkpoint_manager) or 0
126
127  @tf.function
128  def worker_fn(iterator):
129
130    def replica_fn(inputs):
131      batch_data, labels = inputs
132      # calculate gradient, applying gradient, metrics update etc.
133
134    strategy.run(replica_fn, args=(next(iterator),))
135
136  for epoch in range(initial_epoch, num_epoch):
137    distributed_iterator = iter(distributed_dataset)  # Reset iterator state.
138    for step in range(steps_per_epoch):
139
140      # Asynchronously schedule the `worker_fn` to be executed on an arbitrary
141      # worker. This call returns immediately.
142      coordinator.schedule(worker_fn, args=(distributed_iterator,))
143
144    # `join` blocks until all scheduled `worker_fn`s finish execution. Once it
145    # returns, we can read the metrics and save checkpoints as needed.
146    coordinator.join()
147    logging.info('Metric result: %r', metrics.result())
148    train_accuracy.reset_states()
149    checkpoint_manager.save()
150  ```
151
152  __Example code for worker and parameter servers__
153
154  In addition to the coordinator, there should be tasks designated as
155  "worker" or "ps". They should run the following code to start a TensorFlow
156  server, waiting for coordinator's requests:
157
158  ```python
159  # Set the environment variable to allow reporting worker and ps failure to the
160  # coordinator.
161  os.environ["GRPC_FAIL_FAST"] = "use_caller"
162
163  # Provide a `tf.distribute.cluster_resolver.ClusterResolver` that serves
164  # the cluster information. See below "Cluster setup" section.
165  cluster_resolver = ...
166
167  server = tf.distribute.Server(
168      cluster_resolver.cluster_spec(),
169      job_name=cluster_resolver.task_type,
170      task_index=cluster_resolver.task_id,
171      protocol="grpc")
172
173  # Blocking the process that starts a server from exiting.
174  server.join()
175  ```
176
177  __Cluster setup__
178
179  In order for the tasks in the cluster to know other tasks' addresses,
180  a `tf.distribute.cluster_resolver.ClusterResolver` is required to be used
181  in coordinator, worker, and ps. The
182  `tf.distribute.cluster_resolver.ClusterResolver` is responsible for providing
183  the cluster information, as well as the task type and id of the current task.
184  See `tf.distribute.cluster_resolver.ClusterResolver` for more information.
185
186  If `TF_CONFIG` environment variable is set, a
187  `tf.distribute.cluster_resolver.TFConfigClusterResolver` should be used as
188  well.
189
190  Since there are assumptions in
191  `tf.distribute.experimental.ParameterServerStrategy` around the naming of the
192  task types, "chief", "ps", and "worker" should be used in the
193  `tf.distribute.cluster_resolver.ClusterResolver` to refer to the coordinator,
194  parameter servers, and workers, respectively.
195
196  The following example demonstrates setting `TF_CONFIG` for the task designated
197  as a parameter server (task type "ps") and index 1 (the second task), in a
198  cluster with 1 chief, 2 parameter servers, and 3 workers. Note that it needs
199  to be set before the use of
200  `tf.distribute.cluster_resolver.TFConfigClusterResolver`.
201
202  Example code for cluster setup:
203  ```python
204  os.environ['TF_CONFIG'] = '''
205  {
206    "cluster": {
207      "chief": ["chief.example.com:2222"],
208      "ps": ["ps0.example.com:2222", "ps1.example.com:2222"],
209      "worker": ["worker0.example.com:2222", "worker1.example.com:2222",
210                 "worker2.example.com:2222"]
211    },
212    "task": {
213      "type": "ps",
214      "index": 1
215    }
216  }
217  '''
218  ```
219
220  If you prefer to run the same binary for all tasks, you will need to let the
221  binary branch into different roles at the beginning of the program:
222  ```python
223  os.environ["GRPC_FAIL_FAST"] = "use_caller"
224  cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
225
226  # If coordinator, create a strategy and start the training program.
227  if cluster_resolver.task_type == 'chief':
228    strategy = tf.distribute.experimental.ParameterServerStrategy(
229        cluster_resolver)
230    ...
231
232  # If worker/ps, create a server
233  elif cluster_resolver.task_type in ("worker", "ps"):
234    server = tf.distribute.Server(...)
235    ...
236  ```
237  Alternatively, you can also start a bunch of TensorFlow servers in advance and
238  connect to them later. The coordinator can be in the same cluster or on any
239  machine that has connectivity to workers and parameter servers. This is
240  covered in our guide and tutorial.
241
242  __Variable creation with `strategy.scope()`__
243
244  `tf.distribute.experimental.ParameterServerStrategy` follows the
245  `tf.distribute` API contract where variable creation is expected to be inside
246  the context manager returned by `strategy.scope()`, in order to be correctly
247  placed on parameter servers in a round-robin manner:
248
249  ```python
250  # In this example, we're assuming having 3 ps.
251  strategy = tf.distribute.experimental.ParameterServerStrategy(
252      cluster_resolver=...)
253  coordinator = tf.distribute.experimental.coordinator.ClusterCoordinator(
254      strategy=strategy)
255
256  # Variables should be created inside scope to be placed on parameter servers.
257  # If created outside scope such as `v1` here, it would be placed on the
258  # coordinator.
259  v1 = tf.Variable(initial_value=0.0)
260
261  with strategy.scope():
262    v2 = tf.Variable(initial_value=1.0)
263    v3 = tf.Variable(initial_value=2.0)
264    v4 = tf.Variable(initial_value=3.0)
265    v5 = tf.Variable(initial_value=4.0)
266
267  # v2 through v5 are created in scope and are distributed on parameter servers.
268  # Default placement is round-robin but the order should not be relied on.
269  assert v2.device == "/job:ps/replica:0/task:0/device:CPU:0"
270  assert v3.device == "/job:ps/replica:0/task:1/device:CPU:0"
271  assert v4.device == "/job:ps/replica:0/task:2/device:CPU:0"
272  assert v5.device == "/job:ps/replica:0/task:0/device:CPU:0"
273  ```
274
275  See `distribute.Strategy.scope` for more information.
276
277  __Variable partitioning__
278
279  Having dedicated servers to store variables means being able to divide up, or
280  "shard" the variables across the ps. Partitioning large variable among ps is a
281  commonly used technique to boost training throughput and mitigate memory
282  constraints. It enables parallel computations and updates on different shards
283  of a variable, and often yields better load balancing across parameter
284  servers. Without sharding, models with large variables (e.g, embeddings) that
285  can't fit into one machine's memory would otherwise be unable to train.
286
287  With `tf.distribute.experimental.ParameterServerStrategy`, if a
288  `variable_partitioner` is provided to `__init__` and certain conditions are
289  satisfied, the resulting variables created in scope are sharded across the
290  parameter servers, in a round-robin fashion. The variable reference returned
291  from `tf.Variable` becomes a type that serves as the container of the sharded
292  variables. One can access `variables` attribute of this container for the
293  actual variable components. If building model with `tf.Module` or Keras,
294  the variable components are collected in the `variables` alike attributes.
295
296
297  ```python
298  class Dense(tf.Module):
299    def __init__(self, name=None):
300      super().__init__(name=name)
301      self.w = tf.Variable(tf.random.normal([100, 10]), name='w')
302
303    def __call__(self, x):
304      return x * self.w
305
306  # Partition the dense layer into 2 shards.
307  variable_partitioner = (
308    tf.distribute.experimental.partitioners.FixedShardsPartitioner(
309      num_shards = 2))
310  strategy = tf.distribute.experimental.ParameterServerStrategy(
311    cluster_resolver=...,
312    variable_partitioner = variable_partitioner)
313  with strategy.scope():
314    dense = Dense()
315  assert len(dense.variables) == 2
316  assert isinstance(dense.variables[0], tf.Variable)
317  assert isinstance(dense.variables[1], tf.Variable)
318  assert dense.variables[0].shape == (50, 10)
319  assert dense.variables[1].shape == (50, 10)
320  ```
321
322  The sharded variable container can be converted to a `Tensor` via
323  `tf.convert_to_tensor`. This means the container can be directly used in most
324  Python Ops where such `Tensor` conversion automatically happens. For example,
325  in the above code snippet, `x * self.w` would implicitly apply the said tensor
326  conversion. Note that such conversion can be expensive, as the variable
327  components need to be transferred from multiple parameter servers to where
328  the value is used.
329
330  `tf.nn.embedding_lookup` on the other hand doesn't apply the tensor
331  conversion, and performs parallel lookups on the variable components instead.
332  This is crucial to scale up embedding lookups when the embedding table
333  variable is large.
334
335  When a partitioned variable is saved to a `SavedModel`, it will be saved as if
336  it is one single variable. This improves serving efficiency by eliminating
337  a number of Ops that handle the partiton aspects.
338
339  Known limitations of variable partitioning:
340
341  * Number of partitions must not change across Checkpoint saving/loading.
342
343  * After saving partitioned variables to a SavedModel, the SavedModel can't be
344    loaded via `tf.saved_model.load`.
345
346  * Partition variable doesn't directly work with `tf.GradientTape`, please use
347    the `variables` attributes to get the actual variable components and use
348    them in gradient APIs instead.
349
350  __Dataset preparation__
351
352  With `tf.distribute.experimental.ParameterServerStrategy`, a dataset is
353  created in each of the workers to be used for training. This is done by
354  creating a `dataset_fn` that takes no argument and returns a
355  `tf.data.Dataset`, and passing the `dataset_fn` into
356  `tf.distribute.experimental.coordinator.
357  ClusterCoordinator.create_per_worker_dataset`. We recommend the dataset to be
358  shuffled and repeated to have the examples run through the training as evenly
359  as possible.
360
361  ```python
362  def dataset_fn():
363    filenames = ...
364    dataset = tf.data.Dataset.from_tensor_slices(filenames)
365
366    # Dataset is recommended to be shuffled, and repeated.
367    return dataset.shuffle(buffer_size=...).repeat().batch(batch_size=...)
368
369  coordinator =
370      tf.distribute.experimental.coordinator.ClusterCoordinator(strategy=...)
371  distributed_dataset = coordinator.create_per_worker_dataset(dataset_fn)
372  ```
373
374  __Limitations__
375
376  * `tf.distribute.experimental.ParameterServerStrategy` in TF2 is experimental,
377  and the API is subject to further changes.
378
379  * `tf.distribute.experimental.ParameterServerStrategy` does not yet support
380  training with GPU(s). This is a feature request being developed.
381
382  * `tf.distribute.experimental.ParameterServerStrategy` only supports
383  [custom training loop
384  API](https://www.tensorflow.org/tutorials/distribute/custom_training)
385  currently in TF2. Usage of it with Keras `compile`/`fit` API is being
386  developed.
387
388  * `tf.distribute.experimental.ParameterServerStrategy` must be used with
389  `tf.distribute.experimental.coordinator.ClusterCoordinator`.
390  """
391
392  # pyformat: disable
393  def __init__(self, cluster_resolver, variable_partitioner=None):
394    """Initializes the TF2 parameter server strategy.
395
396    This initializes the `tf.distribute.experimental.ParameterServerStrategy`
397    object to be ready for use with
398    `tf.distribute.experimental.coordinator.ClusterCoordinator`.
399
400    Args:
401      cluster_resolver: a `tf.distribute.cluster_resolver.ClusterResolver`
402        object.
403      variable_partitioner:
404        a `distribute.experimental.partitioners.Partitioner` that specifies
405        how to partition variables. If `None`, variables will not be
406        partitioned.
407
408        * Predefined partitioners in `tf.distribute.experimental.partitioners`
409        can be used for this argument. A commonly used partitioner is
410        `MinSizePartitioner(min_shard_bytes = 256 << 10, max_shards = num_ps)`,
411        which allocates at least 256K per shard, and each ps gets at most one
412        shard.
413
414        * `variable_partitioner` will be called for each variable created under
415        strategy `scope` to instruct how the variable should be partitioned.
416        Variables that have only one partition along the partitioning axis
417        (i.e., no need for partition) will be created as a normal `tf.Variable`.
418
419        * Only the first / outermost axis partitioning is supported.
420
421        * Div partition strategy is used to partition variables. Assuming we
422        assign consecutive integer ids along the first axis of a variable, then
423        ids are assigned to shards in a contiguous manner, while attempting to
424        keep each shard size identical. If the ids do not evenly divide the
425        number of shards, each of the first several shards will be assigned one
426        more id. For instance, a variable whose first dimension is 13 has 13
427        ids, and they are split across 5 shards as:
428        `[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]`.
429
430        * Variables created under `strategy.extended.colocate_vars_with` will
431        not be partitioned.
432    """
433    # pyformat: enable
434    self._cluster_resolver = cluster_resolver
435    self._extended = ParameterServerStrategyV2Extended(self, cluster_resolver,
436                                                       variable_partitioner)
437    self._verify_args_and_config(cluster_resolver)
438    self._cluster_coordinator = None
439    logging.info(
440        "`tf.distribute.experimental.ParameterServerStrategy` is initialized "
441        "with cluster_spec: %s", cluster_resolver.cluster_spec())
442
443    # TODO(b/167894802): Make coordinator, worker, and ps names customizable.
444    self._connect_to_cluster(coordinator_name="chief")
445    super(ParameterServerStrategyV2, self).__init__(self._extended)
446    distribute_lib.distribution_strategy_gauge.get_cell("V2").set(
447        "ParameterServerStrategy")
448    self._should_use_with_coordinator = True
449
450  def _connect_to_cluster(self, coordinator_name):
451    if coordinator_name in ["worker", "ps"]:
452      raise ValueError("coordinator name should not be 'worker' or 'ps'.")
453    cluster_spec = self._cluster_resolver.cluster_spec()
454    self._num_workers = len(cluster_spec.as_dict().get("worker", ()))
455    self._num_ps = len(cluster_spec.as_dict().get("ps", ()))
456
457    device_filters = server_lib.ClusterDeviceFilters()
458    # For any worker, only the devices on ps and coordinator nodes are visible
459    for i in range(self._num_workers):
460      device_filters.set_device_filters(
461          "worker", i, ["/job:ps", "/job:%s" % coordinator_name])
462    # Similarly for any ps, only the devices on workers and coordinator are
463    # visible
464    for i in range(self._num_ps):
465      device_filters.set_device_filters(
466          "ps", i, ["/job:worker", "/job:%s" % coordinator_name])
467
468    # Allow at most one outstanding RPC for each worker at a certain time. This
469    # is to simplify worker failure handling in the runtime
470    os.environ["TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE"] = "False"
471
472    logging.info("%s is now connecting to cluster with cluster_spec: %r",
473                 self.__class__.__name__, cluster_spec)
474    remote.connect_to_cluster(
475        cluster_spec,
476        job_name=coordinator_name,
477        protocol=self._cluster_resolver.rpc_layer,
478        cluster_device_filters=device_filters)
479
480    distribute_lib.distribution_strategy_replica_gauge.get_cell(
481        "ps_strategy_num_workers").set(self._num_workers)
482    distribute_lib.distribution_strategy_replica_gauge.get_cell(
483        "ps_strategy_num_ps").set(self._num_ps)
484
485  def _verify_args_and_config(self, cluster_resolver):
486    if not cluster_resolver.cluster_spec():
487      raise ValueError("Cluster spec must be non-empty in "
488                       "`tf.distribute.cluster_resolver.ClusterResolver`.")
489    if self.extended._num_gpus_per_worker > 1:  # pylint: disable=protected-access
490      raise NotImplementedError("Multi-gpu is not supported yet.")
491
492    cluster_spec = cluster_resolver.cluster_spec()
493
494    # The following checks if the task types are allowed (chief, ps, worker).
495    multi_worker_util._validate_cluster_spec(  # pylint: disable=protected-access
496        cluster_spec,
497        cluster_resolver.task_type,
498        cluster_resolver.task_id)
499
500    if multi_worker_util.task_count(cluster_spec, "ps") < 1:
501      raise ValueError("There must be at least one ps.")
502
503    if multi_worker_util.task_count(cluster_spec, "worker") < 1:
504      raise ValueError("There must be at least one worker.")
505
506
507class ParameterServerStrategyV2Extended(
508    parameter_server_strategy.ParameterServerStrategyExtended):
509  """Extended class for ParameterServerStrategyV2.
510
511  Please see `tf.distribute.StrategyExtended` doc for more information.
512  """
513
514  def __init__(self, container_strategy, cluster_resolver,
515               variable_partitioner):
516    """Initialization of ParameterServerStrategyV2Extended."""
517    super(ParameterServerStrategyV2Extended, self).__init__(container_strategy)
518    self._num_ps = len(cluster_resolver.cluster_spec().as_dict().get("ps", []))
519    self._variable_count = 0
520    self._variable_partitioner = variable_partitioner
521
522    # The following two attrs are to verify that `ParameterServerStrategy`
523    # methods are properly used with a `ClusterCoordinator`.
524    self._used_with_coordinator = False
525    self._being_scheduled = False
526
527  def _create_variable(self, next_creator, **kwargs):
528    """Implements StrategyExtendedV2._create_variable.
529
530    Creates a `Variable` or a `ShardedVariable`. A `ShardedVariable` will be
531    created if satisfying all the following criteria:
532      1. `self._variable_partitioner` results in more than one partition on the
533         first axis.
534      2. variable's rank is greater than 0.
535      3. variable is not colocated with another variable.
536    Otherwise a `Variable` will be created.
537
538    Args:
539      next_creator: See `variable_scope.variable_creator_scope`; the next
540        creator in the chain.
541      **kwargs: Passed through to the next creator.
542
543    Returns:
544      A `Variable` or `ShardedVariable`.
545    """
546
547    if "colocate_with" in kwargs:  # Never partition colocated_with variables.
548      colocate_with = kwargs["colocate_with"]
549      # Clear the variable scope to avoid possible conflicts between device
550      # scope and colocation scope.
551      with ops.device(None):
552        with ops.colocate_with(colocate_with):
553          var = next_creator(**kwargs)
554          logging.debug(
555              "Creating variable (name:%s, shape:%r) that colocates with %s",
556              var.name, var.shape, kwargs["colocate_with"].name)
557          return var
558
559    if self._variable_partitioner is None:
560      return self._create_variable_round_robin(next_creator, **kwargs)
561
562    name = kwargs.get("name", None)
563    initial_value = kwargs.get("initial_value", None)
564    if initial_value is None:
565      raise ValueError(
566          "It looks like you are using `ParameterServerStrategy` with a "
567          "`variable_partitioner`, and trying to create a variable without "
568          "specifying `initial_value`. This is not allowed. Please specify the "
569          "`initial_value`. This can also happen if you are trying to load a "
570          "saved_model within a `ParameterServerStrategy` scope. Loading a "
571          "saved_model with `variable_partitioner` is not supported.")
572
573    # Two cases where initial_value can be a callable:
574    #   1. initial_value is passed as a callable, e.g, an `initializer` class.
575    #   2. restoring from checkpoint, initial_value is a
576    #     "CheckpointInitialValueCallable".
577    init_from_fn = callable(initial_value)
578
579    dtype = kwargs.get("dtype", None)
580    shape = kwargs.get("shape", None)
581    if init_from_fn and (shape is None or dtype is None):
582      init_from_fn = False
583      initial_value = initial_value()
584    if not init_from_fn:
585      # The initial_value is created on coordinator, it will need to be sent to
586      # ps for variable initialization, which can be inefficient and can
587      # potentially hit the 2GB limit on protobuf serialization.
588      initial_value = ops.convert_to_tensor(initial_value, dtype=dtype)
589      dtype = initial_value.dtype
590      shape = initial_value.shape
591    else:
592      shape = tensor_shape.as_shape(shape)
593
594    if shape.rank == 0:  # Skip partitioning rank-0 variable.
595      return self._create_variable_round_robin(next_creator, **kwargs)
596
597    num_partitions = self._variable_partitioner(shape=shape, dtype=dtype)
598    if not num_partitions or num_partitions[0] == 0 or any(
599        v != 1 for v in num_partitions[1:]):
600      raise ValueError(
601          "variable_partitioner must return a list/tuple whose elements are 1"
602          " besides the first element (non-zero), got: %r" % num_partitions)
603
604    if num_partitions[0] == 1:  # no partition
605      return self._create_variable_round_robin(next_creator, **kwargs)
606
607    # Use "div" partition strategy to partition the variable.
608    num_partitions = min(num_partitions[0], shape[0])
609    base = shape[0] // num_partitions
610    extra = shape[0] % num_partitions
611    # An example: num_partitions=4, shape[0]=10, partitions: [3, 3, 2, 2]
612    # offsets: [0, 3, 6, 8, 10]
613    offsets = []
614    for i in range(num_partitions):
615      if i == 0:
616        offsets.append(0)
617      else:
618        prev_shard_size = base + (1 if i - 1 < extra else 0)
619        offsets.append(offsets[i - 1] + prev_shard_size)
620    offsets.append(shape[0])
621
622    def init_shard_fn(shard_index):
623      if not init_from_fn:
624        logging.log_if(
625            logging.WARN, _INEFFICIENT_INIT_WARNING % name, shard_index == 0 and
626            shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS)
627        return initial_value[offsets[shard_index]:offsets[shard_index + 1]]
628      partition_shape = (offsets[shard_index + 1] -
629                         offsets[shard_index],) + shape[1:]
630      partition_offset = (offsets[shard_index],) + (0,) * len(shape[1:])
631      arg_spec = tf_inspect.getfullargspec(initial_value)
632      if ("shard_info" not in arg_spec.args and
633          "shard_info" not in arg_spec.kwonlyargs):
634        try:
635          value = initial_value(
636              partition_shape=partition_shape,
637              partition_offset=partition_offset)
638        except (TypeError, ValueError):
639          # TypeError: Initializer doesn't accept kwargs
640          # ValueError: Initializer doesn't accept partition kwargs
641          # In both cases we go ahead creating the full value and then slice.
642          value = initial_value()
643
644        if value.shape == partition_shape:
645          # Initializer supports partition: value is the partition value.
646          return value
647        else:
648          # Initializer doesn't support partition: value is the full value
649          # and needs to be sliced to get the partition value.
650          logging.log_if(
651              logging.WARN, _INEFFICIENT_INIT_WARNING % name,
652              shard_index == 0 and
653              shape.num_elements() > _LARGE_VARIABLE_NUM_ELEMENTS)
654          return value[offsets[shard_index]:offsets[shard_index + 1]]
655      else:
656        # For compatibility with `CheckpointInitialValueCallable`.
657        return initial_value(
658            shard_info=trackable.ShardInfo(
659                shape=tensor_shape.as_shape(partition_shape),
660                offset=partition_offset))
661
662    var_list = []
663    for i in range(num_partitions):
664      kwargs["shape"] = (offsets[i + 1] - offsets[i],) + shape[1:]
665      kwargs["initial_value"] = lambda: init_shard_fn(i)
666      if name is not None:
667        kwargs["name"] = "{}/part_{}".format(name, i)
668      var_list.append(self._create_variable_round_robin(next_creator, **kwargs))
669
670    result = sharded_variable.ShardedVariable(var_list)
671    return result
672
673  def _create_variable_round_robin(self, next_creator, **kwargs):
674    # Clear the colocation scope to avoid possible conflicts between device
675    # scope and colocation scope.
676    with ops.colocate_with(None, ignore_existing=True):
677      with ops.device("/job:ps/task:%d" %
678                      (self._variable_count % self._num_ps)):
679        var = next_creator(**kwargs)
680        logging.debug(
681            "Creating variable (name:%s, shape:%r) on /job:ps/task:%d",
682            var.name, var.shape, (self._variable_count % self._num_ps))
683        self._variable_count += 1
684        return var
685
686  def _assert_used_with_cluster_coordinator(self):
687    if not self._used_with_coordinator:
688      raise NotImplementedError(
689          "`tf.distribute.experimental.ParameterServerStrategy` must be used "
690          "with `tf.distribute.experimental.coordinator.ClusterCoordinator`.")
691
692  def _assert_being_scheduled_by_cluster_coordinator(self):
693    if not self._being_scheduled:
694      raise NotImplementedError(
695          "`tf.distribute.experimental.ParameterServerStrategy`'s `run` or "
696          "`reduce` must be used within a function passed to `"
697          "tf.distribute.experimental.coordinator.ClusterCoordinator.schedule"
698          "`.")
699
700  def _experimental_distribute_dataset(self, dataset, options):
701    self._assert_used_with_cluster_coordinator()
702    if not ops.get_default_graph().building_function:
703      raise ValueError(
704          "The `experimental_distribute_dataset` method must be called inside "
705          "a `tf.function` passed to `create_per_worker_dataset` of "
706          "`tf.distribute.experimental.coordinator.ClusterCoordinator`")
707    return dataset
708
709  def _distribute_datasets_from_function(self, dataset_fn, options):
710    self._assert_used_with_cluster_coordinator()
711    if not ops.get_default_graph().building_function:
712      raise ValueError(
713          "The `distribute_datasets_from_function` method must be called "
714          "inside a `tf.function` passed to `create_per_worker_dataset` of "
715          "`tf.distribute.experimental.coordinator.ClusterCoordinator`")
716    return dataset_fn(distribute_lib.InputContext())
717
718  def _call_for_each_replica(self, fn, args, kwargs):
719    self._assert_being_scheduled_by_cluster_coordinator()
720    with distribute_lib.ReplicaContext(
721        self._container_strategy(),
722        replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)):
723      # TODO(rchao): Support multi-replica per worker or sync-group.
724      return distribute_utils.regroup((fn(*args, **kwargs),))
725
726  def _reduce(self, reduce_op, value):
727    self._assert_being_scheduled_by_cluster_coordinator()
728    # TODO(rchao): Provide implementation for multi-replica. Also look into why
729    # the default implementation is not working.
730    return value
731
732
733# The warning that will be logged if the way we initialize sharded variables
734# is memory-inefficient.
735_INEFFICIENT_INIT_WARNING = (
736    "Large variable %s is partitioned but not initialized in a "
737    "memory-efficient way. On each shard, the full value is first being "
738    "created and then sliced into smaller values. To reduce the memory "
739    "footprint, explicitly specify `dtype` and `shape` when creating "
740    "variables, and use `tf.initializers` to initialize the variable. "
741    "Note that some initializers (e.g., orthogonal) don't support "
742    "memory-efficient initialization and there is not much you can do here.")
743
744_LARGE_VARIABLE_NUM_ELEMENTS = 1e9
745