1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the distributed values library.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22import os 23 24from absl.testing import parameterized 25import numpy as np 26 27from tensorflow.core.protobuf import config_pb2 28from tensorflow.python import tf2 29from tensorflow.python.data.ops import dataset_ops 30from tensorflow.python.distribute import collective_all_reduce_strategy 31from tensorflow.python.distribute import combinations 32from tensorflow.python.distribute import distribute_lib 33from tensorflow.python.distribute import distribute_utils 34from tensorflow.python.distribute import packed_distributed_variable as packed 35from tensorflow.python.distribute import parameter_server_strategy 36from tensorflow.python.distribute import ps_values 37from tensorflow.python.distribute import strategy_combinations 38from tensorflow.python.distribute import test_util as ds_test_util 39from tensorflow.python.distribute import tpu_strategy 40from tensorflow.python.distribute import tpu_values 41from tensorflow.python.distribute import values as values_lib 42from tensorflow.python.eager import context 43from tensorflow.python.eager import def_function 44from tensorflow.python.eager import test 45from tensorflow.python.framework import constant_op 46from tensorflow.python.framework import dtypes 47from tensorflow.python.framework import indexed_slices 48from tensorflow.python.framework import ops 49from tensorflow.python.framework import sparse_tensor 50from tensorflow.python.framework import tensor_shape 51from tensorflow.python.framework import tensor_spec 52from tensorflow.python.framework import test_util 53from tensorflow.python.ops import array_ops 54from tensorflow.python.ops import check_ops 55from tensorflow.python.ops import control_flow_ops 56from tensorflow.python.ops import math_ops 57from tensorflow.python.ops import sparse_ops 58from tensorflow.python.ops import variable_scope 59from tensorflow.python.ops import variables as variables_lib 60from tensorflow.python.saved_model import save 61from tensorflow.python.saved_model import save_context 62from tensorflow.python.saved_model import save_options 63from tensorflow.python.training import saver as saver_lib 64from tensorflow.python.training.tracking import util as trackable_utils 65from tensorflow.python.types import core 66from tensorflow.python.util import nest 67 68 69def _device_str(d): 70 return "/device:GPU:" + str(d) 71 72 73def _nested_value(d): 74 return ("a" + d, ["b" + d, {"c": "d" + d, "e": "f" + d}, "g" + d], "h" + d) 75 76 77def _make_mirrored_val(init_val=5.0): 78 v = [] 79 devices = ["/device:GPU:0", "/device:CPU:0"] 80 for d, _ in zip(devices, ["v", "v/replica"]): 81 with ops.device(d): 82 v.append(constant_op.constant(init_val)) 83 return values_lib.Mirrored(v) 84 85 86def _make_mirrored(distribution=None): 87 v = [] 88 if distribution: 89 devices = distribution.extended.worker_devices 90 else: 91 devices = ["/device:GPU:0", "/device:CPU:0"] 92 for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]): 93 with ops.device(d): 94 v.append( 95 variable_scope.get_variable( 96 name=n, initializer=init, use_resource=True)) 97 98 if (distribution is not None) and isinstance(distribution, _TPU_STRATEGIES): 99 var_cls = tpu_values.TPUMirroredVariable 100 else: 101 var_cls = values_lib.MirroredVariable 102 mirrored = var_cls(distribution, v, variable_scope.VariableAggregation.SUM) 103 return mirrored 104 105 106def mirrored_and_tpu_strategy_combinations(): 107 return combinations.combine( 108 distribution=[ 109 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 110 strategy_combinations.mirrored_strategy_with_two_gpus_no_merge_call, 111 strategy_combinations.tpu_strategy, 112 strategy_combinations.tpu_strategy_packed_var, 113 ], 114 mode=["graph", "eager"]) 115 116 117class DistributedValuesTest(test.TestCase, parameterized.TestCase): 118 119 @combinations.generate( 120 combinations.combine( 121 distribution=(strategy_combinations.all_strategies_minus_default + 122 strategy_combinations.multiworker_strategies), 123 mode=["eager"] 124 )) 125 def testMakeDistributedValueFromTensor(self, distribution): 126 if not tf2.enabled(): 127 self.skipTest("Only V2 is supported.") 128 single_value = constant_op.constant(1) 129 def value_fn(ctx): 130 del ctx 131 return single_value 132 133 distributed_values = ( 134 distribution.experimental_distribute_values_from_function(value_fn)) 135 self.assertAllEqual( 136 ds_test_util.gather(distribution, distributed_values), 137 constant_op.constant(1., shape=(distribution.num_replicas_in_sync))) 138 139 @combinations.generate( 140 combinations.combine( 141 distribution=(strategy_combinations.all_strategies_minus_default + 142 strategy_combinations.multiworker_strategies), 143 mode=["eager"] 144 )) 145 def testMakeDistributedValueSingleNumpyArrayConstant(self, distribution): 146 if not tf2.enabled(): 147 self.skipTest("Only V2 is supported.") 148 array_value = np.array([1., 2., 3.]) 149 def value_fn(ctx): 150 del ctx 151 return array_value 152 153 distributed_values = ( 154 distribution.experimental_distribute_values_from_function(value_fn)) 155 self.assertAllEqual( 156 ds_test_util.gather(distribution, distributed_values).numpy(), 157 [[1., 2., 3.]] * distribution.num_replicas_in_sync) 158 159 @combinations.generate( 160 combinations.combine( 161 distribution=(strategy_combinations.all_strategies_minus_default + 162 strategy_combinations.multiworker_strategies), 163 mode=["eager"] 164 )) 165 def testMakeDistributedValueTupleConstant(self, distribution): 166 if not tf2.enabled(): 167 self.skipTest("Only V2 is supported.") 168 tuple_value = (1., 2., 3.) 169 def value_fn(ctx): 170 del ctx 171 return tuple_value 172 distributed_values = ( 173 distribution.experimental_distribute_values_from_function(value_fn)) 174 distributed_values = ds_test_util.gather(distribution, distributed_values) 175 176 # Expected output for 2 replicas: 177 # ([1.0, 1.0], [2.0, 2.0], [3.0, 3.0]) 178 expected = tuple([v for i in range(distribution.num_replicas_in_sync)] 179 for v in tuple_value) 180 self.assertAllEqual(distributed_values, expected) 181 182 @combinations.generate( 183 combinations.combine( 184 distribution=(strategy_combinations.all_strategies_minus_default + 185 strategy_combinations.multiworker_strategies), 186 mode=["eager"] 187 )) 188 def testMakeDistributedValueNestedStructurePerReplica(self, distribution): 189 if not tf2.enabled(): 190 self.skipTest("Only V2 is supported.") 191 tuple_value = (1., 2., 3.) 192 def value_fn(ctx): 193 per_replica = [] 194 for val in tuple_value: 195 per_replica.append(val * ctx.replica_id_in_sync_group) 196 return tuple(per_replica) 197 distributed_values = ( 198 distribution.experimental_distribute_values_from_function(value_fn)) 199 distributed_values = ds_test_util.gather(distribution, distributed_values) 200 201 # Expected output for 2 replicas: 202 # ([0.0, 1.0], [0.0, 2.0], [0.0, 3.0]) 203 expected = tuple([v * i for i in range(distribution.num_replicas_in_sync)] 204 for v in tuple_value) 205 self.assertAllEqual(distributed_values, expected) 206 207 # NOTE(priyag): Cannot test this with MultiWorkerMirroredStrategy because 208 # collective ops do not support SparseTensors. 209 @combinations.generate( 210 combinations.combine( 211 distribution=strategy_combinations.all_strategies_minus_default, 212 mode=["eager"] 213 )) 214 def testMakeDistributedValueSpareTensor(self, distribution): 215 if not tf2.enabled(): 216 self.skipTest("Only V2 is supported.") 217 def value_fn(ctx): 218 del ctx 219 return sparse_tensor.SparseTensor( 220 indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) 221 222 distributed_values = ( 223 distribution.experimental_distribute_values_from_function(value_fn)) 224 local_results = distribution.experimental_local_results(distributed_values) 225 for i in range(distribution.num_replicas_in_sync): 226 self.assertAllEqual( 227 sparse_ops.sparse_tensor_to_dense(local_results[i]), 228 [[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]]) 229 230 @combinations.generate( 231 combinations.combine( 232 distribution=(strategy_combinations.all_strategies_minus_default + 233 strategy_combinations.multiworker_strategies), 234 mode=["eager"] 235 )) 236 def testMakeDistributedValueExtractFromArray(self, distribution): 237 if not tf2.enabled(): 238 self.skipTest("Only V2 is supported.") 239 multiple_values = range(distribution.num_replicas_in_sync) 240 def value_fn(ctx): 241 return multiple_values[ctx.replica_id_in_sync_group] 242 distributed_values = ( 243 distribution.experimental_distribute_values_from_function(value_fn)) 244 distributed_values = ds_test_util.gather(distribution, distributed_values) 245 expected = range(distribution.num_replicas_in_sync) 246 self.assertAllEqual(distributed_values, expected) 247 248 @combinations.generate( 249 combinations.combine( 250 distribution=(strategy_combinations.all_strategies_minus_default + 251 strategy_combinations.multiworker_strategies), 252 mode=["eager"] 253 )) 254 def testMakeDistributedValueAndRun(self, distribution): 255 if not tf2.enabled(): 256 self.skipTest("Only V2 is supported.") 257 258 @def_function.function 259 def run(): 260 multiple_values = range(distribution.num_replicas_in_sync) 261 def value_fn(ctx): 262 return multiple_values[ctx.replica_id_in_sync_group] 263 distributed_values = ( 264 distribution.experimental_distribute_values_from_function(value_fn)) 265 266 def computation(x): 267 return math_ops.square(x) 268 269 outputs = ds_test_util.gather( 270 distribution, 271 distribution.run(computation, args=(distributed_values,))) 272 return outputs 273 274 results = run() 275 276 expected = [i**2 for i in range(distribution.num_replicas_in_sync)] 277 self.assertAllEqual(results, expected) 278 279 @combinations.generate( 280 combinations.combine( 281 distribution=[ 282 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 283 strategy_combinations 284 .mirrored_strategy_with_two_gpus_no_merge_call, 285 strategy_combinations.tpu_strategy, 286 strategy_combinations.tpu_strategy_packed_var, 287 strategy_combinations.central_storage_strategy_with_two_gpus, 288 ] + strategy_combinations.multiworker_strategies, 289 mode=["eager"])) 290 def testMakeDistributedValueDefaultDevicePlacement(self, distribution): 291 if not tf2.enabled(): 292 self.skipTest("Only V2 is supported.") 293 def value_fn(ctx): 294 del ctx 295 return constant_op.constant(1.0) 296 distributed_values = ( 297 distribution.experimental_distribute_values_from_function(value_fn)) 298 default_device = array_ops.identity(constant_op.constant(1.0)).device 299 for i in range(len(distribution.extended.worker_devices)): 300 self.assertAllEqual(distributed_values._values[i].device, default_device) 301 302 @combinations.generate( 303 combinations.combine( 304 distribution=[ 305 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 306 strategy_combinations 307 .mirrored_strategy_with_two_gpus_no_merge_call, 308 strategy_combinations.tpu_strategy, 309 strategy_combinations.tpu_strategy_packed_var, 310 strategy_combinations.central_storage_strategy_with_two_gpus, 311 ] + strategy_combinations.multiworker_strategies, 312 mode=["eager"], 313 op_type=[constant_op.constant, array_ops.identity])) 314 def testMakeDistributedValueExplicitDevicePlacement(self, distribution, 315 op_type): 316 if not tf2.enabled(): 317 self.skipTest("Only V2 is supported.") 318 worker_devices = distribution.extended.worker_devices 319 def value_fn(ctx): 320 # In multi client setup, worker_devices is just the devices on that 321 # worker. 322 worker_device_id = ctx.replica_id_in_sync_group % len(worker_devices) 323 with ops.device(worker_devices[worker_device_id]): 324 return op_type(1.0) 325 326 distributed_values = ( 327 distribution.experimental_distribute_values_from_function(value_fn)) 328 for i in range(len(distribution.extended.worker_devices)): 329 self.assertAllEqual(distributed_values._values[i].device, 330 worker_devices[i]) 331 332 333class PerWorkerResourceTest(test.TestCase, parameterized.TestCase): 334 335 @combinations.generate( 336 combinations.combine(dataset_fn_as_tf_function=[True, False])) 337 def testMapFnTracing(self, dataset_fn_as_tf_function): 338 # For a PerWorkerResource to correctly behave when used in dataset.map, 339 # it has to be that the map_fn is not traced only once such that 340 # PerWorkerResource.local_table can return the correct resource. This test 341 # can detect the potential breakage of this behavior on TAP. 342 self._traced_once = 0 343 344 def map_fn(x): 345 self._traced_once += 1 346 return x 347 348 def dataset_fn(): 349 dataset = dataset_ops.DatasetV2.from_tensors([0, 1, 2]).repeat().batch( 350 2, drop_remainder=True) 351 dataset = dataset.map(map_fn) 352 return dataset 353 354 datasets = [] 355 number_of_input_pipelines = 5 356 357 if dataset_fn_as_tf_function: 358 dataset_fn = def_function.function(dataset_fn) 359 expected_tracing_times = 1 360 else: 361 expected_tracing_times = number_of_input_pipelines 362 363 for _ in range(number_of_input_pipelines): 364 datasets.append(dataset_fn()) 365 366 self.assertEqual(self._traced_once, expected_tracing_times) 367 368 369class DistributedDelegateTest(test.TestCase): 370 371 @test_util.run_in_graph_and_eager_modes 372 def testGetAttr(self): 373 class Foo(object): 374 375 def __init__(self, x): 376 self.x = x 377 378 v = values_lib.DistributedDelegate((Foo(7), Foo(8))) 379 self.assertEqual(7, v.x) 380 with self.assertRaises(AttributeError): 381 _ = v.y 382 383 @test_util.run_in_graph_and_eager_modes 384 def testOperatorOverride(self): 385 v = values_lib.DistributedDelegate((7, 8)) 386 # v should act like int(7). 387 self.assertEqual(8, v + 1) 388 self.assertEqual(10, 3 + v) 389 self.assertEqual(14, v + v) 390 self.assertEqual(5, v - 2) 391 self.assertEqual(6, 13 - v) 392 self.assertEqual(0, v - v) 393 self.assertEqual(14, v * 2) 394 self.assertEqual(21, 3 * v) 395 self.assertEqual(49, v * v) 396 self.assertEqual(3.5, v / 2) 397 self.assertEqual(1.5, 10.5 / v) 398 self.assertEqual(3, v // 2) 399 self.assertEqual(2, 15 // v) 400 self.assertEqual(1, v % 2) 401 self.assertEqual(2, 16 % v) 402 # pylint: disable=g-generic-assert 403 self.assertTrue(v < 12) 404 self.assertTrue(v <= 12) 405 self.assertFalse(v > 12) 406 self.assertFalse(v >= 12) 407 self.assertFalse(12 < v) 408 self.assertFalse(12 <= v) 409 self.assertTrue(12 > v) 410 self.assertTrue(12 >= v) 411 # pylint: enable=g-generic-assert 412 self.assertEqual(3, v & 3) 413 self.assertEqual(3, 11 & v) 414 self.assertEqual(15, v | 8) 415 self.assertEqual(23, 16 | v) 416 self.assertEqual(4, v ^ 3) 417 self.assertEqual(12, 11 ^ v) 418 self.assertEqual(343, pow(v, 3)) 419 self.assertEqual(3, pow(v, 3, 10)) 420 self.assertEqual(128, pow(2, v)) 421 self.assertEqual(-7, -v) 422 self.assertEqual(~7, ~v) 423 self.assertEqual(7, abs(v)) 424 with self.assertRaises(TypeError): 425 _ = v[2] 426 427 @test_util.run_in_graph_and_eager_modes 428 def testCopy(self): 429 430 class Foo(object): 431 432 def __init__(self, x): 433 self.x = x 434 435 v = values_lib.DistributedDelegate((Foo(7), Foo(8))) 436 v_shallow_copy = copy.copy(v) 437 self.assertEqual(v.x, v_shallow_copy.x) 438 v_deep_copy = copy.deepcopy(v) 439 self.assertEqual(v.x, v_deep_copy.x) 440 441 442@combinations.generate( 443 combinations.combine( 444 distribution=[ 445 strategy_combinations.mirrored_strategy_with_one_cpu, 446 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 447 strategy_combinations.mirrored_strategy_with_two_gpus_no_merge_call, 448 strategy_combinations.tpu_strategy, 449 strategy_combinations.tpu_strategy_packed_var, 450 strategy_combinations.central_storage_strategy_with_gpu_and_cpu, 451 strategy_combinations.multi_worker_mirrored_2x1_cpu, 452 strategy_combinations.multi_worker_mirrored_2x1_gpu, 453 strategy_combinations.multi_worker_mirrored_2x2_gpu, 454 strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call, 455 ], 456 synchronization=[ 457 variables_lib.VariableSynchronization.ON_READ, 458 variables_lib.VariableSynchronization.ON_WRITE, 459 ], 460 aggregation=[ 461 variables_lib.VariableAggregation.MEAN, 462 variables_lib.VariableAggregation.SUM, 463 variables_lib.VariableAggregation.ONLY_FIRST_REPLICA, 464 ], 465 mode=["graph", "eager"], 466 use_var_policy=[True, False])) 467class DistributedVariableTest(test.TestCase, parameterized.TestCase): 468 469 def testExtendsVariable(self, distribution, synchronization, aggregation): 470 with distribution.scope(): 471 v = variables_lib.Variable( 472 1., synchronization=synchronization, aggregation=aggregation) 473 self.assertIsInstance(v, variables_lib.Variable) 474 475 def testCheckpointing(self, distribution, synchronization, aggregation, mode): 476 477 if (isinstance(distribution, 478 collective_all_reduce_strategy.CollectiveAllReduceStrategy) 479 and mode == "graph"): 480 self.skipTest("MWMS combinations tests do not work well in graph mode.") 481 482 with distribution.scope(): 483 v = variables_lib.Variable( 484 constant_op.constant([1., 2., 3., 4]), 485 synchronization=synchronization, 486 aggregation=aggregation) 487 488 self.evaluate(v.initializer) 489 before_save = self.evaluate(v.read_value()) 490 491 # Save random weights into checkpoint. 492 checkpoint = trackable_utils.Checkpoint(v=v) 493 prefix = os.path.join(self.get_temp_dir(), "ckpt") 494 with self.test_session(): 495 save_path = checkpoint.save(prefix) 496 497 # Assign inverted value. 498 self.evaluate(v.assign(constant_op.constant([4., 3., 2., 1.]))) 499 after_assign = self.evaluate(v.read_value()) 500 self.assertNotAllClose(before_save, after_assign) 501 502 # Restore from the checkpoint. 503 with self.test_session(): 504 checkpoint.restore(save_path).assert_consumed().run_restore_ops() 505 after_restore = self.evaluate(v) 506 self.assertAllClose(before_save, after_restore) 507 508 def testTraceback(self, distribution, synchronization, aggregation): 509 if context.executing_eagerly(): 510 self.skipTest("does not apply to eager") 511 with distribution.scope(): 512 variable_scope.get_variable( 513 name="testVar", 514 initializer=1., 515 use_resource=True, 516 synchronization=synchronization, 517 aggregation=aggregation) 518 with self.assertRaisesRegex(ValueError, 519 "Variable testVar already exists"): 520 variable_scope.get_variable( 521 name="testVar", 522 initializer=1., 523 use_resource=True, 524 synchronization=synchronization, 525 aggregation=aggregation) 526 527 def testSelectReplica(self, distribution, synchronization, aggregation): 528 with distribution.scope(): 529 v = variables_lib.Variable( 530 1., synchronization=synchronization, aggregation=aggregation) 531 self.assertIs(v, distribute_utils.select_replica(0, v)) 532 533 def testIsTensorLike(self, distribution, synchronization, aggregation): 534 if isinstance(distribution.extended, 535 tpu_strategy.TPUExtended) and context.executing_eagerly(): 536 self.skipTest("TPU doesn't support pure eager") 537 538 with distribution.scope(): 539 v = variables_lib.Variable( 540 0., synchronization=synchronization, aggregation=aggregation) 541 # In cross replica context. 542 self.assertIsInstance(v, core.Tensor) 543 # In replica context. 544 distribution.run( 545 lambda v: self.assertIsInstance(v, core.Tensor), args=(v,)) 546 547 def testAssignReturnValueIsTensorLike(self, distribution, synchronization, 548 aggregation): 549 if isinstance(distribution.extended, tpu_strategy.TPUExtended): 550 if context.executing_eagerly(): 551 self.skipTest("TPU doesn't support pure eager") 552 else: 553 self.skipTest("b/152076846") 554 555 with distribution.scope(): 556 v = variables_lib.Variable( 557 0., synchronization=synchronization, aggregation=aggregation) 558 559 def assert_is_tensor_like(v): 560 # We can't use Python literals because they are treated as non-distributed 561 # values is not allowed when aggregation is SUM. See 562 # `cross_device_ops.reduce_non_distributed_value`. 563 delta = array_ops.identity(1.) 564 self.assertIsInstance(v.assign(delta), core.Tensor) 565 self.assertIsInstance(v.assign_sub(delta), core.Tensor) 566 self.assertIsInstance(v.assign_add(delta), core.Tensor) 567 568 # In cross replica context we return a PerReplica which is not Tensor like 569 # all the time yet. 570 if (synchronization == variables_lib.VariableSynchronization.ON_READ and 571 aggregation != variables_lib.VariableAggregation.SUM): 572 assert_is_tensor_like(v) 573 574 # In replica context. 575 distribution.run(assert_is_tensor_like, args=(v,)) 576 577 def testDeepCopy(self, distribution, synchronization, 578 aggregation): 579 if not context.executing_eagerly(): 580 self.skipTest("deepcopy only supported in eager mode") 581 582 with distribution.scope(): 583 v = variables_lib.Variable( 584 0., synchronization=synchronization, aggregation=aggregation) 585 in_dist_copy = copy.deepcopy(v) 586 587 out_dist_copy = copy.deepcopy(v) 588 589 def assert_is_deep_copy(v1, v2): 590 self.assertIsInstance(v2, type(v1)) 591 self.assertEqual(v1.aggregation, v2.aggregation) 592 self.assertEqual(v1.distribute_strategy, v2.distribute_strategy) 593 if isinstance(v1, ps_values.AggregatingVariable): 594 self.assertIsInstance(v2.get(), type(v1.get())) 595 self.assertNotEqual(id(v1.get()), id(v2.get())) 596 else: 597 if v1._policy: 598 self.assertNotEqual(id(v1._policy), id(v2._policy)) # pylint: disable=protected-access 599 else: 600 self.assertEqual(id(v1._policy), id(v2._policy)) # pylint: disable=protected-access 601 self.assertEqual(len(v1.values), len(v2.values)) 602 for (v1v, v2v) in zip(v1.values, v2.values): 603 self.assertEqual(v1v.device, v2v.device) 604 self.assertNotEqual(id(v1v), id(v2v)) 605 self.assertAllEqual(self.evaluate(v1.values), 606 self.evaluate(v2.values)) 607 608 self.evaluate(variables_lib.global_variables_initializer()) 609 if not isinstance(distribution.extended, tpu_strategy.TPUExtended): 610 distribution.run(assert_is_deep_copy, args=(v, in_dist_copy)) 611 distribution.run(assert_is_deep_copy, args=(v, out_dist_copy)) 612 613 def testAssignSignature(self, distribution, synchronization, aggregation): 614 # This test verifies assign*() can be called in the same way as normal 615 # variables. 616 with distribution.scope(): 617 v = variables_lib.Variable( 618 0., synchronization=synchronization, aggregation=aggregation) 619 620 def assign(): 621 one = constant_op.constant(1.) 622 v.assign(one, True, "assign", False) 623 # TODO(b/154017756): SyncOnReadVariable.assign() doesn't support passing 624 # value as a keyword argument. 625 v.assign(one, use_locking=True, name="assign", read_value=False) 626 v.assign_add(one, True, "assign", False) 627 v.assign_add(one, use_locking=True, name="assign", read_value=False) 628 v.assign_sub(one, True, "assign", False) 629 v.assign_sub(one, use_locking=True, name="assign", read_value=False) 630 # Return something for graph mode to fetch. 631 return constant_op.constant(1) 632 633 self.evaluate(variables_lib.global_variables_initializer()) 634 if not (synchronization == variables_lib.VariableSynchronization.ON_READ 635 and aggregation == variables_lib.VariableAggregation.SUM): 636 self.evaluate(distribution.experimental_local_results(assign())) 637 if not (isinstance(distribution.extended, tpu_strategy.TPUExtended) and 638 context.executing_eagerly()): 639 self.evaluate( 640 distribution.experimental_local_results(distribution.run(assign))) 641 642 def testStrategyExtendedUpdate(self, distribution, synchronization, 643 aggregation): 644 if len(distribution.extended.parameter_devices) != 2: 645 self.skipTest("n/a: needs exactly two parameter devices") 646 if (synchronization == variables_lib.VariableSynchronization.ON_WRITE and 647 aggregation != variables_lib.VariableAggregation.NONE): 648 self.skipTest("n/a: doesn't apply to ON_WRITE variable with aggregation") 649 with distribution.scope(): 650 v = variables_lib.Variable( 651 0., synchronization=synchronization, aggregation=aggregation) 652 value = values_lib.PerReplica([1., 2.]) 653 654 assign_fn = lambda var, value: var.assign(value) 655 self.evaluate(distribution.extended.update(v, assign_fn, args=(value,))) 656 self.assertAllEqual(self.evaluate(v.values), [1., 2.]) 657 658 assign_add_fn = lambda var, value: var.assign_add(value) 659 self.evaluate(distribution.extended.update(v, assign_add_fn, args=(value,))) 660 self.assertAllEqual(self.evaluate(v.values), [2., 4.]) 661 662 assign_sub_fn = lambda var, value: var.assign_sub(value) 663 self.evaluate(distribution.extended.update(v, assign_sub_fn, args=(value,))) 664 self.assertAllEqual(self.evaluate(v.values), [1., 2.]) 665 666 read_assign_fn = lambda var, value: var.assign_add(var.value() + var. 667 read_value()) 668 self.evaluate( 669 distribution.extended.update(v, read_assign_fn, args=(value,))) 670 self.assertAllEqual(self.evaluate(v.values), [3., 6.]) 671 672 def testSaveNonDistributed(self, distribution, synchronization, aggregation): 673 # This test verifies that the DistributedVariable behave like the primary 674 # variable when saving a non-distributed version of the model (the default). 675 # The test asserts that the function traced under SaveContext has no device 676 # annotations and only reference the primary component of the variable. Note 677 # that please avoid capturing other eager tensors in this test to make the 678 # assertion easy. 679 680 if isinstance(distribution.extended, 681 parameter_server_strategy.ParameterServerStrategyExtended): 682 self.skipTest("b/148689177: AggregatingVariable doesn't " 683 "conform to Variable interface well") 684 685 # tf.function requires the return value to be Tensors, which is not always 686 # case for properties and methods of Variable, so we simply discard the 687 # return values. 688 def _discard_return(f): 689 f() 690 return 691 692 def _test(f, v): 693 # This verifies that the function under SaveContext: 694 # - contains no device annotations. 695 # - only references the primary component of the variable. 696 g = def_function.function(lambda: _discard_return(f)) 697 options = save_options.SaveOptions( 698 experimental_variable_policy=save_options.VariablePolicy.NONE) 699 with save_context.save_context(options): 700 # The graph should contain no device. 701 graph = g.get_concrete_function().graph 702 for op in graph.get_operations(): 703 self.assertEqual(op.device, "", msg=str(op)) 704 # The function should only capture the primary variable. Note that it 705 # may not have captures, e.g. v.aggregation. 706 captures = list(graph.captures) 707 self.assertLessEqual(len(captures), 1) 708 if graph.captures: 709 self.assertIs(captures[0][0], v._primary.handle) 710 711 def _assert(cond): 712 return control_flow_ops.Assert(cond, [cond]) 713 714 with distribution.scope(): 715 # We use four variables for convenience reasons. They have no special 716 # meaning. 717 # - v is used whenever possible. 718 # - w is used for scatter and gather, which require the variable to be 719 # non-scalar. 720 # - y is used when the dtype needs to be integer. Note that aggregation 721 # cannot be MEAN for integers. 722 v = variables_lib.Variable( 723 0., 724 synchronization=synchronization, 725 aggregation=aggregation, 726 trainable=True) 727 w = variables_lib.Variable([0., 0., 0.], 728 synchronization=synchronization, 729 aggregation=aggregation, 730 trainable=True) 731 if aggregation != variables_lib.VariableAggregation.MEAN: 732 y = variables_lib.Variable( 733 0, 734 synchronization=synchronization, 735 aggregation=aggregation) 736 737 # pylint: disable=g-long-lambda 738 739 # tf.Variable properties. 740 _test(lambda: self.assertEqual(v.aggregation, aggregation), v) 741 _test(lambda: self.assertIs(v.constraint, None), v) 742 # TODO(crccw): should we raise an error instead? 743 _test(lambda: self.assertEqual(v.device, v._primary.device), v) 744 _test(lambda: self.assertEqual(v.dtype, dtypes.float32), v) 745 if not context.executing_eagerly(): 746 _test(lambda: self.assertIs(v.graph, v._primary.graph), v) 747 if not context.executing_eagerly(): 748 _test(lambda: _assert(v.initial_value == 0), v) 749 _test(lambda: self.assertIs(v.initializer, v._primary.initializer), v) 750 _test(lambda: self.assertEqual(v.name, "Variable:0"), v) 751 if not context.executing_eagerly(): 752 _test(lambda: self.assertIs(v.op, v._primary.op), v) 753 _test(lambda: self.assertEqual(v.shape, tensor_shape.TensorShape(())), v) 754 _test(lambda: self.assertEqual(v.synchronization, synchronization), v) 755 _test(lambda: self.assertTrue(v.trainable, True), v) 756 757 # tf.Variable methods. 758 _test(lambda: check_ops.assert_equal_v2(v.assign(1.), 1.), v) 759 _test(lambda: check_ops.assert_equal_v2(v.assign_add(1.), 2.), v) 760 _test(lambda: check_ops.assert_equal_v2(v.assign_sub(1.), 1.), v) 761 # TODO(b/148689177): Implement batch_scatter_update. 762 # count_up_to() is skipped since it's deprecated. 763 # eval() is skipped since it shouldn't called in a tf.function. 764 # experimental_ref() is skipped since it's deprecated. 765 # from_proto() is skipped since it shouldn't called in a tf.function. 766 # TODO(b/148689177): Implement gather_nd. 767 _test( 768 lambda: check_ops.assert_equal_v2(v.get_shape(), 769 tensor_shape.TensorShape(())), v) 770 # initialized_value() is skipped since it shouldn't called in a tf.function. 771 # load() is skipped since it shouldn't called in a tf.function. 772 _test(lambda: check_ops.assert_equal_v2(v.read_value(), 1.), v) 773 # ref() is skipped since it shouldn't called in a tf.function. 774 _test( 775 lambda: check_ops.assert_equal_v2( 776 w.scatter_add(_make_index_slices(values=[1., 2.], indices=[0, 2])), 777 [1., 0., 2.]), w) 778 _test( 779 lambda: check_ops.assert_equal_v2( 780 w.scatter_div(_make_index_slices(values=[4., 2.], indices=[0, 2])), 781 [0.25, 0., 1.]), w) 782 _test( 783 lambda: check_ops.assert_equal_v2( 784 w.scatter_max(_make_index_slices(values=[1., 0.5], indices=[1, 2])), 785 [0.25, 1., 1.]), w) 786 _test( 787 lambda: check_ops.assert_equal_v2( 788 w.scatter_min(_make_index_slices(values=[1., 0.5], indices=[0, 1])), 789 [0.25, 0.5, 1.]), w) 790 _test( 791 lambda: check_ops.assert_equal_v2( 792 w.scatter_mul(_make_index_slices(values=[2., 0.5], indices=[0, 1])), 793 [0.5, 0.25, 1.]), w) 794 # TODO(b/148689177): Implement scatter_nd_* 795 _test( 796 lambda: check_ops.assert_equal_v2( 797 w.scatter_sub(_make_index_slices(values=[2., 0.5], indices=[0, 1])), 798 [-1.5, -0.25, 1.]), w) 799 _test( 800 lambda: check_ops.assert_equal_v2( 801 w.scatter_update( 802 _make_index_slices(values=[2., 0.5], indices=[0, 1])), 803 [2., 0.5, 1.]), w) 804 # set_shape() is skipped since ResourceVariable doesn't implement it. 805 # to_proto() is skipped since it shouldn't called in a tf.function. 806 _test(lambda: check_ops.assert_equal_v2(v.value(), 1.), v) 807 808 # DistributedVariable should be treated as ResourceVariable, so it needs to 809 # conform to ResourceVariable interface as well. 810 _test(lambda: self.assertIs(v.handle, v._primary.handle), v) 811 812 # Convert to tensor. 813 _test(lambda: check_ops.assert_equal_v2(ops.convert_to_tensor(v), 1.), v) 814 815 # Control dependency. 816 def _with_control_dep(): 817 with ops.control_dependencies([v.assign(1.)]): 818 return array_ops.identity(1) 819 820 _test(_with_control_dep, v) 821 822 # Operator overloads. 823 _test(lambda: check_ops.assert_equal_v2(v.assign(7.), 7.), v) 824 _test(lambda: check_ops.assert_equal_v2(v + 1., 8.), v) 825 _test(lambda: check_ops.assert_equal_v2(3 + v, 10.), v) 826 _test(lambda: check_ops.assert_equal_v2(v + v, 14.), v) 827 _test(lambda: check_ops.assert_equal_v2(v - 2., 5.), v) 828 _test(lambda: check_ops.assert_equal_v2(v - v, 0.), v) 829 _test(lambda: check_ops.assert_equal_v2(v * 2., 14.), v) 830 _test(lambda: check_ops.assert_equal_v2(3 * v, 21.), v) 831 _test(lambda: check_ops.assert_equal_v2(v * v, 49.), v) 832 _test( 833 lambda: check_ops.assert_equal_v2( 834 math_ops.cast(v / 2., dtypes.float32), 3.5), v) 835 _test( 836 lambda: check_ops.assert_equal_v2( 837 math_ops.cast(14. / v, dtypes.float32), 2.), v) 838 _test(lambda: _assert(v < 12.), v) 839 _test(lambda: _assert(v <= 12.), v) 840 _test(lambda: _assert(not v > 12.), v) 841 _test(lambda: _assert(not v >= 12.), v) 842 _test(lambda: _assert(not 12. < v), v) 843 _test(lambda: _assert(not 12. <= v), v) 844 _test(lambda: _assert(12. > v), v) 845 _test(lambda: _assert(12. >= v), v) 846 _test(lambda: check_ops.assert_near_v2(pow(v, 3.), 343.), v) 847 _test(lambda: check_ops.assert_near_v2(pow(2., v), 128.), v) 848 _test(lambda: check_ops.assert_equal_v2(abs(v), 7.), v) 849 850 # Operator overloads that only works for integers. 851 if aggregation != variables_lib.VariableAggregation.MEAN: 852 _test(lambda: check_ops.assert_equal_v2(y.assign(7), 7), y) 853 _test(lambda: check_ops.assert_equal_v2(y // 2, 3), y) 854 _test(lambda: check_ops.assert_equal_v2(15 // y, 2), y) 855 _test(lambda: check_ops.assert_equal_v2(y % 2, 1), y) 856 _test(lambda: check_ops.assert_equal_v2(16 % y, 2), y) 857 _test(lambda: check_ops.assert_equal_v2(y & 3, 3), y) 858 _test(lambda: check_ops.assert_equal_v2(3 & y, 3), y) 859 _test(lambda: check_ops.assert_equal_v2(y | 8, 15), y) 860 _test(lambda: check_ops.assert_equal_v2(16 | y, 23), y) 861 _test(lambda: check_ops.assert_equal_v2(y ^ 3, 4), y) 862 _test(lambda: check_ops.assert_equal_v2(11 ^ y, 12), y) 863 _test(lambda: check_ops.assert_equal_v2(-y, -7), y) 864 _test(lambda: check_ops.assert_equal_v2(~y, ~7), y) 865 866 # Index. 867 if isinstance(distribution.extended, tpu_strategy.TPUExtended): 868 # TODO(b/161572567): slice assignment doesn't work for TPU. 869 _test(lambda: check_ops.assert_equal_v2(w[0], 2.), w) 870 else: 871 _test(lambda: check_ops.assert_equal_v2(w[0].assign(1.), [1., 0.5, 1.]), 872 w) 873 _test(lambda: check_ops.assert_equal_v2(w[0], 1.), w) 874 875 # pylint: enable=g-long-lambda 876 877 def testUnsaveable(self, distribution, synchronization, aggregation, mode): 878 if isinstance(distribution.extended, 879 parameter_server_strategy.ParameterServerStrategyExtended): 880 self.skipTest("n/a: not appliable to AggregatingVariable") 881 if (isinstance(distribution, 882 collective_all_reduce_strategy.CollectiveAllReduceStrategy) 883 and mode == "graph"): 884 self.skipTest("MWMS combinations tests do not work well in graph mode.") 885 if not distribution.extended._use_merge_call(): 886 self.skipTest("Unsupported combination.") 887 with distribution.scope(): 888 v = variables_lib.Variable([1., 1.], 889 synchronization=synchronization, 890 aggregation=aggregation) 891 892 with self.cached_session(): 893 self.evaluate(variables_lib.global_variables_initializer()) 894 895 export_dir = self.get_temp_dir() 896 897 def _assert_unsaveable(f): 898 # Ignore if it cannot be traced. Certain combinations are not supported or 899 # yet or not allowed. 900 try: 901 f = def_function.function(f).get_concrete_function() 902 except (NotImplementedError, ValueError): 903 return 904 with self.assertRaisesRegex(ValueError, "f_with_input_signature"): 905 save.save(v, export_dir, signatures=f) 906 907 _assert_unsaveable(lambda: v.assign(ops.convert_to_tensor([1., 1.]))) 908 _assert_unsaveable(lambda: v.assign_add(ops.convert_to_tensor([1., 1.]))) 909 _assert_unsaveable(lambda: v.assign_sub(ops.convert_to_tensor([1., 1.]))) 910 _assert_unsaveable(lambda: v.scatter_add(_make_index_slices([1.], [0]))) 911 _assert_unsaveable(lambda: v.scatter_sub(_make_index_slices([1.], [0]))) 912 _assert_unsaveable(lambda: v.scatter_mul(_make_index_slices([1.], [0]))) 913 _assert_unsaveable(lambda: v.scatter_div(_make_index_slices([1.], [0]))) 914 _assert_unsaveable(lambda: v.scatter_min(_make_index_slices([1.], [0]))) 915 _assert_unsaveable(lambda: v.scatter_max(_make_index_slices([1.], [0]))) 916 _assert_unsaveable(lambda: v.scatter_update(_make_index_slices([1.], [0]))) 917 # Reading a ON_READ variable should be unsaveable if either: 918 # 1) CollectiveAllReduceStrategy, and aggregation is MEAN/SUM. 919 # 2) aggregation is SUM. 920 if (synchronization == variables_lib.VariableSynchronization.ON_READ and 921 (aggregation == variables_lib.VariableAggregation.SUM or 922 (not distribution.extended._use_merge_call()) or 923 (isinstance(distribution.extended, 924 collective_all_reduce_strategy.CollectiveAllReduceExtended) 925 and aggregation == variables_lib.VariableAggregation.MEAN))): 926 _assert_unsaveable(v.read_value) 927 _assert_unsaveable(v.value) 928 _assert_unsaveable(lambda: ops.convert_to_tensor(v)) 929 else: 930 # Otherwise reading a variable should be saveable. 931 932 @def_function.function 933 def f(): 934 v.read_value() 935 v.value() 936 return ops.convert_to_tensor(v) 937 938 with self.cached_session(): 939 save.save(v, export_dir, signatures=f.get_concrete_function()) 940 941 942@combinations.generate( 943 combinations.combine( 944 distribution=[ 945 strategy_combinations.mirrored_strategy_with_one_cpu, 946 strategy_combinations.tpu_strategy, 947 ], 948 mode=["eager"])) 949class PackedDistributedVariableTest(test.TestCase, parameterized.TestCase): 950 951 def testPackedVariable(self, distribution): 952 with distribution.scope(): 953 v0 = variables_lib.Variable(0.) 954 self.assertIsNone(v0._packed_var) 955 956 distribution._enable_packed_variable_in_eager_mode = True 957 with distribution.scope(): 958 v1 = variables_lib.Variable(0) 959 self.assertIsInstance(v1._packed_var, packed.PackedDistributedVariable) 960 961 devices = v1._devices 962 for i in range(1, len(devices)): 963 with distribute_lib.ReplicaContext(distribution, i): 964 v1.assign(i) 965 val = v1._get() 966 self.assertIsInstance(val, packed.PackedVarAndDevice) 967 self.assertEqual(val.device, devices[0]) 968 self.assertEqual(self.evaluate(val.read_value()), 0) 969 for i in range(0, len(devices)): 970 with distribute_lib.ReplicaContext(distribution, i): 971 val = v1._get() 972 self.assertIsInstance(val, packed.PackedVarAndDevice) 973 self.assertEqual(val.device, devices[i]) 974 self.assertEqual(self.evaluate(val.read_value()), i) 975 976 def testIgnorePackedVariableInSaveContext(self, distribution): 977 distribution._enable_packed_variable_in_eager_mode = True 978 with distribution.scope(): 979 v = variables_lib.Variable(0) 980 self.assertIsInstance( 981 v._packed_variable, packed.PackedDistributedVariable) 982 983 options = save_options.SaveOptions() 984 with save_context.save_context(options): 985 self.assertIsNone(v._packed_variable) 986 987 988class MirroredVariableTest(test.TestCase, parameterized.TestCase): 989 990 config = config_pb2.ConfigProto() 991 config.allow_soft_placement = True 992 993 @test_util.run_in_graph_and_eager_modes(config=config) 994 def testProperties(self): 995 if context.num_gpus() < 1 and context.executing_eagerly(): 996 self.skipTest("A GPU is not available for this test in eager mode.") 997 998 mirrored = _make_mirrored() 999 v = mirrored.values[0] 1000 self.assertEqual(v.name, mirrored.name) 1001 self.assertEqual(v.dtype, mirrored.dtype) 1002 self.assertEqual(v.shape, mirrored.shape) 1003 1004 @test_util.run_in_graph_and_eager_modes(config=config) 1005 def testVariableOnAnotherDevice(self): 1006 v = variable_scope.get_variable( 1007 name="v", initializer=[1.], use_resource=True) 1008 mirrored = values_lib.MirroredVariable( 1009 None, (v,), variable_scope.VariableAggregation.MEAN) 1010 1011 self.assertEqual(v.name, mirrored.name) 1012 self.assertEqual(v.dtype, mirrored.dtype) 1013 self.assertEqual(v.shape, mirrored.shape) 1014 1015 1016class MirroredVariableSaveRestoreTest(test.TestCase, parameterized.TestCase): 1017 1018 def _assign_mirrored(self, v, new): 1019 for var, n in zip(v.values, new): 1020 self.evaluate(var.assign(n)) 1021 1022 def _save_return_saver(self, sess, var): 1023 saver = saver_lib.Saver(var_list=[var]) 1024 test_dir = self.get_temp_dir() 1025 prefix = os.path.join(test_dir, "ckpt") 1026 return saver.save(sess, prefix), saver 1027 1028 def _save(self, sess, var): 1029 save_path, _ = self._save_return_saver(sess, var) 1030 return save_path 1031 1032 def _save_mirrored(self, distribution): 1033 """Save variables with mirroring, returns save_path.""" 1034 with self.session(graph=ops.Graph()) as sess: 1035 mirrored = _make_mirrored(distribution) 1036 1037 # Overwrite the initial values. 1038 self._assign_mirrored(mirrored, [3., 4.]) 1039 1040 # Saves the current value of v[0], 3. 1041 save_path = self._save(sess, mirrored) 1042 1043 # Change the values between save and restore. 1044 self._assign_mirrored(mirrored, [5., 6.]) 1045 return save_path 1046 1047 def _save_normal(self): 1048 """Save variables without mirroring, returns save_path.""" 1049 with self.session(graph=ops.Graph()) as sess: 1050 var = variable_scope.get_variable( 1051 name="v", initializer=1., use_resource=True) 1052 1053 # Overwrite the initial value. 1054 self.evaluate(var.assign(3.)) 1055 1056 # Saves the current value of var, 3. 1057 save_path = self._save(sess, var) 1058 1059 # Change the values between save and restore. 1060 self.evaluate(var.assign(5.)) 1061 return save_path 1062 1063 def _restore_normal(self, save_path): 1064 """Restore to variables without mirroring in a fresh graph.""" 1065 with self.session(graph=ops.Graph()) as sess: 1066 var = variable_scope.get_variable( 1067 name="v", initializer=7., use_resource=True) 1068 1069 # Overwrite the initial value. 1070 self.evaluate(var.assign(8.)) 1071 1072 # Restores the saved value of 3. to `var`. 1073 saver = saver_lib.Saver(var_list=[var]) 1074 saver.restore(sess, save_path) 1075 self.assertEqual(3., self.evaluate(var)) 1076 1077 def _restore_mirrored(self, save_path, distribution): 1078 """Restore to variables with mirroring in a fresh graph.""" 1079 with self.session(graph=ops.Graph()) as sess: 1080 mirrored = _make_mirrored(distribution) 1081 v = mirrored.values 1082 1083 # Overwrite the initial values. 1084 self._assign_mirrored(mirrored, [7., 8.]) 1085 1086 # Restores the saved value of 3. to both variables. 1087 saver = saver_lib.Saver(var_list=[mirrored]) 1088 saver.restore(sess, save_path) 1089 self.assertEqual([3., 3.], self.evaluate([v[0], v[1]])) 1090 1091 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 1092 def testSaveAndRestoreMirroredOneGraph(self, distribution): 1093 with self.cached_session() as sess: 1094 mirrored = _make_mirrored(distribution) 1095 v = mirrored .values 1096 1097 # Overwrite the initial values. 1098 self._assign_mirrored(mirrored, [3., 4.]) 1099 1100 # Saves the current value of v[0], 3. 1101 save_path, saver = self._save_return_saver(sess, mirrored) 1102 1103 # Change the values between save and restore. 1104 self._assign_mirrored(mirrored, [5., 6.]) 1105 1106 # Restores the saved value of 3. to both variables. 1107 saver.restore(sess, save_path) 1108 self.assertEqual([3., 3.], self.evaluate([v[0], v[1]])) 1109 1110 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 1111 def testSaveMirroredRestoreMirrored(self, distribution): 1112 if context.num_gpus() < 1 and context.executing_eagerly(): 1113 # Graph mode can work without GPU because the Placer "moves" the 1114 # variable to a CPU. In other words, if there is no GPU available, but 1115 # user requested to create a variable on GPU, Placer will ignore the 1116 # user request and assign the VarHandleOp to CPU. This requires 1117 # soft_placement, which is on by default. 1118 self.skipTest("A GPU is not available for this test in eager mode.") 1119 1120 save_path = self._save_mirrored(distribution) 1121 self._restore_mirrored(save_path, distribution) 1122 1123 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 1124 def testSaveMirroredRestoreNormal(self, distribution): 1125 if context.num_gpus() < 1 and context.executing_eagerly(): 1126 # Graph mode can work without GPU because the Placer "moves" the 1127 # variable to a CPU. In other words, if there is no GPU available, but 1128 # user requested to create a variable on GPU, Placer will ignore the 1129 # user request and assign the VarHandleOp to CPU. This requires 1130 # soft_placement, which is on by default. 1131 self.skipTest("A GPU is not available for this test in eager mode.") 1132 1133 save_path = self._save_mirrored(distribution) 1134 self._restore_normal(save_path) 1135 1136 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 1137 def testSaveNormalRestoreMirrored(self, distribution): 1138 if context.num_gpus() < 1 and context.executing_eagerly(): 1139 # Graph mode can work without GPU because the Placer "moves" the 1140 # variable to a CPU. In other words, if there is no GPU available, but 1141 # user requested to create a variable on GPU, Placer will ignore the 1142 # user request and assign the VarHandleOp to CPU. This requires 1143 # soft_placement, which is on by default. 1144 self.skipTest("A GPU is not available for this test in eager mode.") 1145 1146 save_path = self._save_normal() 1147 self._restore_mirrored(save_path, distribution) 1148 1149 1150_TPU_STRATEGIES = (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1) 1151 1152 1153def _make_replica_local(method, strategy=None): 1154 if strategy is None: 1155 devices = ("/device:GPU:0", "/device:CPU:0") 1156 else: 1157 devices = strategy.extended.worker_devices 1158 1159 v = [] 1160 for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]): 1161 with ops.device(d): 1162 v.append(variable_scope.get_variable( 1163 name=n, initializer=init, use_resource=True)) 1164 1165 if (strategy is not None) and isinstance(strategy, _TPU_STRATEGIES): 1166 var_cls = tpu_values.TPUSyncOnReadVariable 1167 else: 1168 var_cls = values_lib.SyncOnReadVariable 1169 replica_local = var_cls(strategy, v, method) 1170 return v, replica_local 1171 1172 1173class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase): 1174 1175 def _assign_replica_local(self, v, new): 1176 for var, n in zip(v, new): 1177 with ops.device(var.device): 1178 self.evaluate(var.assign(n)) 1179 1180 def _save_return_saver(self, sess, var): 1181 saver = saver_lib.Saver(var_list=[var]) 1182 test_dir = self.get_temp_dir() 1183 prefix = os.path.join(test_dir, "ckpt") 1184 return saver.save(sess, prefix), saver 1185 1186 def _save(self, sess, var): 1187 save_path, _ = self._save_return_saver(sess, var) 1188 return save_path 1189 1190 config = config_pb2.ConfigProto() 1191 config.allow_soft_placement = True 1192 1193 @test_util.run_in_graph_and_eager_modes(config=config) 1194 def testProperties(self): 1195 if context.num_gpus() < 1 and context.executing_eagerly(): 1196 self.skipTest("A GPU is not available for this test in eager mode.") 1197 v, replica_local = _make_replica_local( 1198 variable_scope.VariableAggregation.SUM) 1199 1200 self.assertEqual(v[0].constraint, replica_local.constraint) 1201 self.assertEqual(v[0].name, replica_local.name) 1202 self.assertEqual(v[0].dtype, replica_local.dtype) 1203 self.assertEqual(v[0].shape, replica_local.shape) 1204 self.assertEqual(variable_scope.VariableAggregation.SUM, 1205 replica_local.aggregation) 1206 1207 @combinations.generate( 1208 combinations.combine( 1209 distribution=[ 1210 strategy_combinations.mirrored_strategy_with_gpu_and_cpu 1211 ], 1212 mode=["eager"])) 1213 def testCanPassToDefFun(self, distribution): 1214 1215 @def_function.function 1216 def add1(x): 1217 return x + 1. 1218 1219 with distribution.scope(): 1220 v = variables_lib.Variable( 1221 1., 1222 aggregation=variables_lib.VariableAggregation.MEAN, 1223 synchronization=variables_lib.VariableSynchronization.ON_READ) 1224 1225 self.assertEqual(2., self.evaluate(add1(v))) 1226 1227 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 1228 def testTensorConversion(self, distribution): 1229 with context.graph_mode(): 1230 _, replica_local = _make_replica_local( 1231 variable_scope.VariableAggregation.SUM, distribution) 1232 converted = ops.convert_to_tensor(replica_local, as_ref=False) 1233 self.assertIsInstance(converted, ops.Tensor) 1234 self.assertEqual(converted.dtype, replica_local.dtype) 1235 1236 converted = ops.convert_to_tensor(replica_local, as_ref=True) 1237 # Resources variable are converted to tensors as well when as_ref is True. 1238 self.assertIsInstance(converted, ops.Tensor) 1239 self.assertEqual(converted.dtype, replica_local.dtype) 1240 1241 @combinations.generate(combinations.combine( 1242 distribution=[ 1243 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1244 strategy_combinations.mirrored_strategy_with_two_gpus_no_merge_call, 1245 strategy_combinations.tpu_strategy, 1246 strategy_combinations.tpu_strategy_packed_var, 1247 ], mode=["eager"])) 1248 def testValueInCrossReplicaContext(self, distribution): 1249 value_list, replica_local = _make_replica_local( 1250 variable_scope.VariableAggregation.ONLY_FIRST_REPLICA, distribution) 1251 1252 self.assertIsInstance(replica_local.value(), ops.Tensor) 1253 self.assertEqual(self.evaluate(replica_local.value()), 1254 self.evaluate(value_list[0].value())) 1255 1256 @combinations.generate( 1257 combinations.combine( 1258 distribution=[ 1259 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1260 strategy_combinations.tpu_strategy_packed_var, 1261 ], 1262 mode=["eager"])) 1263 def testValueInDefaultReplicaContext(self, distribution): 1264 with distribution.scope(): 1265 v1 = variables_lib.Variable( 1266 0.0, 1267 aggregation=variables_lib.VariableAggregation.SUM, 1268 synchronization=variables_lib.VariableSynchronization.ON_READ) 1269 v2 = variables_lib.Variable( 1270 0.0, 1271 aggregation=variables_lib.VariableAggregation.SUM, 1272 synchronization=variables_lib.VariableSynchronization.ON_READ) 1273 1274 @def_function.function 1275 def replica_fn(): 1276 v1.assign_add(1.0) 1277 v2.assign_add(2.0) 1278 1279 distribution.run(replica_fn) 1280 sum_v = v1 + v2 1281 self.assertEqual(sum_v, 6.0) 1282 1283 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 1284 def testSaveAndRestoreReplicaLocalSumOneGraph(self, distribution): 1285 with self.cached_session() as sess: 1286 v, replica_local = _make_replica_local( 1287 variable_scope.VariableAggregation.SUM, distribution) 1288 1289 # Overwrite the initial values. 1290 self._assign_replica_local(v, [3., 4.]) 1291 1292 with distribution.scope(): 1293 # Saves the current value of v[0] + v[1], 7. 1294 save_path, saver = self._save_return_saver(sess, replica_local) 1295 1296 # Change the values between save and restore. 1297 self._assign_replica_local(v, [5., 6.]) 1298 1299 # Restores the saved value of 7. which gets divided equally 1300 # between the variables. 1301 saver.restore(sess, save_path) 1302 self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) 1303 1304 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 1305 def testSaveAndRestoreReplicaLocalMeanOneGraph(self, distribution): 1306 if context.num_gpus() < 1 and context.executing_eagerly(): 1307 self.skipTest("A GPU is not available for this test in eager mode.") 1308 1309 with self.cached_session() as sess: 1310 v, replica_local = _make_replica_local( 1311 variable_scope.VariableAggregation.MEAN, distribution) 1312 1313 # Overwrite the initial values. 1314 self._assign_replica_local(v, [3., 4.]) 1315 1316 with distribution.scope(): 1317 # Saves the current value of (v[0] + v[1])/2, 3.5. 1318 save_path, saver = self._save_return_saver(sess, replica_local) 1319 1320 # Change the values between save and restore. 1321 self._assign_replica_local(v, [5., 6.]) 1322 1323 # Restores the saved value of 3.5 to both variables. 1324 saver.restore(sess, save_path) 1325 self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) 1326 1327 def _save_replica_local_mean(self, distribution): 1328 """Save variables with mirroring, returns save_path.""" 1329 with self.session(graph=ops.Graph()) as sess: 1330 v, replica_local = _make_replica_local( 1331 variable_scope.VariableAggregation.MEAN, distribution) 1332 1333 # Overwrite the initial values. 1334 self._assign_replica_local(v, [3., 4.]) 1335 1336 with distribution.scope(): 1337 # Saves the current value of (v[0] + v[1])/2, 3.5 1338 save_path = self._save(sess, replica_local) 1339 1340 # Change the values between save and restore. 1341 self._assign_replica_local(v, [5., 6.]) 1342 return save_path 1343 1344 def _save_replica_local_sum(self, distribution): 1345 """Save variables with mirroring, returns save_path.""" 1346 with self.session(graph=ops.Graph()) as sess: 1347 v, replica_local = _make_replica_local( 1348 variable_scope.VariableAggregation.SUM, distribution) 1349 1350 # Overwrite the initial values. 1351 self._assign_replica_local(v, [1.5, 2.]) 1352 1353 with distribution.scope(): 1354 # Saves the current value of v[0] + v[1], 3.5 1355 save_path = self._save(sess, replica_local) 1356 1357 # Change the values between save and restore. 1358 self._assign_replica_local(v, [5., 6.]) 1359 return save_path 1360 1361 def _save_normal(self): 1362 """Save variables without mirroring, returns save_path.""" 1363 with self.session(graph=ops.Graph()) as sess: 1364 var = variable_scope.get_variable( 1365 name="v", initializer=1., use_resource=True) 1366 1367 # Overwrite the initial value. 1368 self.evaluate(var.assign(3.5)) 1369 1370 # Saves the current value of var, 3.5. 1371 save_path = self._save(sess, var) 1372 1373 # Change the values between save and restore. 1374 self.evaluate(var.assign(5.)) 1375 return save_path 1376 1377 def _restore_normal(self, save_path): 1378 """Restore to variables without mirroring in a fresh graph.""" 1379 with self.session(graph=ops.Graph()) as sess: 1380 var = variable_scope.get_variable( 1381 name="v", initializer=7., use_resource=True) 1382 1383 # Overwrite the initial value. 1384 self.evaluate(var.assign(8.)) 1385 1386 # Restores the saved value of 3.5 to `var`. 1387 saver = saver_lib.Saver(var_list=[var]) 1388 saver.restore(sess, save_path) 1389 self.assertEqual(3.5, self.evaluate(var)) 1390 1391 def _restore_replica_local_mean(self, save_path, distribution): 1392 """Restore to variables with mirroring in a fresh graph.""" 1393 with self.session(graph=ops.Graph()) as sess: 1394 v, replica_local = _make_replica_local( 1395 variable_scope.VariableAggregation.MEAN, distribution) 1396 1397 # Overwrite the initial values. 1398 self._assign_replica_local(v, [7., 8.]) 1399 1400 with distribution.scope(): 1401 # Restores the saved value of 3.5 to both variables. 1402 saver = saver_lib.Saver(var_list=[replica_local]) 1403 saver.restore(sess, save_path) 1404 self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) 1405 1406 def _restore_replica_local_sum(self, save_path, distribution): 1407 """Restore to variables with mirroring in a fresh graph.""" 1408 with self.session(graph=ops.Graph()) as sess: 1409 v, replica_local = _make_replica_local( 1410 variable_scope.VariableAggregation.SUM, distribution) 1411 1412 # Overwrite the initial values. 1413 self._assign_replica_local(v, [7., 8.]) 1414 1415 with distribution.scope(): 1416 # Restores the saved value of 3.5 to both variables. 1417 saver = saver_lib.Saver(var_list=[replica_local]) 1418 saver.restore(sess, save_path) 1419 self.assertEqual([1.75, 1.75], self.evaluate([v[0], v[1]])) 1420 1421 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 1422 def testSaveReplicaLocalRestoreReplicaLocalMean(self, distribution): 1423 save_path = self._save_replica_local_mean(distribution) 1424 self._restore_replica_local_mean(save_path, distribution) 1425 1426 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 1427 def testSaveReplicaLocalRestoreReplicaLocalSum(self, distribution): 1428 save_path = self._save_replica_local_sum(distribution) 1429 self._restore_replica_local_sum(save_path, distribution) 1430 1431 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 1432 def testSaveReplicaLocalMeanRestoreNormal(self, distribution): 1433 save_path = self._save_replica_local_mean(distribution) 1434 self._restore_normal(save_path) 1435 1436 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 1437 def testSaveReplicaLocalSumRestoreNormal(self, distribution): 1438 save_path = self._save_replica_local_sum(distribution) 1439 self._restore_normal(save_path) 1440 1441 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 1442 def testSaveNormalRestoreReplicaLocalMean(self, distribution): 1443 save_path = self._save_normal() 1444 self._restore_replica_local_mean(save_path, distribution) 1445 1446 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 1447 def testSaveNormalRestoreReplicaLocalSum(self, distribution): 1448 save_path = self._save_normal() 1449 self._restore_replica_local_sum(save_path, distribution) 1450 1451 1452class MirroredTest(test.TestCase): 1453 1454 def testAddOp(self): 1455 if context.num_gpus() < 1: 1456 self.skipTest("A GPU is not available for this test.") 1457 mirrored_val = _make_mirrored_val(init_val=3.) 1458 1459 self.assertEqual(self.evaluate(constant_op.constant(6.)), 1460 self.evaluate(mirrored_val + mirrored_val)) 1461 self.assertEqual(self.evaluate(constant_op.constant(4.)), 1462 self.evaluate(mirrored_val + 1)) 1463 self.assertEqual(self.evaluate(mirrored_val + 1), 1464 self.evaluate(math_ops.add(mirrored_val, 1))) 1465 self.assertEqual(type(mirrored_val + 1), 1466 type(math_ops.add(mirrored_val, 1))) 1467 1468 1469class PerReplicaTest(test.TestCase, parameterized.TestCase): 1470 1471 @combinations.generate(combinations.combine(mode=["eager"])) 1472 def testTypeSpec(self): 1473 vals = (constant_op.constant(1.),) 1474 per_replica = values_lib.PerReplica(vals) 1475 1476 spec = per_replica._type_spec 1477 self.assertEqual(spec._value_specs, 1478 (tensor_spec.TensorSpec([], dtypes.float32),)) 1479 1480 @combinations.generate(combinations.combine(mode=["eager"])) 1481 def testTypeSpecRoundTrip(self): 1482 vals = (constant_op.constant(1.),) 1483 per_replica = values_lib.PerReplica(vals) 1484 1485 spec = per_replica._type_spec 1486 tensor_list = spec._to_components(per_replica) 1487 reconstructed = spec._from_components(tensor_list) 1488 1489 self.assertAllEqual(per_replica.values, reconstructed.values) 1490 1491 @combinations.generate(combinations.combine(mode=["eager"])) 1492 def testTypeSpecNest(self): 1493 vals = (constant_op.constant(1.), constant_op.constant([5., 6.0]),) 1494 per_replica = values_lib.PerReplica(vals) 1495 1496 # Note: nest.map_structure exercises nest.flatten and 1497 # nest.pack_sequence_as. 1498 result = nest.map_structure( 1499 lambda t: t + 10, per_replica, expand_composites=True) 1500 1501 self.assertLen(result.values, 2) 1502 self.assertAllEqual(result.values[0], 11.) 1503 self.assertAllEqual(result.values[1], [15., 16.0]) 1504 1505 @test_util.run_in_graph_and_eager_modes 1506 def testIsGraphTensor(self): 1507 per_replica = values_lib.PerReplica((constant_op.constant(1.),)) 1508 for t in nest.flatten(per_replica, expand_composites=True): 1509 self.assertEqual(hasattr(t, "graph"), not context.executing_eagerly()) 1510 1511 @combinations.generate(combinations.combine(mode=["eager"])) 1512 def testDoesNotTriggerFunctionTracing(self): 1513 traces = [] 1514 1515 @def_function.function 1516 def f(x): 1517 traces.append(None) # Only happens on trace. 1518 return x 1519 1520 per_replica = values_lib.PerReplica((constant_op.constant(1.),)) 1521 1522 # Trace once. 1523 f(per_replica) 1524 self.assertNotEmpty(traces) 1525 del traces[:] 1526 1527 per_replica_spec = per_replica._type_spec 1528 for _ in range(5): 1529 vals = per_replica_spec._to_components(per_replica) 1530 vals = [v * 2 for v in vals] 1531 per_replica = per_replica_spec._from_components(vals) 1532 1533 output = f(per_replica) 1534 self.assertIsInstance(output, values_lib.PerReplica) 1535 self.assertAllEqual(output._values, per_replica._values) 1536 self.assertEmpty(traces) # Make sure we're not re-tracing `f`. 1537 1538 @combinations.generate(combinations.combine(mode=["eager"])) 1539 def testFunctionCanReturnPerReplica(self): 1540 f = def_function.function(lambda x: x) 1541 x = values_lib.PerReplica((constant_op.constant(1.),)) 1542 y = f(x) 1543 self.assertIsNot(x, y) 1544 nest.map_structure(self.assertAllEqual, x, y, expand_composites=True) 1545 self.assertEqual(x._type_spec, y._type_spec) 1546 1547 @test_util.run_in_graph_and_eager_modes 1548 def testCondWithTensorValues(self): 1549 per_replica_1 = values_lib.PerReplica((constant_op.constant("a"),)) 1550 per_replica_2 = values_lib.PerReplica((constant_op.constant(["b", "c"]),)) 1551 condition = array_ops.placeholder_with_default(True, []) 1552 1553 result = control_flow_ops.cond( 1554 condition, lambda: per_replica_1, lambda: per_replica_2) 1555 1556 self.assertLen(result.values, 1) 1557 self.assertAllEqual(result.values[0], "a") 1558 1559 @test_util.run_in_graph_and_eager_modes 1560 def testCondWithValuesConvertibleToTensor(self): 1561 per_replica_1 = values_lib.PerReplica(("a",)) 1562 per_replica_2 = values_lib.PerReplica(("b",)) 1563 condition = array_ops.placeholder_with_default(True, []) 1564 1565 result = control_flow_ops.cond( 1566 condition, lambda: per_replica_1, lambda: per_replica_2) 1567 1568 self.assertLen(result.values, 1) 1569 self.assertAllEqual(result.values[0], "a") 1570 1571 @test_util.build_as_function_and_v1_graph 1572 def testCondWithValuesNotConvertibleToTensor(self): 1573 per_replica_1 = values_lib.PerReplica(({"a"},)) 1574 per_replica_2 = values_lib.PerReplica(({"b", "c"},)) 1575 condition = array_ops.placeholder(dtypes.bool, []) 1576 1577 with self.assertRaisesRegex(TypeError, "Could not build a TypeSpec for"): 1578 control_flow_ops.cond( 1579 condition, lambda: per_replica_1, lambda: per_replica_2) 1580 1581 1582def _make_index_slices(values, indices, dense_shape=None): 1583 if dense_shape: 1584 dense_shape = array_ops.identity(dense_shape) 1585 return indexed_slices.IndexedSlices( 1586 array_ops.identity(values), array_ops.identity(indices), dense_shape) 1587 1588 1589if __name__ == "__main__": 1590 ds_test_util.main() 1591