• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python3
2# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Tests for parameter_server_strategy_v2.py."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import contextlib
23import functools
24import os
25
26from absl.testing import parameterized
27import numpy as np
28
29from tensorflow.core.protobuf import saved_model_pb2
30from tensorflow.python.compat import v2_compat
31from tensorflow.python.data.ops import dataset_ops
32from tensorflow.python.distribute import distribution_strategy_context
33from tensorflow.python.distribute import multi_worker_test_base
34from tensorflow.python.distribute import parameter_server_strategy_v2
35from tensorflow.python.distribute import ps_values
36from tensorflow.python.distribute import sharded_variable
37from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
38from tensorflow.python.eager import context
39from tensorflow.python.eager import def_function
40from tensorflow.python.eager import test
41from tensorflow.python.framework import constant_op
42from tensorflow.python.framework import dtypes
43from tensorflow.python.framework import ops
44from tensorflow.python.framework import tensor_spec
45from tensorflow.python.module import module
46from tensorflow.python.ops import array_ops
47from tensorflow.python.ops import embedding_ops
48from tensorflow.python.ops import init_ops_v2
49from tensorflow.python.ops import linalg_ops_impl
50from tensorflow.python.ops import math_ops
51from tensorflow.python.ops import variable_scope
52from tensorflow.python.ops import variables
53from tensorflow.python.platform import gfile
54from tensorflow.python.saved_model import save
55from tensorflow.python.training.server_lib import ClusterSpec
56from tensorflow.python.training.tracking import tracking
57from tensorflow.python.training.tracking import util as tracking_util
58
59# We create one cluster to share between tests. The cluster should be large
60# enough to accommodate all the tests. Adjust the following constants as needed
61# but be aware of resource limitations in OSS tests.
62MAX_NUM_WORKER = 2
63MAX_NUM_PS = 3
64
65_cluster = None
66
67
68def get_cluster_def(num_workers, num_ps):
69  if num_workers > MAX_NUM_WORKER or num_ps > MAX_NUM_PS:
70    raise ValueError("Requesting more servers than the maximum, adjust"
71                     "MAX_NUM_PS and MAX_NUM_WORKER")
72  global _cluster
73  if _cluster is None:
74    _cluster = multi_worker_test_base.create_in_process_cluster(
75        num_workers=MAX_NUM_WORKER, num_ps=MAX_NUM_PS)
76  return {
77      "worker": _cluster["worker"][:num_workers],
78      "ps": _cluster["ps"][:num_ps],
79  }
80
81
82class ParameterServerStrategyV2Test(test.TestCase):
83
84  def setUp(self):
85    super().setUp()
86    cluster_def = get_cluster_def(num_workers=2, num_ps=3)
87    self.cluster_resolver = SimpleClusterResolver(ClusterSpec(cluster_def))
88
89  def tearDown(self):
90    super().tearDown()
91    # reset context to disconnect from the cluster.
92    context._reset_context()
93
94  def testVariablePlacement(self):
95
96    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
97        self.cluster_resolver)
98    v1 = variables.Variable(initial_value=0.0)
99    with strategy.scope():
100      v2 = variables.Variable(initial_value=1.0)
101      v3 = variables.Variable(initial_value=2.0)
102      v4 = variables.Variable(initial_value=3.0)
103      v5 = variables.Variable(initial_value=4.0)
104    # v1 was created outside scope so should be on client.
105    gpu_devices = context.num_gpus()
106    if gpu_devices:
107      # For tests with GPUs
108      self.assertEqual(v1.device, "/job:chief/replica:0/task:0/device:GPU:0")
109    else:
110      self.assertEqual(v1.device, "/job:chief/replica:0/task:0/device:CPU:0")
111    # v2 through v5 are created in scope and in a round-robin manner.
112    self.assertEqual(v2.device, "/job:ps/replica:0/task:0/device:CPU:0")
113    self.assertEqual(v3.device, "/job:ps/replica:0/task:1/device:CPU:0")
114    self.assertEqual(v4.device, "/job:ps/replica:0/task:2/device:CPU:0")
115    self.assertEqual(v5.device, "/job:ps/replica:0/task:0/device:CPU:0")
116
117  def testInteractionWithDeviceScope(self):
118    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
119        self.cluster_resolver)
120
121    # The strategy scope always wins.
122    with strategy.scope():
123      with ops.device("/job:ps/replica:0/task:1"):
124        v0 = variables.Variable(initial_value=0.0)
125      self.assertEqual(v0.device, "/job:ps/replica:0/task:0/device:CPU:0")
126
127      with ops.device("/job:ps/replica:0/task:0"):
128        v1 = variables.Variable(initial_value=0.0)
129      self.assertEqual(v1.device, "/job:ps/replica:0/task:1/device:CPU:0")
130
131    with ops.device("/job:ps/replica:0/task:1"):
132      with strategy.scope():
133        v2 = variables.Variable(initial_value=0.0)
134        self.assertEqual(v2.device, "/job:ps/replica:0/task:2/device:CPU:0")
135
136        v3 = variables.Variable(initial_value=0.0)
137        self.assertEqual(v3.device, "/job:ps/replica:0/task:0/device:CPU:0")
138
139  def testInteractionWithVariableCreatorScope(self):
140
141    def var_creator(next_creator, **kwargs):
142      if "colocate_with" in kwargs:
143        with ops.device(None):
144          with ops.colocate_with(kwargs["colocate_with"]):
145            return next_creator(**kwargs)
146
147      self.assertIn("ps1", kwargs["name"])
148      with ops.device("/job:ps/task:1"):
149        return next_creator(**kwargs)
150
151    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
152        self.cluster_resolver)
153
154    # variable_creator_scope itself will work.
155    with variable_scope.variable_creator_scope(var_creator):
156      v0 = variables.Variable(initial_value=0.0, name="ps1_0")
157    self.assertEqual(v0.device, "/job:ps/replica:0/task:1/device:CPU:0")
158
159    # variable_creator_scope inside strategy.scope will not work.
160    with strategy.scope():
161      with variable_scope.variable_creator_scope(var_creator):
162        v1 = variables.Variable(initial_value=0.0, name="ps1_1")
163    self.assertEqual(v1.device, "/job:ps/replica:0/task:0/device:CPU:0")
164
165    # strategy.scope still assigns variables in a round robin fashion.
166    with strategy.scope():
167      v2 = variables.Variable(initial_value=0.0, name="ps1_2")
168    self.assertEqual(v2.device, "/job:ps/replica:0/task:1/device:CPU:0")
169
170    with strategy.scope():
171      v3 = variables.Variable(initial_value=0.0, name="ps1_3")
172    self.assertEqual(v3.device, "/job:ps/replica:0/task:2/device:CPU:0")
173
174    # variable_creator_scope outside strategy.scope will work.
175    with variable_scope.variable_creator_scope(var_creator):
176      with strategy.scope():
177        v4 = variables.Variable(initial_value=0.0, name="ps1_4")
178    self.assertEqual(v4.device, "/job:ps/replica:0/task:1/device:CPU:0")
179
180    with variable_scope.variable_creator_scope(var_creator):
181      with strategy.scope():
182        v5 = variables.Variable(initial_value=0.0, name="ps1_5")
183    self.assertEqual(v5.device, "/job:ps/replica:0/task:1/device:CPU:0")
184
185    # variable_creator_scope can be made to respect "colocate_with" as well.
186    with variable_scope.variable_creator_scope(var_creator):
187      with strategy.scope():
188        with strategy.extended.colocate_vars_with(v1):
189          v6 = variables.Variable(initial_value=0.0, name="ps1_6")
190    self.assertEqual(v6.device, "/job:ps/replica:0/task:0/device:CPU:0")
191
192  @contextlib.contextmanager
193  def _assertRaisesUsageWarningWithSchedule(self):
194    with self.assertLogs(level="WARNING") as logs:
195      yield
196
197    self.assertIn(
198        "It is detected that a function used with "
199        "`tf.distribute.experimental.ParameterServerStrategy` "
200        "is executed locally on the coordinator. This is inefficient but may "
201        "be valid for one-off tasks such as inferring output signature. "
202        "To properly distribute functions to run on workers, `run` or "
203        "`reduce` should be used within a function passed to `"
204        "tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`.",
205        logs.output[0])
206
207  def testRunNotUsedWithClusterCoordinator(self):
208    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
209        self.cluster_resolver)
210    dataset = dataset_ops.DatasetV2.range(8)
211    with strategy.scope():
212      v = variables.Variable(1, dtype=dtypes.int64)
213
214    def step_fn(iterator):
215      return next(iterator) + v
216
217    with self._assertRaisesUsageWarningWithSchedule():
218      strategy.run(step_fn, args=(iter(dataset),))
219
220  def testRunUsedWithTestOnlyMode(self):
221    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
222        self.cluster_resolver)
223    strategy.extended._allow_run_without_coordinator = True
224    dataset = dataset_ops.DatasetV2.range(15)
225    with strategy.scope():
226      v = variables.Variable(1, dtype=dtypes.int64)
227
228    def step_fn(iterator):
229      return next(iterator) + v
230
231    strategy.run(step_fn, args=(iter(dataset),))
232
233  def testReduceNotUsedWithClusterCoordinator(self):
234    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
235        self.cluster_resolver)
236    with self._assertRaisesUsageWarningWithSchedule():
237      strategy.reduce("SUM", None, axis=None)
238
239  def testDistributeDatasetUsedDirectly(self):
240    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
241        self.cluster_resolver)
242    dataset = dataset_ops.DatasetV2.range(3)
243    distributed_dataset = strategy.experimental_distribute_dataset(dataset)
244    with self.assertRaises(ValueError):
245      iter(distributed_dataset)
246
247    distributed_dataset = strategy.distribute_datasets_from_function(
248        lambda: dataset)
249    with self.assertRaises(ValueError):
250      iter(distributed_dataset)
251
252  def testSparselyReadForEmbeddingLookup(self):
253    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
254        self.cluster_resolver)
255
256    class FakeModel(module.Module):
257
258      def __init__(self):
259        self._var0 = variables.Variable([1.0, 2.0, 3.0, 4.0])
260        self._var1 = variables.Variable([5.0, 6.0, 7.0, 8.0])
261
262      @def_function.function(input_signature=[
263          tensor_spec.TensorSpec(shape=[2], dtype=dtypes.int32, name="inputs")
264      ])
265      def func(self, x):
266        return embedding_ops.embedding_lookup([self._var0, self._var1], x)
267
268    with strategy.scope():
269      model = FakeModel()
270
271    # Assert that ResourceGather op exists instead of Gather in training
272    # function.
273    found_resource_gather = False
274    found_gather = False
275
276    for n in model.func.get_concrete_function().graph.as_graph_def().node:
277      if n.op == "ResourceGather":
278        found_resource_gather = True
279      elif n.op == "Gather":
280        found_gather = True
281    self.assertTrue(found_resource_gather)
282    self.assertFalse(found_gather)
283
284    # Assert that ResourceGather op exists instead of Gather in saved_model.
285    found_resource_gather = False
286    found_gather = False
287
288    tmp_dir = self.get_temp_dir()
289    save.save(model, tmp_dir, signatures=model.func)
290
291    with gfile.Open("%s/saved_model.pb" % tmp_dir, "rb") as f:
292      saved_model_proto = saved_model_pb2.SavedModel().FromString(f.read())
293
294    for function in saved_model_proto.meta_graphs[0].graph_def.library.function:
295      for n in function.node_def:
296        if n.op == "ResourceGather":
297          found_resource_gather = True
298          resource_gather_device = n.device
299        elif n.op == "Gather":
300          found_gather = True
301    self.assertTrue(found_resource_gather)
302    self.assertFalse(found_gather)
303
304    # We also assert that the colocate_with in embedding_ops will not result in
305    # a hard-coded device string.
306    self.assertEmpty(resource_gather_device)
307
308
309class PartitionAwareIdentity(object):
310
311  def __call__(self, shape, dtype, **kwargs):
312    value = linalg_ops_impl.eye(*shape, dtype=dtype)
313    if "partition_shape" in kwargs and "partition_offset" in kwargs:
314      return array_ops.slice(value, kwargs["partition_offset"],
315                             kwargs["partition_shape"])
316    raise AssertionError("PartitionAwareIdentity do not support "
317                         "non-partitioned initialization")
318
319
320class VariablePartitioningTest(test.TestCase, parameterized.TestCase):
321
322  def setUp(self):
323    super().setUp()
324    cluster_def = get_cluster_def(num_workers=2, num_ps=2)
325    self.cluster_resolver = SimpleClusterResolver(ClusterSpec(cluster_def))
326
327  def tearDown(self):
328    super().tearDown()
329    # reset context to disconnect from the cluster.
330    context._reset_context()
331
332  def testDefaultNoPartition(self):
333    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
334        self.cluster_resolver)
335    with strategy.scope():
336      v = variables.Variable([0, 1, 2, 3])
337
338    self.assertIsInstance(v, variables.Variable)
339
340  def testBasic(self):
341    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
342        self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2))
343    with strategy.scope():
344      init1 = init_ops_v2.Constant([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
345      v1 = variables.Variable(
346          initial_value=lambda: init1(shape=(5, 2), dtype=dtypes.int64),
347          shape=(5, 2),
348          dtype=dtypes.int64)
349
350      init2 = init_ops_v2.Constant([0, 1, 2, 3, 4, 5])
351      v2 = variables.Variable(
352          initial_value=lambda: init2(shape=(6, 1), dtype=dtypes.int64),
353          shape=(6, 1),
354          dtype=dtypes.int64)
355
356    self.assertIsInstance(v1, sharded_variable.ShardedVariable)
357    self.assertLen(v1.variables, 2)
358    self.assertRegex(v1.variables[0].device, "/job:ps/replica:0/task:0")
359    self.assertRegex(v1.variables[1].device, "/job:ps/replica:0/task:1")
360    self.assertAllEqual(v1.variables[0], [[0, 1], [2, 3], [4, 5]])
361    self.assertAllEqual(v1.variables[1], [[6, 7], [8, 9]])
362
363    self.assertIsInstance(v2, sharded_variable.ShardedVariable)
364    self.assertLen(v2.variables, 2)
365    self.assertRegex(v2.variables[0].device, "/job:ps/replica:0/task:0")
366    self.assertRegex(v2.variables[1].device, "/job:ps/replica:0/task:1")
367    self.assertAllEqual(v2.variables[0], [[0], [1], [2]])
368    self.assertAllEqual(v2.variables[1], [[3], [4], [5]])
369
370  def testBasicVariableWithAggregation(self):
371    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
372        self.cluster_resolver)
373    strategy.extended._allow_run_without_coordinator = True
374    with strategy.scope():
375      v = variables.Variable(
376          initial_value=[0, 0, 0, 0, 0, 0, 0, 0],
377          dtype=dtypes.float32,
378          aggregation=variable_scope.VariableAggregation.SUM)
379
380    if strategy.num_replicas_in_sync > 1:
381      self.assertIsInstance(v, ps_values.AggregatingVariable)
382    else:
383      self.assertIsInstance(v, variables.Variable)
384
385    def replica_fn():
386      replica_id = distribution_strategy_context.get_replica_context(
387      ).replica_id_in_sync_group
388      val = array_ops.reshape(
389          math_ops.cast(replica_id + 10, dtype=v.dtype), [1])
390      v.assign(
391          array_ops.concat(
392              [val, constant_op.constant([1., 2., 3., 4., 5., 6., 7.])], 0))
393
394    strategy.run(replica_fn)
395
396    expected_result = np.arange(8.) * strategy.num_replicas_in_sync
397    for i in range(strategy.num_replicas_in_sync):
398      expected_result[0] = expected_result[0] + i + 10
399    self.assertAllEqual(v, expected_result)
400
401  def testBasicShardedVariableWithAggregation(self):
402    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
403        self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2))
404    strategy.extended._allow_run_without_coordinator = True
405    with strategy.scope():
406      v = variables.Variable(
407          initial_value=[0, 0, 0, 0, 0, 0, 0, 0],
408          dtype=dtypes.float32,
409          aggregation=variable_scope.VariableAggregation.SUM)
410
411    self.assertIsInstance(v, sharded_variable.ShardedVariable)
412    self.assertLen(v.variables, 2)
413    if strategy.num_replicas_in_sync > 1:
414      self.assertIsInstance(v.variables[0], ps_values.AggregatingVariable)
415    else:
416      self.assertIsInstance(v.variables[0], variables.Variable)
417
418    def replica_fn():
419      replica_id = distribution_strategy_context.get_replica_context(
420      ).replica_id_in_sync_group
421      val = array_ops.reshape(
422          math_ops.cast(replica_id + 10, dtype=v.dtype), [1])
423      v.assign(
424          array_ops.concat(
425              [val, constant_op.constant([1., 2., 3., 4., 5., 6., 7.])], 0))
426
427    strategy.run(replica_fn)
428
429    expected_result = np.arange(8.) * strategy.num_replicas_in_sync
430    for i in range(strategy.num_replicas_in_sync):
431      expected_result[0] = expected_result[0] + i + 10
432    expected_result = np.array_split(expected_result, 2)
433    self.assertAllEqual(expected_result[0], v.variables[0])
434    self.assertAllEqual(expected_result[1], v.variables[1])
435
436  def testNonCallableInitialValue(self):
437    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
438        self.cluster_resolver, sharded_variable.FixedShardsPartitioner(4))
439    with strategy.scope():
440      v = variables.Variable([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
441
442    self.assertIsInstance(v, sharded_variable.ShardedVariable)
443    self.assertLen(v.variables, 4)
444    self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0")
445    self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1")
446    self.assertRegex(v.variables[2].device, "/job:ps/replica:0/task:0")
447    self.assertRegex(v.variables[3].device, "/job:ps/replica:0/task:1")
448    self.assertAllEqual(v.variables[0], [0, 1, 2])
449    self.assertAllEqual(v.variables[1], [3, 4, 5])
450    self.assertAllEqual(v.variables[2], [6, 7])
451    self.assertAllEqual(v.variables[3], [8, 9])
452
453  def testNumPartitionsLargerThanSize(self):
454    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
455        self.cluster_resolver, sharded_variable.FixedShardsPartitioner(4))
456    with strategy.scope():
457      v = variables.Variable([0, 1, 2])
458
459    self.assertIsInstance(v, sharded_variable.ShardedVariable)
460    self.assertLen(v.variables, 3)
461    self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0")
462    self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1")
463    self.assertRegex(v.variables[2].device, "/job:ps/replica:0/task:0")
464    self.assertAllEqual(v.variables[0], [0])
465    self.assertAllEqual(v.variables[1], [1])
466    self.assertAllEqual(v.variables[2], [2])
467
468  def testPartitionToOne(self):
469    # For small variables there is only one partition.
470    variable_partitioner = sharded_variable.MinSizePartitioner(
471        min_shard_bytes=64 << 20, max_shards=2)
472    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
473        self.cluster_resolver, variable_partitioner)
474    with strategy.scope():
475      initializer = init_ops_v2.Constant([0] * 10)
476      v1 = variables.Variable(
477          initial_value=lambda: initializer(shape=(10,), dtype=dtypes.int64),
478          shape=(10,),
479          dtype=dtypes.int64)
480
481      v2 = variables.Variable([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
482
483    self.assertIsInstance(v1, variables.Variable)
484    self.assertNotIsInstance(v1, sharded_variable.ShardedVariable)
485    self.assertRegex(v1.device, "/job:ps/replica:0/task:0")
486    self.assertAllEqual(v1, [0] * 10)
487
488    self.assertIsInstance(v2, variables.Variable)
489    self.assertNotIsInstance(v2, sharded_variable.ShardedVariable)
490    self.assertRegex(v2.device, "/job:ps/replica:0/task:1")
491    self.assertAllEqual(v2, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
492
493  def testColocateWith(self):
494    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
495        self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2))
496    with strategy.scope():
497      v1 = variables.Variable([0, 1, 2, 3])
498
499      with strategy.extended.colocate_vars_with(v1.variables[0]):
500        v2 = variables.Variable([4, 5])
501
502    self.assertIsInstance(v1, sharded_variable.ShardedVariable)
503
504    self.assertIsInstance(v2, variables.Variable)
505    self.assertNotIsInstance(v2, sharded_variable.ShardedVariable)
506    self.assertEqual(v2.device, v1.variables[0].device)
507    self.assertAllEqual(v2, [4, 5])
508
509  def testCustomPartitionAwareInitializer(self):
510    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
511        self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2))
512    with strategy.scope():
513      initializer = PartitionAwareIdentity()
514      initial_value = functools.partial(
515          initializer, shape=(4, 4), dtype=dtypes.int64)
516      v = variables.Variable(
517          initial_value=initial_value, shape=(4, 4), dtype=dtypes.int64)
518
519    self.assertIsInstance(v, sharded_variable.ShardedVariable)
520    self.assertLen(v.variables, 2)
521    self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0")
522    self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1")
523    self.assertAllEqual(v.variables[0], [[1, 0, 0, 0], [0, 1, 0, 0]])
524    self.assertAllEqual(v.variables[1], [[0, 0, 1, 0], [0, 0, 0, 1]])
525
526  def testPartitionWhenLackOfInfo(self):
527    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
528        self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2))
529    with strategy.scope():
530      initializer = init_ops_v2.Constant([0, 1, 2, 3])
531      # Shape is not explicitly specified.
532      v1 = variables.Variable(
533          initial_value=lambda: initializer(shape=(4,), dtype=dtypes.int64),
534          dtype=dtypes.int64)
535      # Dtype is not explicitly specified.
536      v2 = variables.Variable(
537          initial_value=lambda: initializer(shape=(4,), dtype=dtypes.int64),
538          shape=(4,))
539      # Neither shape nor dtype is explicitly specified.
540      v3 = variables.Variable(
541          initial_value=lambda: initializer(shape=(4,), dtype=dtypes.int64))
542
543    for v in [v1, v2, v3]:
544      self.assertIsInstance(v, sharded_variable.ShardedVariable)
545      self.assertLen(v.variables, 2)
546      self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0")
547      self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1")
548      self.assertAllEqual(v.variables[0], [0, 1])
549      self.assertAllEqual(v.variables[1], [2, 3])
550
551  def testInvalidPartitioner(self):
552    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
553        self.cluster_resolver, lambda shape, dtype: None)
554    with self.assertRaisesRegex(ValueError, "variable_partitioner"):
555      with strategy.scope():
556        variables.Variable([[[0, 1], [2, 3]], [[0, 1], [2, 3]]])
557
558    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
559        self.cluster_resolver, lambda shape, dtype: [])
560    with self.assertRaisesRegex(ValueError, "variable_partitioner"):
561      with strategy.scope():
562        variables.Variable([[[0, 1], [2, 3]], [[0, 1], [2, 3]]])
563
564    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
565        self.cluster_resolver, lambda shape, dtype: [0, 1, 1])
566    with self.assertRaisesRegex(ValueError, "variable_partitioner"):
567      with strategy.scope():
568        variables.Variable([[[0, 1], [2, 3]], [[0, 1], [2, 3]]])
569
570    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
571        self.cluster_resolver, lambda shape, dtype: [2, 2, 1])
572    with self.assertRaisesRegex(ValueError, "variable_partitioner"):
573      with strategy.scope():
574        variables.Variable([[[0, 1], [2, 3]], [[0, 1], [2, 3]]])
575
576  def testCreateInsideTFFunction(self):
577    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
578        self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2))
579
580    collection = []
581
582    @def_function.function
583    def create_vars():
584      if not collection:
585        identity = init_ops_v2.Identity()
586        v1 = variables.Variable([[1., 0.], [0., 1.]], dtype=dtypes.float32)
587        v2 = variables.Variable(lambda: identity((2, 2), dtypes.float32))
588        v3 = variables.Variable(
589            lambda: identity((2, 2), dtypes.float32),
590            dtype=dtypes.float32,
591            shape=(2, 2))
592        collection.extend([v1, v2, v3])
593
594    with strategy.scope():
595      create_vars()
596      for v in collection:
597        self.assertIsInstance(v, sharded_variable.ShardedVariable)
598        self.assertLen(v.variables, 2)
599        self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0")
600        self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1")
601        self.assertAllEqual(v.variables[0], [[1., 0.]])
602        self.assertAllEqual(v.variables[1], [[0., 1.]])
603
604  @parameterized.named_parameters(
605      ("Restore", False, 2),
606      ("RestoreDiffShards", False, 4),
607      ("DelayedRestore", True, 2),
608      ("DelayedRestoreDiffShards", True, 4),
609  )
610  def testCheckpoint(self, delayed, restore_shards):
611
612    def make_variable(name, shape, dtype, initializer):
613      initial_value = functools.partial(initializer, shape, dtype=dtype)
614      return variables.Variable(
615          name=name, initial_value=initial_value, shape=shape, dtype=dtype)
616
617    class Model(tracking.AutoTrackable):
618
619      def build(self):
620        self.w = self._add_variable_with_custom_getter(
621            "w",
622            shape=(4,),
623            initializer=init_ops_v2.Ones(),
624            getter=make_variable)
625
626    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
627        self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2))
628    ckpt_dir = os.path.join(self.get_temp_dir(), "checkpoint")
629
630    with strategy.scope():
631      model1 = Model()
632      model1.build()
633      self.assertIsInstance(model1.w, sharded_variable.ShardedVariable)
634      self.assertLen(model1.w.variables, 2)
635      model1.w.assign([1., 2., 3., 4.])
636
637      cp1 = tracking_util.Checkpoint(model=model1)
638      cp1.write(ckpt_dir)
639
640    strategy = parameter_server_strategy_v2.ParameterServerStrategyV2(
641        self.cluster_resolver,
642        sharded_variable.FixedShardsPartitioner(restore_shards))
643
644    with strategy.scope():
645      model2 = Model()
646      cp2 = tracking_util.Checkpoint(model=model2)
647      if delayed:
648        cp2.restore(ckpt_dir)
649        model2.build()
650      else:
651        model2.build()
652        cp2.restore(ckpt_dir)
653      self.assertIsInstance(model2.w, sharded_variable.ShardedVariable)
654      self.assertLen(model2.w.variables, restore_shards)
655      if restore_shards == 2:
656        self.assertAllEqual(model2.w.variables[0], [1., 2.])
657        self.assertAllEqual(model2.w.variables[1], [3., 4.])
658      elif restore_shards == 4:
659        self.assertAllEqual(model2.w.variables[0], [1.])
660        self.assertAllEqual(model2.w.variables[1], [2.])
661        self.assertAllEqual(model2.w.variables[2], [3.])
662        self.assertAllEqual(model2.w.variables[3], [4.])
663
664
665class ClusterTypeNameTest(test.TestCase):
666
667  def testArbitraryJobName(self):
668    cluster_def = multi_worker_test_base.create_cluster_spec(
669        num_workers=1, num_ps=1, has_chief=True)
670    cluster_def["some_arbitrary_name"] = [
671        "localhost:%d" % multi_worker_test_base.pick_unused_port()
672    ]
673    cluster_resolver = SimpleClusterResolver(
674        ClusterSpec(cluster_def), rpc_layer="grpc")
675    with self.assertRaisesRegexp(ValueError, "Disallowed task type found in"):
676      parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
677
678  def testArbitraryCurrentTaskType(self):
679    cluster_def = multi_worker_test_base.create_cluster_spec(
680        num_workers=1, num_ps=1, has_chief=True)
681    cluster_resolver = SimpleClusterResolver(
682        ClusterSpec(cluster_def), rpc_layer="grpc", task_type="foobar")
683    with self.assertRaisesRegexp(ValueError, "Unrecognized task_type: foobar"):
684      parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
685
686  def testMoreThanOneChief(self):
687    cluster_def = multi_worker_test_base.create_cluster_spec(
688        num_workers=1, num_ps=1)
689    chief_ports = [multi_worker_test_base.pick_unused_port() for _ in range(3)]
690    cluster_def["chief"] = ["localhost:%s" % port for port in chief_ports]
691    cluster_resolver = SimpleClusterResolver(
692        ClusterSpec(cluster_def),
693        rpc_layer="grpc",
694        task_type="chief",
695        task_id=1)
696    with self.assertRaisesRegexp(ValueError,
697                                 "There must be at most one 'chief' job."):
698      parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
699
700  def testLessThanOneWorker(self):
701    cluster_def = multi_worker_test_base.create_cluster_spec(
702        num_workers=0, num_ps=1, has_chief=True)
703    cluster_resolver = SimpleClusterResolver(
704        ClusterSpec(cluster_def), rpc_layer="grpc", task_type="ps", task_id=0)
705    with self.assertRaisesRegexp(ValueError,
706                                 "There must be at least one worker."):
707      parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
708
709  def testLessThanOnePs(self):
710    cluster_def = multi_worker_test_base.create_cluster_spec(
711        num_workers=1, num_ps=0, has_chief=True)
712    cluster_resolver = SimpleClusterResolver(
713        ClusterSpec(cluster_def),
714        rpc_layer="grpc",
715        task_type="worker",
716        task_id=0)
717    with self.assertRaisesRegexp(ValueError, "There must be at least one ps."):
718      parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver)
719
720
721if __name__ == "__main__":
722  v2_compat.enable_v2_behavior()
723  test.main()
724