1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for ParameterServerStrategy.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22import threading 23 24from absl.testing import parameterized 25from tensorflow.core.protobuf import config_pb2 26from tensorflow.python.data.ops import dataset_ops 27from tensorflow.python.distribute import central_storage_strategy 28from tensorflow.python.distribute import combinations 29from tensorflow.python.distribute import device_util 30from tensorflow.python.distribute import distribute_lib 31from tensorflow.python.distribute import distribute_utils 32from tensorflow.python.distribute import distribution_strategy_context as ds_context 33from tensorflow.python.distribute import input_lib 34from tensorflow.python.distribute import multi_worker_test_base 35from tensorflow.python.distribute import multi_worker_util 36from tensorflow.python.distribute import parameter_server_strategy 37from tensorflow.python.distribute import ps_values 38from tensorflow.python.distribute import reduce_util 39from tensorflow.python.distribute import strategy_test_lib 40from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver 41from tensorflow.python.eager import backprop 42from tensorflow.python.eager import context 43from tensorflow.python.estimator import run_config 44from tensorflow.python.framework import constant_op 45from tensorflow.python.framework import device as tf_device 46from tensorflow.python.framework import dtypes 47from tensorflow.python.framework import errors 48from tensorflow.python.framework import ops 49from tensorflow.python.framework import tensor_util 50from tensorflow.python.ops import array_ops 51from tensorflow.python.ops import control_flow_ops 52from tensorflow.python.ops import gradients 53from tensorflow.python.ops import math_ops 54from tensorflow.python.ops import partitioned_variables 55from tensorflow.python.ops import resource_variable_ops 56from tensorflow.python.ops import variable_scope 57from tensorflow.python.ops import variables 58from tensorflow.python.platform import test 59from tensorflow.python.training import training_util 60 61CHIEF = run_config.TaskType.CHIEF 62WORKER = run_config.TaskType.WORKER 63PS = run_config.TaskType.PS 64 65 66def _get_replica_id_integer(): 67 replica_id = ds_context.get_replica_context().replica_id_in_sync_group 68 if isinstance(replica_id, ops.Tensor): 69 replica_id = tensor_util.constant_value(replica_id) 70 return replica_id 71 72 73def create_test_objects(cluster_spec=None, 74 task_type=None, 75 task_id=None, 76 num_gpus=None, 77 sess_config=None): 78 sess_config = sess_config or config_pb2.ConfigProto() 79 if num_gpus is None: 80 num_gpus = context.num_gpus() 81 if cluster_spec and task_type and task_id is not None: 82 cluster_resolver = SimpleClusterResolver( 83 cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec), 84 task_type=task_type, 85 task_id=task_id, 86 num_accelerators={'GPU': num_gpus}) 87 distribution = parameter_server_strategy.ParameterServerStrategyV1( 88 cluster_resolver) 89 target = 'grpc://' + cluster_spec[WORKER][task_id] 90 else: 91 distribution = ( 92 central_storage_strategy.CentralStorageStrategy._from_num_gpus(num_gpus) 93 ) 94 target = '' 95 96 sess_config = copy.deepcopy(sess_config) 97 sess_config = distribution.update_config_proto(sess_config) 98 99 return distribution, target, sess_config 100 101 102class ParameterServerStrategyTestBase( 103 multi_worker_test_base.MultiWorkerTestBase): 104 105 def setUp(self): 106 self._result = 0 107 self._lock = threading.Lock() 108 self._init_condition = threading.Condition() 109 self._init_reached = 0 110 self._finish_condition = threading.Condition() 111 self._finish_reached = 0 112 self._sess_config = config_pb2.ConfigProto(allow_soft_placement=True) 113 super(ParameterServerStrategyTestBase, self).setUp() 114 115 def _get_test_objects(self, task_type, task_id, num_gpus): 116 return create_test_objects( 117 cluster_spec=self._cluster_spec, 118 task_type=task_type, 119 task_id=task_id, 120 num_gpus=num_gpus, 121 sess_config=self._sess_config) 122 123 def _test_device_assignment_distributed(self, task_type, task_id, num_gpus): 124 worker_device = '/job:%s/replica:0/task:%d' % (task_type, task_id) 125 d, _, sess_config = self._get_test_objects(task_type, task_id, num_gpus) 126 with ops.Graph().as_default(), \ 127 self.cached_session(target=self._default_target, 128 config=sess_config) as sess, \ 129 d.scope(): 130 131 # Define a variable outside the call_for_each_replica scope. 132 n = variable_scope.get_variable('n', initializer=10.0) 133 self.assertEqual(n.device, '/job:ps/task:0') 134 135 def model_fn(): 136 if num_gpus == 0: 137 last_part_device = 'device:CPU:0' 138 else: 139 replica_id = _get_replica_id_integer() 140 last_part_device = ('device:GPU:%d' % replica_id) 141 142 a = constant_op.constant(1.0) 143 b = constant_op.constant(2.0) 144 c = a + b 145 self.assertEqual(a.device, worker_device + '/' + last_part_device) 146 self.assertEqual(b.device, worker_device + '/' + last_part_device) 147 self.assertEqual(c.device, worker_device + '/' + last_part_device) 148 149 # The device scope is ignored for variables but not for normal ops. 150 with ops.device('/job:worker/task:0'): 151 x = variable_scope.get_variable( 152 'x', initializer=10.0, 153 aggregation=variable_scope.VariableAggregation.SUM) 154 x_add = x.assign_add(c) 155 e = a + c 156 # The variable x is on the task 1 since the device_function has been 157 # called once before the model_fn. 158 self.assertEqual(x.device, '/job:ps/task:1') 159 self.assertEqual(x_add.device, x.device) 160 self.assertEqual(e.device, 161 '/job:worker/replica:0/task:0/%s' % last_part_device) 162 163 # The colocate_vars_with can override the distribution's device. 164 with d.extended.colocate_vars_with(x): 165 y = variable_scope.get_variable( 166 'y', initializer=20.0, 167 aggregation=variable_scope.VariableAggregation.SUM) 168 # We add an identity here to avoid complaints about summing 169 # non-distributed values. 170 y_add = y.assign_add(array_ops.identity(x_add)) 171 self.assertEqual(y.device, '/job:ps/task:1') 172 self.assertEqual(y_add.device, y.device) 173 self.assertEqual(y.device, x.device) 174 175 z = variable_scope.get_variable( 176 'z', initializer=10.0, 177 aggregation=variable_scope.VariableAggregation.SUM) 178 self.assertEqual(z.device, '/job:ps/task:0') 179 self.assertNotEqual(z.device, x.device) 180 181 with ops.control_dependencies([y_add]): 182 # We add an identity here to avoid complaints about summing 183 # non-distributed values. 184 z_add = z.assign_add(array_ops.identity(y)) 185 with ops.control_dependencies([z_add]): 186 f = z + c 187 self.assertEqual(f.device, worker_device + '/' + last_part_device) 188 189 # The device scope would merge with the default worker device. 190 with ops.device('/CPU:1'): 191 g = e + 1.0 192 self.assertEqual(g.device, worker_device + '/device:CPU:1') 193 194 # This ops.colocate_with will be ignored when defining a variable but not 195 # for a normal tensor. 196 with ops.colocate_with(x): 197 u = variable_scope.get_variable('u', initializer=30.0) 198 v = variable_scope.get_variable('v', initializer=30.0) 199 h = f + 1.0 200 self.assertIn('/job:ps/', u.device) 201 self.assertIn('/job:ps/', v.device) 202 # u and v are on different parameter servers. 203 self.assertTrue(u.device != x.device or v.device != x.device) 204 self.assertTrue(u.device == x.device or v.device == x.device) 205 # Here h is not on one worker. Note h.device is canonical while x.device 206 # is not but. 207 self.assertIn('/job:ps/', h.device) 208 return y_add, z_add, f 209 210 y, z, f = d.extended.call_for_each_replica(model_fn) 211 self.assertNotEqual(y, None) 212 self.assertNotEqual(z, None) 213 self.assertNotEqual(f, None) 214 215 if context.num_gpus() >= 1 and num_gpus <= 1: 216 self.evaluate(variables.global_variables_initializer()) 217 y_val, z_val, f_val = sess.run([y, z, f]) 218 self.assertEqual(y_val, 33.0) 219 self.assertEqual(z_val, 43.0) 220 self.assertEqual(f_val, 46.0) 221 222 def _test_device_assignment_distributed_enable_partitioner( 223 self, task_type, task_id, num_gpus): 224 d, _, sess_config = self._get_test_objects(task_type, task_id, num_gpus) 225 num_shards = len(d.extended.parameter_devices) 226 partitioner = partitioned_variables.fixed_size_partitioner(num_shards) 227 with ops.Graph().as_default(), \ 228 self.cached_session(target=self._default_target, 229 config=sess_config) as sess, \ 230 d.scope(): 231 232 n = variable_scope.get_variable( 233 'n', 234 initializer=constant_op.constant([10.0, 20.0]), 235 aggregation=variable_scope.VariableAggregation.SUM, 236 partitioner=partitioner) 237 238 for part_id, var in enumerate(n): 239 self.assertEqual(var.device, '/job:ps/task:%d' % part_id) 240 241 def model_fn(): 242 a = constant_op.constant([3.0, 5.0]) 243 # The device scope is ignored for variables but not for normal ops. 244 with ops.device('/job:worker/task:0'): 245 x = variable_scope.get_variable( 246 'x', 247 initializer=constant_op.constant([10.0, 20.0]), 248 aggregation=variable_scope.VariableAggregation.SUM, 249 partitioner=partitioner) 250 x_add = x.assign_add(a, name='x_add') 251 # The variable x is on the task 1 since the device_function has been 252 # called once before the model_fn. 253 for part_id, var in enumerate(x): 254 self.assertEqual(var.device, '/job:ps/task:%d' % part_id) 255 self.assertEqual(var.device, x_add[part_id].device) 256 257 return x_add 258 259 x = d.extended.call_for_each_replica(model_fn) 260 261 if context.num_gpus() >= 1: 262 self.evaluate(variables.global_variables_initializer()) 263 x_val = sess.run(x) 264 if num_gpus < 1: 265 self.assertEqual(x_val, [13.0, 25.0]) 266 else: 267 x_expect = [10.0 + 3 * num_gpus, 20.0 + 5 * num_gpus] 268 self.assertEqual(x_val, x_expect) 269 270 def _test_device_assignment_local(self, 271 d, 272 compute_device='CPU', 273 variable_device='CPU', 274 num_gpus=0): 275 with ops.Graph().as_default(), \ 276 self.cached_session(target=self._default_target, 277 config=self._sess_config) as sess, \ 278 d.scope(): 279 280 def model_fn(): 281 if 'CPU' in compute_device: 282 replica_compute_device = '/device:CPU:0' 283 else: 284 replica_id = _get_replica_id_integer() 285 replica_compute_device = ('/device:GPU:%d' % replica_id) 286 replica_compute_device = device_util.canonicalize( 287 replica_compute_device) 288 289 if 'CPU' in variable_device: 290 replica_variable_device = '/device:CPU:0' 291 else: 292 replica_id = _get_replica_id_integer() 293 replica_variable_device = ('/device:GPU:%d' % replica_id) 294 replica_variable_device = device_util.canonicalize( 295 replica_variable_device) 296 297 a = constant_op.constant(1.0) 298 b = constant_op.constant(2.0) 299 c = a + b 300 self.assertEqual(a.device, replica_compute_device) 301 self.assertEqual(b.device, replica_compute_device) 302 self.assertEqual(c.device, replica_compute_device) 303 304 # The device scope is ignored for variables but not for normal ops. 305 with ops.device('/device:GPU:2'): 306 x = variable_scope.get_variable( 307 'x', initializer=10.0, 308 aggregation=variable_scope.VariableAggregation.SUM) 309 x_add = x.assign_add(c) 310 e = a + c 311 self.assertEqual( 312 device_util.canonicalize(x.device), replica_variable_device) 313 self.assertEqual(x_add.device, x.device) 314 self.assertEqual(e.device, device_util.canonicalize('/device:GPU:2')) 315 316 # The colocate_vars_with can override the distribution's device. 317 with d.extended.colocate_vars_with(x): 318 y = variable_scope.get_variable( 319 'y', initializer=20.0, 320 aggregation=variable_scope.VariableAggregation.SUM) 321 # We add an identity here to avoid complaints about summing 322 # non-distributed values. 323 y_add = y.assign_add(array_ops.identity(x_add)) 324 self.assertEqual( 325 device_util.canonicalize(y.device), replica_variable_device) 326 self.assertEqual(y_add.device, y.device) 327 self.assertEqual(y.device, x.device) 328 329 z = variable_scope.get_variable( 330 'z', initializer=10.0, 331 aggregation=variable_scope.VariableAggregation.SUM) 332 self.assertEqual( 333 device_util.canonicalize(z.device), replica_variable_device) 334 335 with ops.control_dependencies([y_add]): 336 # We add an identity here to avoid complaints about summing 337 # non-distributed values. 338 z_add = z.assign_add(array_ops.identity(y)) 339 with ops.control_dependencies([z_add]): 340 f = z + c 341 self.assertEqual(f.device, replica_compute_device) 342 343 # The device scope would merge with the default worker device. 344 with ops.device('/CPU:1'): 345 g = e + 1.0 346 self.assertEqual(g.device, device_util.canonicalize('/device:CPU:1')) 347 348 # This ops.colocate_with will be ignored when defining a variable but not 349 # for a normal tensor. 350 with ops.colocate_with(x): 351 u = variable_scope.get_variable('u', initializer=30.0) 352 h = f + 1.0 353 self.assertEqual( 354 device_util.canonicalize(u.device), replica_variable_device) 355 self.assertEqual( 356 device_util.canonicalize(x.device), 357 device_util.canonicalize(h.device)) 358 return y_add, z_add, f 359 360 y, z, f = d.extended.call_for_each_replica(model_fn) 361 self.assertNotEqual(y, None) 362 self.assertNotEqual(z, None) 363 self.assertNotEqual(f, None) 364 365 if context.num_gpus() >= 1 and num_gpus <= 1: 366 self.evaluate(variables.global_variables_initializer()) 367 y_val, z_val, f_val = sess.run([y, z, f]) 368 self.assertEqual(y_val, 33.0) 369 self.assertEqual(z_val, 43.0) 370 self.assertEqual(f_val, 46.0) 371 372 def _test_simple_increment(self, task_type, task_id, num_gpus): 373 d, master_target, sess_config = self._get_test_objects( 374 task_type, task_id, num_gpus) 375 if d.extended._cluster_spec: 376 num_workers = len(d.extended._cluster_spec.as_dict().get(WORKER)) 377 if 'chief' in d.extended._cluster_spec.as_dict(): 378 num_workers += 1 379 else: 380 num_workers = 1 381 with ops.Graph().as_default(), \ 382 self.cached_session(target=master_target, 383 config=sess_config) as sess, \ 384 d.scope(): 385 386 def model_fn(): 387 x = variable_scope.get_variable( 388 'x', initializer=10.0, 389 aggregation=variable_scope.VariableAggregation.SUM) 390 y = variable_scope.get_variable( 391 'y', initializer=20.0, 392 aggregation=variable_scope.VariableAggregation.SUM) 393 z = variable_scope.get_variable( 394 'z', initializer=30.0, 395 aggregation=variable_scope.VariableAggregation.ONLY_FIRST_REPLICA) 396 397 # We explicitly make a constant tensor here to avoid complaints about 398 # summing non-distributed values. 399 one = constant_op.constant(1.0) 400 x_add = x.assign_add(one, use_locking=True) 401 y_add = y.assign_add(one, use_locking=True) 402 z_add = z.assign_add(one, use_locking=True) 403 404 train_op = control_flow_ops.group(x_add, y_add, z_add) 405 return x, y, z, train_op 406 407 x, y, z, train_op = d.extended.call_for_each_replica(model_fn) 408 train_op = d.group(train_op) 409 410 if task_id == 0: 411 self.evaluate(variables.global_variables_initializer()) 412 413 # Workers waiting for chief worker's initializing variables. 414 self._init_condition.acquire() 415 self._init_reached += 1 416 while self._init_reached != num_workers: 417 self._init_condition.wait() 418 self._init_condition.notify_all() 419 self._init_condition.release() 420 421 sess.run(train_op) 422 423 # Wait for other workers to finish training. 424 self._finish_condition.acquire() 425 self._finish_reached += 1 426 while self._finish_reached != num_workers: 427 self._finish_condition.wait() 428 self._finish_condition.notify_all() 429 self._finish_condition.release() 430 431 x_val, y_val, z_val = sess.run([x, y, z]) 432 self.assertEqual(x_val, 10.0 + 1.0 * num_workers * d.num_replicas_in_sync) 433 self.assertEqual(y_val, 20.0 + 1.0 * num_workers * d.num_replicas_in_sync) 434 self.assertEqual(z_val, 30.0 + 1.0 * num_workers) 435 436 def _test_minimize_loss_graph(self, task_type, task_id, num_gpus): 437 d, master_target, sess_config = self._get_test_objects( 438 task_type, task_id, num_gpus) 439 if task_type: 440 # Multi-worker 441 assert hasattr(d.extended, '_cluster_spec') and d.extended._cluster_spec 442 num_workers = len(d.extended._cluster_spec.as_dict().get(WORKER)) 443 if CHIEF in d.extended._cluster_spec.as_dict(): 444 num_workers += 1 445 else: 446 # local 447 num_workers = 1 448 449 with ops.Graph().as_default(), \ 450 self.cached_session(target=master_target, 451 config=sess_config) as sess, \ 452 d.scope(): 453 kernel = strategy_test_lib.create_variable_like_keras_layer( 454 'kernel', (1, 1), dtypes.float32,) 455 456 def loss_fn(x): 457 y = array_ops.reshape( 458 math_ops.matmul(x, kernel), []) - constant_op.constant(1.) 459 return y * y 460 461 # TODO(yuefengz, apassos): eager.backprop.implicit_grad is not safe for 462 # multiple graphs (b/111216820). 463 def grad_fn(x): 464 loss = loss_fn(x) 465 var_list = ( 466 variables.trainable_variables() + ops.get_collection( 467 ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) 468 grads = gradients.gradients(loss, var_list) 469 ret = list(zip(grads, var_list)) 470 return ret 471 472 def update(v, g): 473 return v.assign_sub(0.05 * g, use_locking=True) 474 475 one = constant_op.constant([[1.]]) 476 477 def step(): 478 """Perform one optimization step.""" 479 # Run forward & backward to get gradients, variables list. 480 g_v = d.extended.call_for_each_replica(grad_fn, args=(one,)) 481 # Update the variables using the gradients and the update() function. 482 before_list = [] 483 after_list = [] 484 for g, v in g_v: 485 fetched = d.extended.read_var(v) 486 before_list.append(fetched) 487 with ops.control_dependencies([fetched]): 488 # TODO(yuefengz): support non-Mirrored variable as destinations. 489 g = d.extended.reduce_to( 490 reduce_util.ReduceOp.SUM, g, destinations=v) 491 with ops.control_dependencies( 492 d.extended.update(v, update, args=(g,), group=False)): 493 after_list.append(d.extended.read_var(v)) 494 return before_list, after_list 495 496 before_out, after_out = step() 497 498 if (not task_type or 499 multi_worker_util.is_chief( 500 d.extended._cluster_spec, task_type, task_id)): 501 self.evaluate(variables.global_variables_initializer()) 502 503 # Workers waiting for chief worker's initializing variables. 504 self._init_condition.acquire() 505 self._init_reached += 1 506 while self._init_reached != num_workers: 507 self._init_condition.wait() 508 self._init_condition.notify_all() 509 self._init_condition.release() 510 511 for i in range(10): 512 b, a = sess.run((before_out, after_out)) 513 if i == 0: 514 before, = b 515 after, = a 516 517 error_before = abs(before - 1) 518 error_after = abs(after - 1) 519 # Error should go down 520 self.assertLess(error_after, error_before) 521 522 def _test_input_fn_iterator(self, 523 task_type, 524 task_id, 525 num_gpus, 526 input_fn, 527 expected_values, 528 test_reinitialize=True, 529 ignore_order=False): 530 distribution, master_target, config = self._get_test_objects( 531 task_type, task_id, num_gpus) 532 devices = distribution.extended.worker_devices 533 534 with ops.Graph().as_default(), \ 535 self.cached_session(config=config, 536 target=master_target) as sess: 537 iterator = distribution.make_input_fn_iterator(input_fn) 538 sess.run(iterator.initializer) 539 540 for expected_value in expected_values: 541 next_element = iterator.get_next() 542 computed_value = sess.run([distribute_utils.select_replica( 543 r, next_element) for r in range(len(devices))]) 544 if ignore_order: 545 self.assertCountEqual(expected_value, computed_value) 546 else: 547 self.assertEqual(expected_value, computed_value) 548 549 with self.assertRaises(errors.OutOfRangeError): 550 next_element = iterator.get_next() 551 sess.run([distribute_utils.select_replica(r, next_element) 552 for r in range(len(devices))]) 553 554 # After re-initializing the iterator, should be able to iterate again. 555 if test_reinitialize: 556 sess.run(iterator.initializer) 557 558 for expected_value in expected_values: 559 next_element = iterator.get_next() 560 computed_value = sess.run([distribute_utils.select_replica( 561 r, next_element) for r in range(len(devices))]) 562 if ignore_order: 563 self.assertCountEqual(expected_value, computed_value) 564 else: 565 self.assertEqual(expected_value, computed_value) 566 567 568class ParameterServerStrategyTest( 569 ParameterServerStrategyTestBase, 570 strategy_test_lib.DistributionTestBase, 571 strategy_test_lib.TwoDeviceDistributionTestBase, 572 parameterized.TestCase): 573 574 @classmethod 575 def setUpClass(cls): 576 cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( 577 num_workers=3, num_ps=2) 578 cls._default_target = 'grpc://' + cls._cluster_spec[WORKER][0] 579 580 @combinations.generate(combinations.combine(mode=['graph'])) 581 def test_num_replicas_in_sync(self): 582 strategy, _, _ = create_test_objects(num_gpus=2) 583 # All the devices on a given worker are in sync which in this case is the 584 # number of gpus on each worker. 585 self.assertEqual(2, strategy.num_replicas_in_sync) 586 587 @combinations.generate(combinations.combine(mode=['graph'])) 588 def testDeviceAssignmentLocalCPU(self): 589 strategy, _, _ = create_test_objects(num_gpus=0) 590 self._test_device_assignment_local( 591 strategy, compute_device='CPU', variable_device='CPU', num_gpus=0) 592 593 @combinations.generate(combinations.combine(mode=['graph'])) 594 def testDeviceAssignmentLocalOneGPU(self): 595 strategy, _, _ = create_test_objects(num_gpus=1) 596 self._test_device_assignment_local( 597 strategy, compute_device='GPU', variable_device='GPU', num_gpus=1) 598 599 @combinations.generate(combinations.combine(mode=['graph'])) 600 def testDeviceAssignmentLocalTwoGPUs(self): 601 strategy, _, _ = create_test_objects(num_gpus=2) 602 self._test_device_assignment_local( 603 strategy, compute_device='GPU', variable_device='CPU', num_gpus=2) 604 605 @combinations.generate( 606 combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) 607 def testDeviceAssignmentDistributed(self, num_gpus): 608 self._test_device_assignment_distributed('worker', 1, num_gpus) 609 610 @combinations.generate( 611 combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) 612 def testDeviceAssignmentDistributedEnablePartitioner(self, num_gpus): 613 self._test_device_assignment_distributed_enable_partitioner( 614 'worker', 1, num_gpus) 615 616 @combinations.generate(combinations.combine(mode=['graph'])) 617 def testSimpleBetweenGraph(self): 618 self._run_between_graph_clients(self._test_simple_increment, 619 self._cluster_spec, context.num_gpus()) 620 621 @combinations.generate( 622 combinations.combine(mode=['graph'], required_gpus=[0, 1, 2])) 623 def testLocalSimpleIncrement(self, required_gpus): 624 self._test_simple_increment(None, 0, required_gpus) 625 626 @combinations.generate( 627 combinations.combine(mode=['graph'], required_gpus=[0, 1, 2])) 628 def testMinimizeLossGraphDistributed(self, required_gpus): 629 self._run_between_graph_clients(self._test_minimize_loss_graph, 630 self._cluster_spec, required_gpus) 631 632 @combinations.generate( 633 combinations.combine(mode=['graph'], required_gpus=[0, 1, 2])) 634 def testMinimizeLossGraphLocal(self, required_gpus): 635 self._test_minimize_loss_graph(None, None, required_gpus) 636 637 # TODO(priyag): Refactor this and other multi worker tests. 638 @combinations.generate( 639 combinations.combine( 640 mode=['graph'], required_gpus=[1, 2], use_dataset=[True, False])) 641 def testMakeInputFnIteratorDistributed(self, required_gpus, use_dataset): 642 if use_dataset: 643 fn = lambda: dataset_ops.Dataset.range(100) 644 else: 645 def fn(): 646 dataset = dataset_ops.Dataset.range(100) 647 it = dataset_ops.make_one_shot_iterator(dataset) 648 return it.get_next 649 650 expected_values = [[i + j 651 for j in range(required_gpus)] 652 for i in range(0, 100, required_gpus)] 653 654 input_fn = self._input_fn_to_test_input_context( 655 fn, 656 expected_num_replicas_in_sync=required_gpus, 657 expected_num_input_pipelines=3, 658 expected_input_pipeline_id=1) # because task_id = 1 659 self._test_input_fn_iterator( 660 'worker', 661 1, 662 required_gpus, 663 input_fn, 664 expected_values, 665 test_reinitialize=use_dataset, 666 ignore_order=not use_dataset) 667 668 @combinations.generate( 669 combinations.combine( 670 mode=['graph'], required_gpus=[1, 2], use_dataset=[True, False])) 671 def testMakeInputFnIteratorLocal(self, required_gpus, use_dataset): 672 if use_dataset: 673 fn = lambda: dataset_ops.Dataset.range(100) 674 else: 675 676 def fn(): 677 dataset = dataset_ops.Dataset.range(100) 678 it = dataset_ops.make_one_shot_iterator(dataset) 679 return it.get_next 680 681 expected_values = [[i + j 682 for j in range(required_gpus)] 683 for i in range(0, 100, required_gpus)] 684 685 input_fn = self._input_fn_to_test_input_context( 686 fn, 687 expected_num_replicas_in_sync=required_gpus, 688 expected_num_input_pipelines=1, 689 expected_input_pipeline_id=0) # only one worker and pipeline for local. 690 self._test_input_fn_iterator( 691 None, 692 None, 693 required_gpus, 694 input_fn, 695 expected_values, 696 test_reinitialize=use_dataset, 697 ignore_order=not use_dataset) 698 699 @combinations.generate(combinations.combine(mode=['graph'])) 700 def testGlobalStepUpdate(self): 701 strategy, _, _ = create_test_objects() 702 self._test_global_step_update(strategy) 703 704 @combinations.generate(combinations.combine(mode=['graph'])) 705 def testUpdateConfigProtoMultiWorker(self): 706 strategy, _, _ = create_test_objects( 707 cluster_spec=self._cluster_spec, 708 task_type='worker', 709 task_id=1, 710 num_gpus=2) 711 712 config_proto = config_pb2.ConfigProto(device_filters=['to_be_overridden']) 713 714 new_config = strategy.update_config_proto(config_proto) 715 716 # Verify device filters. 717 self.assertEqual(['/job:worker/task:1', '/job:ps'], 718 new_config.device_filters) 719 720 # Verify isolate_session_state 721 self.assertFalse(new_config.isolate_session_state) 722 723 @combinations.generate(combinations.combine(mode=['graph'])) 724 def testUpdateConfigProtoLocal(self): 725 strategy, _, _ = create_test_objects(num_gpus=2) 726 727 config_proto = config_pb2.ConfigProto() 728 new_config = strategy.update_config_proto(config_proto) 729 730 # Verify isolate_session_state 731 self.assertTrue(new_config.isolate_session_state) 732 733 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 734 def testInMultiWorkerMode(self): 735 strategy, _, _ = create_test_objects( 736 cluster_spec=self._cluster_spec, 737 task_type='worker', 738 task_id=1, 739 num_gpus=0) 740 self.assertTrue(strategy.extended._in_multi_worker_mode()) 741 742 @combinations.generate(combinations.combine(mode=['eager'])) 743 def testEagerCustomTrainingUnimplementedError(self): 744 cluster_spec = multi_worker_test_base.create_in_process_cluster( 745 num_workers=3, num_ps=2) 746 cluster_resolver = SimpleClusterResolver( 747 cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec), 748 task_type='worker', 749 task_id=1, 750 num_accelerators={'GPU': 0}) 751 strategy = parameter_server_strategy.ParameterServerStrategyV1( 752 cluster_resolver) 753 dataset = dataset_ops.DatasetV2.from_tensor_slices([5., 6., 7., 8.]) 754 755 def train_step(data): 756 return math_ops.square(data) 757 758 self.assertRaisesRegex(NotImplementedError, 'ParameterServerStrategy*', 759 strategy.experimental_distribute_dataset, 760 dataset.batch(2)) 761 762 self.assertRaisesRegex(NotImplementedError, 'ParameterServerStrategy*', 763 strategy.distribute_datasets_from_function, 764 lambda _: dataset) 765 766 self.assertRaisesRegex(NotImplementedError, 'ParameterServerStrategy*', 767 strategy.scope) 768 769 self.assertRaisesRegex(NotImplementedError, 'ParameterServerStrategy*', 770 strategy.run, train_step) 771 772 @combinations.generate(combinations.combine( 773 mode=['graph'], 774 prefetch_to_device=[None, True])) 775 def test_prefetch_to_device_dataset(self, prefetch_to_device): 776 distribution, _, _ = create_test_objects( 777 cluster_spec=self._cluster_spec, 778 task_type='worker', 779 task_id=0, 780 num_gpus=2) 781 if prefetch_to_device is None: 782 input_options = None 783 else: 784 input_options = distribute_lib.InputOptions( 785 experimental_prefetch_to_device=prefetch_to_device) 786 dataset = dataset_ops.Dataset.range(100) 787 dataset = dataset.batch(distribution.num_replicas_in_sync) 788 dataset = distribution.experimental_distribute_dataset( 789 dataset, options=input_options) 790 if isinstance(dataset, input_lib.DistributedDatasetV1): 791 item = dataset.make_initializable_iterator().get_next() 792 else: 793 self.skipTest('unsupported test combination') 794 device_types = { 795 tf_device.DeviceSpec.from_string(tensor.device).device_type for 796 tensor in item.values} 797 self.assertAllEqual(list(device_types), ['GPU']) 798 799 @combinations.generate(combinations.combine(mode=['graph'])) 800 def test_prefetch_to_host_dataset(self): 801 distribution, _, _ = create_test_objects( 802 cluster_spec=self._cluster_spec, 803 task_type='worker', 804 task_id=0, 805 num_gpus=2) 806 input_options = distribute_lib.InputOptions( 807 experimental_prefetch_to_device=False) 808 dataset = dataset_ops.Dataset.range(100) 809 dataset = dataset.batch(distribution.num_replicas_in_sync) 810 dataset = distribution.experimental_distribute_dataset( 811 dataset, options=input_options) 812 if isinstance(dataset, input_lib.DistributedDatasetV1): 813 item = dataset.make_initializable_iterator().get_next() 814 else: 815 self.skipTest('unsupported test combination') 816 device_types = { 817 tf_device.DeviceSpec.from_string(tensor.device).device_type for 818 tensor in item.values} 819 self.assertAllEqual(list(device_types), ['CPU']) 820 821 822class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase, 823 parameterized.TestCase): 824 825 @classmethod 826 def setUpClass(cls): 827 cls._cluster_spec = multi_worker_test_base.create_in_process_cluster( 828 num_workers=3, num_ps=2, has_chief=True) 829 cls._default_target = 'grpc://' + cls._cluster_spec[CHIEF][0] 830 831 @combinations.generate( 832 combinations.combine(mode=['graph'], required_gpus=[0, 1, 2])) 833 def testSimpleBetweenGraph(self, required_gpus): 834 self._run_between_graph_clients(self._test_simple_increment, 835 self._cluster_spec, required_gpus) 836 837 @combinations.generate( 838 combinations.combine(mode=['graph'], num_gpus=[0, 1, 2])) 839 def testMinimizeLossGraph(self, num_gpus): 840 self._run_between_graph_clients(self._test_minimize_loss_graph, 841 self._cluster_spec, num_gpus) 842 843 @combinations.generate(combinations.combine(mode=['graph'])) 844 def testGlobalStepIsWrappedOnTwoGPUs(self): 845 strategy, _, _ = create_test_objects(num_gpus=2) 846 with ops.Graph().as_default(), strategy.scope(): 847 created_step = training_util.create_global_step() 848 get_step = training_util.get_global_step() 849 self.assertEqual(created_step, get_step, 850 msg=('created_step %s type %s vs. get_step %s type %s' % 851 (id(created_step), created_step.__class__.__name__, 852 id(get_step), get_step.__class__.__name__))) 853 self.assertIs(ps_values.AggregatingVariable, type(created_step)) 854 self.assertIs(ps_values.AggregatingVariable, type(get_step)) 855 self.assertIs(strategy, created_step.distribute_strategy) 856 857 @combinations.generate(combinations.combine(mode=['graph'])) 858 def testGlobalStepIsNotWrappedOnOneGPU(self): 859 strategy, _, _ = create_test_objects(num_gpus=1) 860 with ops.Graph().as_default(), strategy.scope(): 861 created_step = training_util.create_global_step() 862 get_step = training_util.get_global_step() 863 self.assertEqual(created_step, get_step, 864 msg=('created_step %s type %s vs. get_step %s type %s' % 865 (id(created_step), created_step.__class__.__name__, 866 id(get_step), get_step.__class__.__name__))) 867 self.assertIs(resource_variable_ops.ResourceVariable, type(created_step)) 868 self.assertIs(resource_variable_ops.ResourceVariable, type(get_step)) 869 # All variables have an _distribute_strategy parameter. Only variable 870 # subclasses in distribution strategy expose it publicly. 871 self.assertFalse(hasattr(strategy, 'distribute_strategy')) 872 self.assertIs(strategy, created_step._distribute_strategy) 873 874 @combinations.generate(combinations.combine(mode=['graph'], required_gpus=2)) 875 def testValueContainer(self): 876 strategy, _, _ = create_test_objects(num_gpus=2) 877 with ops.Graph().as_default(), strategy.scope(): 878 879 def f(): 880 with backprop.GradientTape() as tape: 881 v = variable_scope.get_variable('v', initializer=10.0) 882 _ = v * v 883 v, = tape.watched_variables() 884 w = strategy.extended.value_container(v) 885 self.assertIs(ps_values.AggregatingVariable, type(w)) 886 887 strategy.extended.call_for_each_replica(f) 888 889 890class CentralStorageStrategyTest(strategy_test_lib.DistributionTestBase, 891 parameterized.TestCase): 892 893 @combinations.generate(combinations.combine(mode=['graph', 'eager'], 894 required_gpus=2)) 895 def testNumpyDataset(self): 896 strategy, _, _ = create_test_objects(num_gpus=2) 897 self._test_numpy_dataset(strategy) 898 899 @combinations.generate(combinations.combine(mode=['graph', 'eager'])) 900 def testInMultiWorkerMode(self): 901 strategy, _, _ = create_test_objects(num_gpus=0) 902 self.assertFalse(strategy.extended._in_multi_worker_mode()) 903 904 905if __name__ == '__main__': 906 test.main() 907