1# Lint as: python3 2# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Tests for parameter_server_strategy_v2.py.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import contextlib 23import functools 24import os 25 26from absl.testing import parameterized 27import numpy as np 28 29from tensorflow.core.protobuf import saved_model_pb2 30from tensorflow.python.compat import v2_compat 31from tensorflow.python.data.ops import dataset_ops 32from tensorflow.python.distribute import distribution_strategy_context 33from tensorflow.python.distribute import multi_worker_test_base 34from tensorflow.python.distribute import parameter_server_strategy_v2 35from tensorflow.python.distribute import ps_values 36from tensorflow.python.distribute import sharded_variable 37from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver 38from tensorflow.python.eager import context 39from tensorflow.python.eager import def_function 40from tensorflow.python.eager import test 41from tensorflow.python.framework import constant_op 42from tensorflow.python.framework import dtypes 43from tensorflow.python.framework import ops 44from tensorflow.python.framework import tensor_spec 45from tensorflow.python.module import module 46from tensorflow.python.ops import array_ops 47from tensorflow.python.ops import embedding_ops 48from tensorflow.python.ops import init_ops_v2 49from tensorflow.python.ops import linalg_ops_impl 50from tensorflow.python.ops import math_ops 51from tensorflow.python.ops import variable_scope 52from tensorflow.python.ops import variables 53from tensorflow.python.platform import gfile 54from tensorflow.python.saved_model import save 55from tensorflow.python.training.server_lib import ClusterSpec 56from tensorflow.python.training.tracking import tracking 57from tensorflow.python.training.tracking import util as tracking_util 58 59# We create one cluster to share between tests. The cluster should be large 60# enough to accommodate all the tests. Adjust the following constants as needed 61# but be aware of resource limitations in OSS tests. 62MAX_NUM_WORKER = 2 63MAX_NUM_PS = 3 64 65_cluster = None 66 67 68def get_cluster_def(num_workers, num_ps): 69 if num_workers > MAX_NUM_WORKER or num_ps > MAX_NUM_PS: 70 raise ValueError("Requesting more servers than the maximum, adjust" 71 "MAX_NUM_PS and MAX_NUM_WORKER") 72 global _cluster 73 if _cluster is None: 74 _cluster = multi_worker_test_base.create_in_process_cluster( 75 num_workers=MAX_NUM_WORKER, num_ps=MAX_NUM_PS) 76 return { 77 "worker": _cluster["worker"][:num_workers], 78 "ps": _cluster["ps"][:num_ps], 79 } 80 81 82class ParameterServerStrategyV2Test(test.TestCase): 83 84 def setUp(self): 85 super().setUp() 86 cluster_def = get_cluster_def(num_workers=2, num_ps=3) 87 self.cluster_resolver = SimpleClusterResolver(ClusterSpec(cluster_def)) 88 89 def tearDown(self): 90 super().tearDown() 91 # reset context to disconnect from the cluster. 92 context._reset_context() 93 94 def testVariablePlacement(self): 95 96 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 97 self.cluster_resolver) 98 v1 = variables.Variable(initial_value=0.0) 99 with strategy.scope(): 100 v2 = variables.Variable(initial_value=1.0) 101 v3 = variables.Variable(initial_value=2.0) 102 v4 = variables.Variable(initial_value=3.0) 103 v5 = variables.Variable(initial_value=4.0) 104 # v1 was created outside scope so should be on client. 105 gpu_devices = context.num_gpus() 106 if gpu_devices: 107 # For tests with GPUs 108 self.assertEqual(v1.device, "/job:chief/replica:0/task:0/device:GPU:0") 109 else: 110 self.assertEqual(v1.device, "/job:chief/replica:0/task:0/device:CPU:0") 111 # v2 through v5 are created in scope and in a round-robin manner. 112 self.assertEqual(v2.device, "/job:ps/replica:0/task:0/device:CPU:0") 113 self.assertEqual(v3.device, "/job:ps/replica:0/task:1/device:CPU:0") 114 self.assertEqual(v4.device, "/job:ps/replica:0/task:2/device:CPU:0") 115 self.assertEqual(v5.device, "/job:ps/replica:0/task:0/device:CPU:0") 116 117 def testInteractionWithDeviceScope(self): 118 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 119 self.cluster_resolver) 120 121 # The strategy scope always wins. 122 with strategy.scope(): 123 with ops.device("/job:ps/replica:0/task:1"): 124 v0 = variables.Variable(initial_value=0.0) 125 self.assertEqual(v0.device, "/job:ps/replica:0/task:0/device:CPU:0") 126 127 with ops.device("/job:ps/replica:0/task:0"): 128 v1 = variables.Variable(initial_value=0.0) 129 self.assertEqual(v1.device, "/job:ps/replica:0/task:1/device:CPU:0") 130 131 with ops.device("/job:ps/replica:0/task:1"): 132 with strategy.scope(): 133 v2 = variables.Variable(initial_value=0.0) 134 self.assertEqual(v2.device, "/job:ps/replica:0/task:2/device:CPU:0") 135 136 v3 = variables.Variable(initial_value=0.0) 137 self.assertEqual(v3.device, "/job:ps/replica:0/task:0/device:CPU:0") 138 139 def testInteractionWithVariableCreatorScope(self): 140 141 def var_creator(next_creator, **kwargs): 142 if "colocate_with" in kwargs: 143 with ops.device(None): 144 with ops.colocate_with(kwargs["colocate_with"]): 145 return next_creator(**kwargs) 146 147 self.assertIn("ps1", kwargs["name"]) 148 with ops.device("/job:ps/task:1"): 149 return next_creator(**kwargs) 150 151 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 152 self.cluster_resolver) 153 154 # variable_creator_scope itself will work. 155 with variable_scope.variable_creator_scope(var_creator): 156 v0 = variables.Variable(initial_value=0.0, name="ps1_0") 157 self.assertEqual(v0.device, "/job:ps/replica:0/task:1/device:CPU:0") 158 159 # variable_creator_scope inside strategy.scope will not work. 160 with strategy.scope(): 161 with variable_scope.variable_creator_scope(var_creator): 162 v1 = variables.Variable(initial_value=0.0, name="ps1_1") 163 self.assertEqual(v1.device, "/job:ps/replica:0/task:0/device:CPU:0") 164 165 # strategy.scope still assigns variables in a round robin fashion. 166 with strategy.scope(): 167 v2 = variables.Variable(initial_value=0.0, name="ps1_2") 168 self.assertEqual(v2.device, "/job:ps/replica:0/task:1/device:CPU:0") 169 170 with strategy.scope(): 171 v3 = variables.Variable(initial_value=0.0, name="ps1_3") 172 self.assertEqual(v3.device, "/job:ps/replica:0/task:2/device:CPU:0") 173 174 # variable_creator_scope outside strategy.scope will work. 175 with variable_scope.variable_creator_scope(var_creator): 176 with strategy.scope(): 177 v4 = variables.Variable(initial_value=0.0, name="ps1_4") 178 self.assertEqual(v4.device, "/job:ps/replica:0/task:1/device:CPU:0") 179 180 with variable_scope.variable_creator_scope(var_creator): 181 with strategy.scope(): 182 v5 = variables.Variable(initial_value=0.0, name="ps1_5") 183 self.assertEqual(v5.device, "/job:ps/replica:0/task:1/device:CPU:0") 184 185 # variable_creator_scope can be made to respect "colocate_with" as well. 186 with variable_scope.variable_creator_scope(var_creator): 187 with strategy.scope(): 188 with strategy.extended.colocate_vars_with(v1): 189 v6 = variables.Variable(initial_value=0.0, name="ps1_6") 190 self.assertEqual(v6.device, "/job:ps/replica:0/task:0/device:CPU:0") 191 192 @contextlib.contextmanager 193 def _assertRaisesUsageWarningWithSchedule(self): 194 with self.assertLogs(level="WARNING") as logs: 195 yield 196 197 self.assertIn( 198 "It is detected that a function used with " 199 "`tf.distribute.experimental.ParameterServerStrategy` " 200 "is executed locally on the coordinator. This is inefficient but may " 201 "be valid for one-off tasks such as inferring output signature. " 202 "To properly distribute functions to run on workers, `run` or " 203 "`reduce` should be used within a function passed to `" 204 "tf.distribute.experimental.coordinator.ClusterCoordinator.schedule`.", 205 logs.output[0]) 206 207 def testRunNotUsedWithClusterCoordinator(self): 208 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 209 self.cluster_resolver) 210 dataset = dataset_ops.DatasetV2.range(8) 211 with strategy.scope(): 212 v = variables.Variable(1, dtype=dtypes.int64) 213 214 def step_fn(iterator): 215 return next(iterator) + v 216 217 with self._assertRaisesUsageWarningWithSchedule(): 218 strategy.run(step_fn, args=(iter(dataset),)) 219 220 def testRunUsedWithTestOnlyMode(self): 221 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 222 self.cluster_resolver) 223 strategy.extended._allow_run_without_coordinator = True 224 dataset = dataset_ops.DatasetV2.range(15) 225 with strategy.scope(): 226 v = variables.Variable(1, dtype=dtypes.int64) 227 228 def step_fn(iterator): 229 return next(iterator) + v 230 231 strategy.run(step_fn, args=(iter(dataset),)) 232 233 def testReduceNotUsedWithClusterCoordinator(self): 234 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 235 self.cluster_resolver) 236 with self._assertRaisesUsageWarningWithSchedule(): 237 strategy.reduce("SUM", None, axis=None) 238 239 def testDistributeDatasetUsedDirectly(self): 240 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 241 self.cluster_resolver) 242 dataset = dataset_ops.DatasetV2.range(3) 243 distributed_dataset = strategy.experimental_distribute_dataset(dataset) 244 with self.assertRaises(ValueError): 245 iter(distributed_dataset) 246 247 distributed_dataset = strategy.distribute_datasets_from_function( 248 lambda: dataset) 249 with self.assertRaises(ValueError): 250 iter(distributed_dataset) 251 252 def testSparselyReadForEmbeddingLookup(self): 253 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 254 self.cluster_resolver) 255 256 class FakeModel(module.Module): 257 258 def __init__(self): 259 self._var0 = variables.Variable([1.0, 2.0, 3.0, 4.0]) 260 self._var1 = variables.Variable([5.0, 6.0, 7.0, 8.0]) 261 262 @def_function.function(input_signature=[ 263 tensor_spec.TensorSpec(shape=[2], dtype=dtypes.int32, name="inputs") 264 ]) 265 def func(self, x): 266 return embedding_ops.embedding_lookup([self._var0, self._var1], x) 267 268 with strategy.scope(): 269 model = FakeModel() 270 271 # Assert that ResourceGather op exists instead of Gather in training 272 # function. 273 found_resource_gather = False 274 found_gather = False 275 276 for n in model.func.get_concrete_function().graph.as_graph_def().node: 277 if n.op == "ResourceGather": 278 found_resource_gather = True 279 elif n.op == "Gather": 280 found_gather = True 281 self.assertTrue(found_resource_gather) 282 self.assertFalse(found_gather) 283 284 # Assert that ResourceGather op exists instead of Gather in saved_model. 285 found_resource_gather = False 286 found_gather = False 287 288 tmp_dir = self.get_temp_dir() 289 save.save(model, tmp_dir, signatures=model.func) 290 291 with gfile.Open("%s/saved_model.pb" % tmp_dir, "rb") as f: 292 saved_model_proto = saved_model_pb2.SavedModel().FromString(f.read()) 293 294 for function in saved_model_proto.meta_graphs[0].graph_def.library.function: 295 for n in function.node_def: 296 if n.op == "ResourceGather": 297 found_resource_gather = True 298 resource_gather_device = n.device 299 elif n.op == "Gather": 300 found_gather = True 301 self.assertTrue(found_resource_gather) 302 self.assertFalse(found_gather) 303 304 # We also assert that the colocate_with in embedding_ops will not result in 305 # a hard-coded device string. 306 self.assertEmpty(resource_gather_device) 307 308 309class PartitionAwareIdentity(object): 310 311 def __call__(self, shape, dtype, **kwargs): 312 value = linalg_ops_impl.eye(*shape, dtype=dtype) 313 if "partition_shape" in kwargs and "partition_offset" in kwargs: 314 return array_ops.slice(value, kwargs["partition_offset"], 315 kwargs["partition_shape"]) 316 raise AssertionError("PartitionAwareIdentity do not support " 317 "non-partitioned initialization") 318 319 320class VariablePartitioningTest(test.TestCase, parameterized.TestCase): 321 322 def setUp(self): 323 super().setUp() 324 cluster_def = get_cluster_def(num_workers=2, num_ps=2) 325 self.cluster_resolver = SimpleClusterResolver(ClusterSpec(cluster_def)) 326 327 def tearDown(self): 328 super().tearDown() 329 # reset context to disconnect from the cluster. 330 context._reset_context() 331 332 def testDefaultNoPartition(self): 333 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 334 self.cluster_resolver) 335 with strategy.scope(): 336 v = variables.Variable([0, 1, 2, 3]) 337 338 self.assertIsInstance(v, variables.Variable) 339 340 def testBasic(self): 341 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 342 self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2)) 343 with strategy.scope(): 344 init1 = init_ops_v2.Constant([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) 345 v1 = variables.Variable( 346 initial_value=lambda: init1(shape=(5, 2), dtype=dtypes.int64), 347 shape=(5, 2), 348 dtype=dtypes.int64) 349 350 init2 = init_ops_v2.Constant([0, 1, 2, 3, 4, 5]) 351 v2 = variables.Variable( 352 initial_value=lambda: init2(shape=(6, 1), dtype=dtypes.int64), 353 shape=(6, 1), 354 dtype=dtypes.int64) 355 356 self.assertIsInstance(v1, sharded_variable.ShardedVariable) 357 self.assertLen(v1.variables, 2) 358 self.assertRegex(v1.variables[0].device, "/job:ps/replica:0/task:0") 359 self.assertRegex(v1.variables[1].device, "/job:ps/replica:0/task:1") 360 self.assertAllEqual(v1.variables[0], [[0, 1], [2, 3], [4, 5]]) 361 self.assertAllEqual(v1.variables[1], [[6, 7], [8, 9]]) 362 363 self.assertIsInstance(v2, sharded_variable.ShardedVariable) 364 self.assertLen(v2.variables, 2) 365 self.assertRegex(v2.variables[0].device, "/job:ps/replica:0/task:0") 366 self.assertRegex(v2.variables[1].device, "/job:ps/replica:0/task:1") 367 self.assertAllEqual(v2.variables[0], [[0], [1], [2]]) 368 self.assertAllEqual(v2.variables[1], [[3], [4], [5]]) 369 370 def testBasicVariableWithAggregation(self): 371 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 372 self.cluster_resolver) 373 strategy.extended._allow_run_without_coordinator = True 374 with strategy.scope(): 375 v = variables.Variable( 376 initial_value=[0, 0, 0, 0, 0, 0, 0, 0], 377 dtype=dtypes.float32, 378 aggregation=variable_scope.VariableAggregation.SUM) 379 380 if strategy.num_replicas_in_sync > 1: 381 self.assertIsInstance(v, ps_values.AggregatingVariable) 382 else: 383 self.assertIsInstance(v, variables.Variable) 384 385 def replica_fn(): 386 replica_id = distribution_strategy_context.get_replica_context( 387 ).replica_id_in_sync_group 388 val = array_ops.reshape( 389 math_ops.cast(replica_id + 10, dtype=v.dtype), [1]) 390 v.assign( 391 array_ops.concat( 392 [val, constant_op.constant([1., 2., 3., 4., 5., 6., 7.])], 0)) 393 394 strategy.run(replica_fn) 395 396 expected_result = np.arange(8.) * strategy.num_replicas_in_sync 397 for i in range(strategy.num_replicas_in_sync): 398 expected_result[0] = expected_result[0] + i + 10 399 self.assertAllEqual(v, expected_result) 400 401 def testBasicShardedVariableWithAggregation(self): 402 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 403 self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2)) 404 strategy.extended._allow_run_without_coordinator = True 405 with strategy.scope(): 406 v = variables.Variable( 407 initial_value=[0, 0, 0, 0, 0, 0, 0, 0], 408 dtype=dtypes.float32, 409 aggregation=variable_scope.VariableAggregation.SUM) 410 411 self.assertIsInstance(v, sharded_variable.ShardedVariable) 412 self.assertLen(v.variables, 2) 413 if strategy.num_replicas_in_sync > 1: 414 self.assertIsInstance(v.variables[0], ps_values.AggregatingVariable) 415 else: 416 self.assertIsInstance(v.variables[0], variables.Variable) 417 418 def replica_fn(): 419 replica_id = distribution_strategy_context.get_replica_context( 420 ).replica_id_in_sync_group 421 val = array_ops.reshape( 422 math_ops.cast(replica_id + 10, dtype=v.dtype), [1]) 423 v.assign( 424 array_ops.concat( 425 [val, constant_op.constant([1., 2., 3., 4., 5., 6., 7.])], 0)) 426 427 strategy.run(replica_fn) 428 429 expected_result = np.arange(8.) * strategy.num_replicas_in_sync 430 for i in range(strategy.num_replicas_in_sync): 431 expected_result[0] = expected_result[0] + i + 10 432 expected_result = np.array_split(expected_result, 2) 433 self.assertAllEqual(expected_result[0], v.variables[0]) 434 self.assertAllEqual(expected_result[1], v.variables[1]) 435 436 def testNonCallableInitialValue(self): 437 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 438 self.cluster_resolver, sharded_variable.FixedShardsPartitioner(4)) 439 with strategy.scope(): 440 v = variables.Variable([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) 441 442 self.assertIsInstance(v, sharded_variable.ShardedVariable) 443 self.assertLen(v.variables, 4) 444 self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0") 445 self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1") 446 self.assertRegex(v.variables[2].device, "/job:ps/replica:0/task:0") 447 self.assertRegex(v.variables[3].device, "/job:ps/replica:0/task:1") 448 self.assertAllEqual(v.variables[0], [0, 1, 2]) 449 self.assertAllEqual(v.variables[1], [3, 4, 5]) 450 self.assertAllEqual(v.variables[2], [6, 7]) 451 self.assertAllEqual(v.variables[3], [8, 9]) 452 453 def testNumPartitionsLargerThanSize(self): 454 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 455 self.cluster_resolver, sharded_variable.FixedShardsPartitioner(4)) 456 with strategy.scope(): 457 v = variables.Variable([0, 1, 2]) 458 459 self.assertIsInstance(v, sharded_variable.ShardedVariable) 460 self.assertLen(v.variables, 3) 461 self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0") 462 self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1") 463 self.assertRegex(v.variables[2].device, "/job:ps/replica:0/task:0") 464 self.assertAllEqual(v.variables[0], [0]) 465 self.assertAllEqual(v.variables[1], [1]) 466 self.assertAllEqual(v.variables[2], [2]) 467 468 def testPartitionToOne(self): 469 # For small variables there is only one partition. 470 variable_partitioner = sharded_variable.MinSizePartitioner( 471 min_shard_bytes=64 << 20, max_shards=2) 472 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 473 self.cluster_resolver, variable_partitioner) 474 with strategy.scope(): 475 initializer = init_ops_v2.Constant([0] * 10) 476 v1 = variables.Variable( 477 initial_value=lambda: initializer(shape=(10,), dtype=dtypes.int64), 478 shape=(10,), 479 dtype=dtypes.int64) 480 481 v2 = variables.Variable([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) 482 483 self.assertIsInstance(v1, variables.Variable) 484 self.assertNotIsInstance(v1, sharded_variable.ShardedVariable) 485 self.assertRegex(v1.device, "/job:ps/replica:0/task:0") 486 self.assertAllEqual(v1, [0] * 10) 487 488 self.assertIsInstance(v2, variables.Variable) 489 self.assertNotIsInstance(v2, sharded_variable.ShardedVariable) 490 self.assertRegex(v2.device, "/job:ps/replica:0/task:1") 491 self.assertAllEqual(v2, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) 492 493 def testColocateWith(self): 494 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 495 self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2)) 496 with strategy.scope(): 497 v1 = variables.Variable([0, 1, 2, 3]) 498 499 with strategy.extended.colocate_vars_with(v1.variables[0]): 500 v2 = variables.Variable([4, 5]) 501 502 self.assertIsInstance(v1, sharded_variable.ShardedVariable) 503 504 self.assertIsInstance(v2, variables.Variable) 505 self.assertNotIsInstance(v2, sharded_variable.ShardedVariable) 506 self.assertEqual(v2.device, v1.variables[0].device) 507 self.assertAllEqual(v2, [4, 5]) 508 509 def testCustomPartitionAwareInitializer(self): 510 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 511 self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2)) 512 with strategy.scope(): 513 initializer = PartitionAwareIdentity() 514 initial_value = functools.partial( 515 initializer, shape=(4, 4), dtype=dtypes.int64) 516 v = variables.Variable( 517 initial_value=initial_value, shape=(4, 4), dtype=dtypes.int64) 518 519 self.assertIsInstance(v, sharded_variable.ShardedVariable) 520 self.assertLen(v.variables, 2) 521 self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0") 522 self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1") 523 self.assertAllEqual(v.variables[0], [[1, 0, 0, 0], [0, 1, 0, 0]]) 524 self.assertAllEqual(v.variables[1], [[0, 0, 1, 0], [0, 0, 0, 1]]) 525 526 def testPartitionWhenLackOfInfo(self): 527 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 528 self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2)) 529 with strategy.scope(): 530 initializer = init_ops_v2.Constant([0, 1, 2, 3]) 531 # Shape is not explicitly specified. 532 v1 = variables.Variable( 533 initial_value=lambda: initializer(shape=(4,), dtype=dtypes.int64), 534 dtype=dtypes.int64) 535 # Dtype is not explicitly specified. 536 v2 = variables.Variable( 537 initial_value=lambda: initializer(shape=(4,), dtype=dtypes.int64), 538 shape=(4,)) 539 # Neither shape nor dtype is explicitly specified. 540 v3 = variables.Variable( 541 initial_value=lambda: initializer(shape=(4,), dtype=dtypes.int64)) 542 543 for v in [v1, v2, v3]: 544 self.assertIsInstance(v, sharded_variable.ShardedVariable) 545 self.assertLen(v.variables, 2) 546 self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0") 547 self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1") 548 self.assertAllEqual(v.variables[0], [0, 1]) 549 self.assertAllEqual(v.variables[1], [2, 3]) 550 551 def testInvalidPartitioner(self): 552 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 553 self.cluster_resolver, lambda shape, dtype: None) 554 with self.assertRaisesRegex(ValueError, "variable_partitioner"): 555 with strategy.scope(): 556 variables.Variable([[[0, 1], [2, 3]], [[0, 1], [2, 3]]]) 557 558 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 559 self.cluster_resolver, lambda shape, dtype: []) 560 with self.assertRaisesRegex(ValueError, "variable_partitioner"): 561 with strategy.scope(): 562 variables.Variable([[[0, 1], [2, 3]], [[0, 1], [2, 3]]]) 563 564 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 565 self.cluster_resolver, lambda shape, dtype: [0, 1, 1]) 566 with self.assertRaisesRegex(ValueError, "variable_partitioner"): 567 with strategy.scope(): 568 variables.Variable([[[0, 1], [2, 3]], [[0, 1], [2, 3]]]) 569 570 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 571 self.cluster_resolver, lambda shape, dtype: [2, 2, 1]) 572 with self.assertRaisesRegex(ValueError, "variable_partitioner"): 573 with strategy.scope(): 574 variables.Variable([[[0, 1], [2, 3]], [[0, 1], [2, 3]]]) 575 576 def testCreateInsideTFFunction(self): 577 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 578 self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2)) 579 580 collection = [] 581 582 @def_function.function 583 def create_vars(): 584 if not collection: 585 identity = init_ops_v2.Identity() 586 v1 = variables.Variable([[1., 0.], [0., 1.]], dtype=dtypes.float32) 587 v2 = variables.Variable(lambda: identity((2, 2), dtypes.float32)) 588 v3 = variables.Variable( 589 lambda: identity((2, 2), dtypes.float32), 590 dtype=dtypes.float32, 591 shape=(2, 2)) 592 collection.extend([v1, v2, v3]) 593 594 with strategy.scope(): 595 create_vars() 596 for v in collection: 597 self.assertIsInstance(v, sharded_variable.ShardedVariable) 598 self.assertLen(v.variables, 2) 599 self.assertRegex(v.variables[0].device, "/job:ps/replica:0/task:0") 600 self.assertRegex(v.variables[1].device, "/job:ps/replica:0/task:1") 601 self.assertAllEqual(v.variables[0], [[1., 0.]]) 602 self.assertAllEqual(v.variables[1], [[0., 1.]]) 603 604 @parameterized.named_parameters( 605 ("Restore", False, 2), 606 ("RestoreDiffShards", False, 4), 607 ("DelayedRestore", True, 2), 608 ("DelayedRestoreDiffShards", True, 4), 609 ) 610 def testCheckpoint(self, delayed, restore_shards): 611 612 def make_variable(name, shape, dtype, initializer): 613 initial_value = functools.partial(initializer, shape, dtype=dtype) 614 return variables.Variable( 615 name=name, initial_value=initial_value, shape=shape, dtype=dtype) 616 617 class Model(tracking.AutoTrackable): 618 619 def build(self): 620 self.w = self._add_variable_with_custom_getter( 621 "w", 622 shape=(4,), 623 initializer=init_ops_v2.Ones(), 624 getter=make_variable) 625 626 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 627 self.cluster_resolver, sharded_variable.FixedShardsPartitioner(2)) 628 ckpt_dir = os.path.join(self.get_temp_dir(), "checkpoint") 629 630 with strategy.scope(): 631 model1 = Model() 632 model1.build() 633 self.assertIsInstance(model1.w, sharded_variable.ShardedVariable) 634 self.assertLen(model1.w.variables, 2) 635 model1.w.assign([1., 2., 3., 4.]) 636 637 cp1 = tracking_util.Checkpoint(model=model1) 638 cp1.write(ckpt_dir) 639 640 strategy = parameter_server_strategy_v2.ParameterServerStrategyV2( 641 self.cluster_resolver, 642 sharded_variable.FixedShardsPartitioner(restore_shards)) 643 644 with strategy.scope(): 645 model2 = Model() 646 cp2 = tracking_util.Checkpoint(model=model2) 647 if delayed: 648 cp2.restore(ckpt_dir) 649 model2.build() 650 else: 651 model2.build() 652 cp2.restore(ckpt_dir) 653 self.assertIsInstance(model2.w, sharded_variable.ShardedVariable) 654 self.assertLen(model2.w.variables, restore_shards) 655 if restore_shards == 2: 656 self.assertAllEqual(model2.w.variables[0], [1., 2.]) 657 self.assertAllEqual(model2.w.variables[1], [3., 4.]) 658 elif restore_shards == 4: 659 self.assertAllEqual(model2.w.variables[0], [1.]) 660 self.assertAllEqual(model2.w.variables[1], [2.]) 661 self.assertAllEqual(model2.w.variables[2], [3.]) 662 self.assertAllEqual(model2.w.variables[3], [4.]) 663 664 665class ClusterTypeNameTest(test.TestCase): 666 667 def testArbitraryJobName(self): 668 cluster_def = multi_worker_test_base.create_cluster_spec( 669 num_workers=1, num_ps=1, has_chief=True) 670 cluster_def["some_arbitrary_name"] = [ 671 "localhost:%d" % multi_worker_test_base.pick_unused_port() 672 ] 673 cluster_resolver = SimpleClusterResolver( 674 ClusterSpec(cluster_def), rpc_layer="grpc") 675 with self.assertRaisesRegexp(ValueError, "Disallowed task type found in"): 676 parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver) 677 678 def testArbitraryCurrentTaskType(self): 679 cluster_def = multi_worker_test_base.create_cluster_spec( 680 num_workers=1, num_ps=1, has_chief=True) 681 cluster_resolver = SimpleClusterResolver( 682 ClusterSpec(cluster_def), rpc_layer="grpc", task_type="foobar") 683 with self.assertRaisesRegexp(ValueError, "Unrecognized task_type: foobar"): 684 parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver) 685 686 def testMoreThanOneChief(self): 687 cluster_def = multi_worker_test_base.create_cluster_spec( 688 num_workers=1, num_ps=1) 689 chief_ports = [multi_worker_test_base.pick_unused_port() for _ in range(3)] 690 cluster_def["chief"] = ["localhost:%s" % port for port in chief_ports] 691 cluster_resolver = SimpleClusterResolver( 692 ClusterSpec(cluster_def), 693 rpc_layer="grpc", 694 task_type="chief", 695 task_id=1) 696 with self.assertRaisesRegexp(ValueError, 697 "There must be at most one 'chief' job."): 698 parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver) 699 700 def testLessThanOneWorker(self): 701 cluster_def = multi_worker_test_base.create_cluster_spec( 702 num_workers=0, num_ps=1, has_chief=True) 703 cluster_resolver = SimpleClusterResolver( 704 ClusterSpec(cluster_def), rpc_layer="grpc", task_type="ps", task_id=0) 705 with self.assertRaisesRegexp(ValueError, 706 "There must be at least one worker."): 707 parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver) 708 709 def testLessThanOnePs(self): 710 cluster_def = multi_worker_test_base.create_cluster_spec( 711 num_workers=1, num_ps=0, has_chief=True) 712 cluster_resolver = SimpleClusterResolver( 713 ClusterSpec(cluster_def), 714 rpc_layer="grpc", 715 task_type="worker", 716 task_id=0) 717 with self.assertRaisesRegexp(ValueError, "There must be at least one ps."): 718 parameter_server_strategy_v2.ParameterServerStrategyV2(cluster_resolver) 719 720 721if __name__ == "__main__": 722 v2_compat.enable_v2_behavior() 723 test.main() 724