1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for the distributed values library.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22import os 23 24from absl.testing import parameterized 25import numpy as np 26 27from tensorflow.core.protobuf import config_pb2 28from tensorflow.python import tf2 29from tensorflow.python.distribute import collective_all_reduce_strategy 30from tensorflow.python.distribute import combinations 31from tensorflow.python.distribute import distribute_lib 32from tensorflow.python.distribute import distribute_utils 33from tensorflow.python.distribute import packed_distributed_variable as packed 34from tensorflow.python.distribute import parameter_server_strategy 35from tensorflow.python.distribute import ps_values 36from tensorflow.python.distribute import strategy_combinations 37from tensorflow.python.distribute import test_util as ds_test_util 38from tensorflow.python.distribute import tpu_strategy 39from tensorflow.python.distribute import tpu_values 40from tensorflow.python.distribute import values as values_lib 41from tensorflow.python.eager import context 42from tensorflow.python.eager import def_function 43from tensorflow.python.eager import test 44from tensorflow.python.framework import constant_op 45from tensorflow.python.framework import dtypes 46from tensorflow.python.framework import indexed_slices 47from tensorflow.python.framework import ops 48from tensorflow.python.framework import sparse_tensor 49from tensorflow.python.framework import tensor_shape 50from tensorflow.python.framework import tensor_spec 51from tensorflow.python.framework import test_util 52from tensorflow.python.ops import array_ops 53from tensorflow.python.ops import check_ops 54from tensorflow.python.ops import control_flow_ops 55from tensorflow.python.ops import math_ops 56from tensorflow.python.ops import sparse_ops 57from tensorflow.python.ops import variable_scope 58from tensorflow.python.ops import variables as variables_lib 59from tensorflow.python.saved_model import save 60from tensorflow.python.saved_model import save_context 61from tensorflow.python.saved_model import save_options 62from tensorflow.python.training import saver as saver_lib 63from tensorflow.python.training.tracking import util as trackable_utils 64from tensorflow.python.types import core 65from tensorflow.python.util import nest 66 67 68def _device_str(d): 69 return "/device:GPU:" + str(d) 70 71 72def _nested_value(d): 73 return ("a" + d, ["b" + d, {"c": "d" + d, "e": "f" + d}, "g" + d], "h" + d) 74 75 76def _make_mirrored_val(init_val=5.0): 77 v = [] 78 devices = ["/device:GPU:0", "/device:CPU:0"] 79 for d, _ in zip(devices, ["v", "v/replica"]): 80 with ops.device(d): 81 v.append(constant_op.constant(init_val)) 82 return values_lib.Mirrored(v) 83 84 85def _make_mirrored(distribution=None): 86 v = [] 87 if distribution: 88 devices = distribution.extended.worker_devices 89 else: 90 devices = ["/device:GPU:0", "/device:CPU:0"] 91 for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]): 92 with ops.device(d): 93 v.append( 94 variable_scope.get_variable( 95 name=n, initializer=init, use_resource=True)) 96 97 if (distribution is not None) and isinstance(distribution, _TPU_STRATEGIES): 98 var_cls = tpu_values.TPUMirroredVariable 99 else: 100 var_cls = values_lib.MirroredVariable 101 mirrored = var_cls(distribution, v, variable_scope.VariableAggregation.SUM) 102 return mirrored 103 104 105def mirrored_and_tpu_strategy_combinations(): 106 return combinations.combine( 107 distribution=[ 108 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 109 strategy_combinations.tpu_strategy, 110 strategy_combinations.tpu_strategy_packed_var, 111 ], 112 mode=["graph", "eager"]) 113 114 115class DistributedValuesTest(test.TestCase, parameterized.TestCase): 116 117 @combinations.generate( 118 combinations.combine( 119 distribution=(strategy_combinations.all_strategies_minus_default + 120 strategy_combinations.multiworker_strategies), 121 mode=["eager"] 122 )) 123 def testMakeDistributedValueFromTensor(self, distribution): 124 if not tf2.enabled(): 125 self.skipTest("Only V2 is supported.") 126 single_value = constant_op.constant(1) 127 def value_fn(ctx): 128 del ctx 129 return single_value 130 131 distributed_values = ( 132 distribution.experimental_distribute_values_from_function(value_fn)) 133 self.assertAllEqual( 134 ds_test_util.gather(distribution, distributed_values), 135 constant_op.constant(1., shape=(distribution.num_replicas_in_sync))) 136 137 @combinations.generate( 138 combinations.combine( 139 distribution=(strategy_combinations.all_strategies_minus_default + 140 strategy_combinations.multiworker_strategies), 141 mode=["eager"] 142 )) 143 def testMakeDistributedValueSingleNumpyArrayConstant(self, distribution): 144 if not tf2.enabled(): 145 self.skipTest("Only V2 is supported.") 146 array_value = np.array([1., 2., 3.]) 147 def value_fn(ctx): 148 del ctx 149 return array_value 150 151 distributed_values = ( 152 distribution.experimental_distribute_values_from_function(value_fn)) 153 self.assertAllEqual( 154 ds_test_util.gather(distribution, distributed_values).numpy(), 155 [[1., 2., 3.]] * distribution.num_replicas_in_sync) 156 157 @combinations.generate( 158 combinations.combine( 159 distribution=(strategy_combinations.all_strategies_minus_default + 160 strategy_combinations.multiworker_strategies), 161 mode=["eager"] 162 )) 163 def testMakeDistributedValueTupleConstant(self, distribution): 164 if not tf2.enabled(): 165 self.skipTest("Only V2 is supported.") 166 tuple_value = (1., 2., 3.) 167 def value_fn(ctx): 168 del ctx 169 return tuple_value 170 distributed_values = ( 171 distribution.experimental_distribute_values_from_function(value_fn)) 172 distributed_values = ds_test_util.gather(distribution, distributed_values) 173 174 # Expected output for 2 replicas: 175 # ([1.0, 1.0], [2.0, 2.0], [3.0, 3.0]) 176 expected = tuple([v for i in range(distribution.num_replicas_in_sync)] 177 for v in tuple_value) 178 self.assertAllEqual(distributed_values, expected) 179 180 @combinations.generate( 181 combinations.combine( 182 distribution=(strategy_combinations.all_strategies_minus_default + 183 strategy_combinations.multiworker_strategies), 184 mode=["eager"] 185 )) 186 def testMakeDistributedValueNestedStructurePerReplica(self, distribution): 187 if not tf2.enabled(): 188 self.skipTest("Only V2 is supported.") 189 tuple_value = (1., 2., 3.) 190 def value_fn(ctx): 191 per_replica = [] 192 for val in tuple_value: 193 per_replica.append(val * ctx.replica_id_in_sync_group) 194 return tuple(per_replica) 195 distributed_values = ( 196 distribution.experimental_distribute_values_from_function(value_fn)) 197 distributed_values = ds_test_util.gather(distribution, distributed_values) 198 199 # Expected output for 2 replicas: 200 # ([0.0, 1.0], [0.0, 2.0], [0.0, 3.0]) 201 expected = tuple([v * i for i in range(distribution.num_replicas_in_sync)] 202 for v in tuple_value) 203 self.assertAllEqual(distributed_values, expected) 204 205 # NOTE(priyag): Cannot test this with MultiWorkerMirroredStrategy because 206 # collective ops do not support SparseTensors. 207 @combinations.generate( 208 combinations.combine( 209 distribution=strategy_combinations.all_strategies_minus_default, 210 mode=["eager"] 211 )) 212 def testMakeDistributedValueSpareTensor(self, distribution): 213 if not tf2.enabled(): 214 self.skipTest("Only V2 is supported.") 215 def value_fn(ctx): 216 del ctx 217 return sparse_tensor.SparseTensor( 218 indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) 219 220 distributed_values = ( 221 distribution.experimental_distribute_values_from_function(value_fn)) 222 local_results = distribution.experimental_local_results(distributed_values) 223 for i in range(distribution.num_replicas_in_sync): 224 self.assertAllEqual( 225 sparse_ops.sparse_tensor_to_dense(local_results[i]), 226 [[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]]) 227 228 @combinations.generate( 229 combinations.combine( 230 distribution=(strategy_combinations.all_strategies_minus_default + 231 strategy_combinations.multiworker_strategies), 232 mode=["eager"] 233 )) 234 def testMakeDistributedValueExtractFromArray(self, distribution): 235 if not tf2.enabled(): 236 self.skipTest("Only V2 is supported.") 237 multiple_values = range(distribution.num_replicas_in_sync) 238 def value_fn(ctx): 239 return multiple_values[ctx.replica_id_in_sync_group] 240 distributed_values = ( 241 distribution.experimental_distribute_values_from_function(value_fn)) 242 distributed_values = ds_test_util.gather(distribution, distributed_values) 243 expected = range(distribution.num_replicas_in_sync) 244 self.assertAllEqual(distributed_values, expected) 245 246 @combinations.generate( 247 combinations.combine( 248 distribution=(strategy_combinations.all_strategies_minus_default + 249 strategy_combinations.multiworker_strategies), 250 mode=["eager"] 251 )) 252 def testMakeDistributedValueAndRun(self, distribution): 253 if not tf2.enabled(): 254 self.skipTest("Only V2 is supported.") 255 256 @def_function.function 257 def run(): 258 multiple_values = range(distribution.num_replicas_in_sync) 259 def value_fn(ctx): 260 return multiple_values[ctx.replica_id_in_sync_group] 261 distributed_values = ( 262 distribution.experimental_distribute_values_from_function(value_fn)) 263 264 def computation(x): 265 return math_ops.square(x) 266 267 outputs = ds_test_util.gather( 268 distribution, 269 distribution.run(computation, args=(distributed_values,))) 270 return outputs 271 272 results = run() 273 274 expected = [i**2 for i in range(distribution.num_replicas_in_sync)] 275 self.assertAllEqual(results, expected) 276 277 @combinations.generate( 278 combinations.combine( 279 distribution=[ 280 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 281 strategy_combinations.tpu_strategy, 282 strategy_combinations.tpu_strategy_packed_var, 283 strategy_combinations.central_storage_strategy_with_two_gpus, 284 ] + strategy_combinations.multiworker_strategies, 285 mode=["eager"])) 286 def testMakeDistributedValueDefaultDevicePlacement(self, distribution): 287 if not tf2.enabled(): 288 self.skipTest("Only V2 is supported.") 289 def value_fn(ctx): 290 del ctx 291 return constant_op.constant(1.0) 292 distributed_values = ( 293 distribution.experimental_distribute_values_from_function(value_fn)) 294 for i in range(len(distribution.extended.worker_devices)): 295 self.assertAllEqual(distributed_values._values[i].device, 296 "/job:localhost/replica:0/task:0/device:CPU:0") 297 298 @combinations.generate( 299 combinations.combine( 300 distribution=[ 301 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 302 strategy_combinations.tpu_strategy, 303 strategy_combinations.tpu_strategy_packed_var, 304 strategy_combinations.central_storage_strategy_with_two_gpus, 305 ] + strategy_combinations.multiworker_strategies, 306 mode=["eager"])) 307 def testMakeDistributedValueExplicitDevicePlacement(self, distribution): 308 if not tf2.enabled(): 309 self.skipTest("Only V2 is supported.") 310 worker_devices = distribution.extended.worker_devices 311 def value_fn(ctx): 312 # In multi client setup, worker_devices is just the devices on that 313 # worker. 314 worker_device_id = ctx.replica_id_in_sync_group % len(worker_devices) 315 with ops.device(worker_devices[worker_device_id]): 316 return array_ops.identity(1.0) 317 distributed_values = ( 318 distribution.experimental_distribute_values_from_function(value_fn)) 319 for i in range(len(distribution.extended.worker_devices)): 320 self.assertAllEqual(distributed_values._values[i].device, 321 worker_devices[i]) 322 323 324class DistributedDelegateTest(test.TestCase): 325 326 @test_util.run_in_graph_and_eager_modes 327 def testGetAttr(self): 328 class Foo(object): 329 330 def __init__(self, x): 331 self.x = x 332 333 v = values_lib.DistributedDelegate((Foo(7), Foo(8))) 334 self.assertEqual(7, v.x) 335 with self.assertRaises(AttributeError): 336 _ = v.y 337 338 @test_util.run_in_graph_and_eager_modes 339 def testOperatorOverride(self): 340 v = values_lib.DistributedDelegate((7, 8)) 341 # v should act like int(7). 342 self.assertEqual(8, v + 1) 343 self.assertEqual(10, 3 + v) 344 self.assertEqual(14, v + v) 345 self.assertEqual(5, v - 2) 346 self.assertEqual(6, 13 - v) 347 self.assertEqual(0, v - v) 348 self.assertEqual(14, v * 2) 349 self.assertEqual(21, 3 * v) 350 self.assertEqual(49, v * v) 351 self.assertEqual(3.5, v / 2) 352 self.assertEqual(1.5, 10.5 / v) 353 self.assertEqual(3, v // 2) 354 self.assertEqual(2, 15 // v) 355 self.assertEqual(1, v % 2) 356 self.assertEqual(2, 16 % v) 357 # pylint: disable=g-generic-assert 358 self.assertTrue(v < 12) 359 self.assertTrue(v <= 12) 360 self.assertFalse(v > 12) 361 self.assertFalse(v >= 12) 362 self.assertFalse(12 < v) 363 self.assertFalse(12 <= v) 364 self.assertTrue(12 > v) 365 self.assertTrue(12 >= v) 366 # pylint: enable=g-generic-assert 367 self.assertEqual(3, v & 3) 368 self.assertEqual(3, 11 & v) 369 self.assertEqual(15, v | 8) 370 self.assertEqual(23, 16 | v) 371 self.assertEqual(4, v ^ 3) 372 self.assertEqual(12, 11 ^ v) 373 self.assertEqual(343, pow(v, 3)) 374 self.assertEqual(3, pow(v, 3, 10)) 375 self.assertEqual(128, pow(2, v)) 376 self.assertEqual(-7, -v) 377 self.assertEqual(~7, ~v) 378 self.assertEqual(7, abs(v)) 379 with self.assertRaises(TypeError): 380 _ = v[2] 381 382 @test_util.run_in_graph_and_eager_modes 383 def testCopy(self): 384 385 class Foo(object): 386 387 def __init__(self, x): 388 self.x = x 389 390 v = values_lib.DistributedDelegate((Foo(7), Foo(8))) 391 v_shallow_copy = copy.copy(v) 392 self.assertEqual(v.x, v_shallow_copy.x) 393 v_deep_copy = copy.deepcopy(v) 394 self.assertEqual(v.x, v_deep_copy.x) 395 396 397@combinations.generate( 398 combinations.combine( 399 distribution=[ 400 strategy_combinations.mirrored_strategy_with_one_cpu, 401 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 402 strategy_combinations.tpu_strategy, 403 strategy_combinations.tpu_strategy_packed_var, 404 strategy_combinations.central_storage_strategy_with_gpu_and_cpu, 405 strategy_combinations.multi_worker_mirrored_2x1_cpu, 406 strategy_combinations.multi_worker_mirrored_2x1_gpu, 407 strategy_combinations.multi_worker_mirrored_2x2_gpu 408 ], 409 synchronization=[ 410 variables_lib.VariableSynchronization.ON_READ, 411 variables_lib.VariableSynchronization.ON_WRITE, 412 ], 413 aggregation=[ 414 variables_lib.VariableAggregation.MEAN, 415 variables_lib.VariableAggregation.SUM, 416 variables_lib.VariableAggregation.ONLY_FIRST_REPLICA, 417 ], 418 mode=["graph", "eager"], 419 use_var_policy=[True, False])) 420class DistributedVariableTest(test.TestCase, parameterized.TestCase): 421 422 def testExtendsVariable(self, distribution, synchronization, aggregation): 423 with distribution.scope(): 424 v = variables_lib.Variable( 425 1., synchronization=synchronization, aggregation=aggregation) 426 self.assertIsInstance(v, variables_lib.Variable) 427 428 def testCheckpointing(self, distribution, synchronization, aggregation, mode): 429 430 if (isinstance(distribution, 431 collective_all_reduce_strategy.CollectiveAllReduceStrategy) 432 and mode == "graph"): 433 self.skipTest("MWMS combinations tests do not work well in graph mode.") 434 435 with distribution.scope(): 436 v = variables_lib.Variable( 437 constant_op.constant([1., 2., 3., 4]), 438 synchronization=synchronization, 439 aggregation=aggregation) 440 441 self.evaluate(v.initializer) 442 before_save = self.evaluate(v.read_value()) 443 444 # Save random weights into checkpoint. 445 checkpoint = trackable_utils.Checkpoint(v=v) 446 prefix = os.path.join(self.get_temp_dir(), "ckpt") 447 with self.test_session(): 448 save_path = checkpoint.save(prefix) 449 450 # Assign inverted value. 451 self.evaluate(v.assign(constant_op.constant([4., 3., 2., 1.]))) 452 after_assign = self.evaluate(v.read_value()) 453 self.assertNotAllClose(before_save, after_assign) 454 455 # Restore from the checkpoint. 456 with self.test_session(): 457 checkpoint.restore(save_path).assert_consumed().run_restore_ops() 458 after_restore = self.evaluate(v) 459 self.assertAllClose(before_save, after_restore) 460 461 def testTraceback(self, distribution, synchronization, aggregation): 462 if context.executing_eagerly(): 463 self.skipTest("does not apply to eager") 464 with distribution.scope(): 465 variable_scope.get_variable( 466 name="testVar", 467 initializer=1., 468 use_resource=True, 469 synchronization=synchronization, 470 aggregation=aggregation) 471 with self.assertRaisesRegex(ValueError, 472 "Variable testVar already exists"): 473 variable_scope.get_variable( 474 name="testVar", 475 initializer=1., 476 use_resource=True, 477 synchronization=synchronization, 478 aggregation=aggregation) 479 480 def testSelectReplica(self, distribution, synchronization, aggregation): 481 with distribution.scope(): 482 v = variables_lib.Variable( 483 1., synchronization=synchronization, aggregation=aggregation) 484 self.assertIs(v, distribute_utils.select_replica(0, v)) 485 486 def testIsTensorLike(self, distribution, synchronization, aggregation): 487 if isinstance(distribution.extended, 488 tpu_strategy.TPUExtended) and context.executing_eagerly(): 489 self.skipTest("TPU doesn't support pure eager") 490 491 with distribution.scope(): 492 v = variables_lib.Variable( 493 0., synchronization=synchronization, aggregation=aggregation) 494 # In cross replica context. 495 self.assertIsInstance(v, core.Tensor) 496 # In replica context. 497 distribution.run( 498 lambda v: self.assertIsInstance(v, core.Tensor), args=(v,)) 499 500 def testAssignReturnValueIsTensorLike(self, distribution, synchronization, 501 aggregation): 502 if isinstance(distribution.extended, tpu_strategy.TPUExtended): 503 if context.executing_eagerly(): 504 self.skipTest("TPU doesn't support pure eager") 505 else: 506 self.skipTest("b/152076846") 507 508 with distribution.scope(): 509 v = variables_lib.Variable( 510 0., synchronization=synchronization, aggregation=aggregation) 511 512 def assert_is_tensor_like(v): 513 # We can't use Python literals because they are treated as non-distributed 514 # values is not allowed when aggregation is SUM. See 515 # `cross_device_ops.reduce_non_distributed_value`. 516 delta = array_ops.identity(1.) 517 self.assertIsInstance(v.assign(delta), core.Tensor) 518 self.assertIsInstance(v.assign_sub(delta), core.Tensor) 519 self.assertIsInstance(v.assign_add(delta), core.Tensor) 520 521 # In cross replica context we return a PerReplica which is not Tensor like 522 # all the time yet. 523 if (synchronization == variables_lib.VariableSynchronization.ON_READ and 524 aggregation != variables_lib.VariableAggregation.SUM): 525 assert_is_tensor_like(v) 526 527 # In replica context. 528 distribution.run(assert_is_tensor_like, args=(v,)) 529 530 def testDeepCopy(self, distribution, synchronization, 531 aggregation): 532 if not context.executing_eagerly(): 533 self.skipTest("deepcopy only supported in eager mode") 534 535 with distribution.scope(): 536 v = variables_lib.Variable( 537 0., synchronization=synchronization, aggregation=aggregation) 538 in_dist_copy = copy.deepcopy(v) 539 540 out_dist_copy = copy.deepcopy(v) 541 542 def assert_is_deep_copy(v1, v2): 543 self.assertIsInstance(v2, type(v1)) 544 self.assertEqual(v1.aggregation, v2.aggregation) 545 self.assertEqual(v1.distribute_strategy, v2.distribute_strategy) 546 if isinstance(v1, ps_values.AggregatingVariable): 547 self.assertIsInstance(v2.get(), type(v1.get())) 548 self.assertNotEqual(id(v1.get()), id(v2.get())) 549 else: 550 if v1._policy: 551 self.assertNotEqual(id(v1._policy), id(v2._policy)) # pylint: disable=protected-access 552 else: 553 self.assertEqual(id(v1._policy), id(v2._policy)) # pylint: disable=protected-access 554 self.assertEqual(len(v1.values), len(v2.values)) 555 for (v1v, v2v) in zip(v1.values, v2.values): 556 self.assertEqual(v1v.device, v2v.device) 557 self.assertNotEqual(id(v1v), id(v2v)) 558 self.assertAllEqual(self.evaluate(v1.values), 559 self.evaluate(v2.values)) 560 561 self.evaluate(variables_lib.global_variables_initializer()) 562 if not isinstance(distribution.extended, tpu_strategy.TPUExtended): 563 distribution.run(assert_is_deep_copy, args=(v, in_dist_copy)) 564 distribution.run(assert_is_deep_copy, args=(v, out_dist_copy)) 565 566 def testAssignSignature(self, distribution, synchronization, aggregation): 567 # This test verifies assign*() can be called in the same way as normal 568 # variables. 569 with distribution.scope(): 570 v = variables_lib.Variable( 571 0., synchronization=synchronization, aggregation=aggregation) 572 573 def assign(): 574 one = constant_op.constant(1.) 575 v.assign(one, True, "assign", False) 576 # TODO(b/154017756): SyncOnReadVariable.assign() doesn't support passing 577 # value as a keyword argument. 578 v.assign(one, use_locking=True, name="assign", read_value=False) 579 v.assign_add(one, True, "assign", False) 580 v.assign_add(one, use_locking=True, name="assign", read_value=False) 581 v.assign_sub(one, True, "assign", False) 582 v.assign_sub(one, use_locking=True, name="assign", read_value=False) 583 # Return something for graph mode to fetch. 584 return constant_op.constant(1) 585 586 self.evaluate(variables_lib.global_variables_initializer()) 587 if not (synchronization == variables_lib.VariableSynchronization.ON_READ 588 and aggregation == variables_lib.VariableAggregation.SUM): 589 self.evaluate(distribution.experimental_local_results(assign())) 590 if not (isinstance(distribution.extended, tpu_strategy.TPUExtended) and 591 context.executing_eagerly()): 592 self.evaluate( 593 distribution.experimental_local_results(distribution.run(assign))) 594 595 def testStrategyExtendedUpdate(self, distribution, synchronization, 596 aggregation): 597 if len(distribution.extended.parameter_devices) != 2: 598 self.skipTest("n/a: needs exactly two parameter devices") 599 if (synchronization == variables_lib.VariableSynchronization.ON_WRITE and 600 aggregation != variables_lib.VariableAggregation.NONE): 601 self.skipTest("n/a: doesn't apply to ON_WRITE variable with aggregation") 602 with distribution.scope(): 603 v = variables_lib.Variable( 604 0., synchronization=synchronization, aggregation=aggregation) 605 value = values_lib.PerReplica([1., 2.]) 606 607 assign_fn = lambda var, value: var.assign(value) 608 self.evaluate(distribution.extended.update(v, assign_fn, args=(value,))) 609 self.assertAllEqual(self.evaluate(v.values), [1., 2.]) 610 611 assign_add_fn = lambda var, value: var.assign_add(value) 612 self.evaluate(distribution.extended.update(v, assign_add_fn, args=(value,))) 613 self.assertAllEqual(self.evaluate(v.values), [2., 4.]) 614 615 assign_sub_fn = lambda var, value: var.assign_sub(value) 616 self.evaluate(distribution.extended.update(v, assign_sub_fn, args=(value,))) 617 self.assertAllEqual(self.evaluate(v.values), [1., 2.]) 618 619 read_assign_fn = lambda var, value: var.assign_add(var.value() + var. 620 read_value()) 621 self.evaluate( 622 distribution.extended.update(v, read_assign_fn, args=(value,))) 623 self.assertAllEqual(self.evaluate(v.values), [3., 6.]) 624 625 def testSaveNonDistributed(self, distribution, synchronization, aggregation): 626 # This test verifies that the DistributedVariable behave like the primary 627 # variable when saving a non-distributed version of the model (the default). 628 # The test asserts that the function traced under SaveContext has no device 629 # annotations and only reference the primary component of the variable. Note 630 # that please avoid capturing other eager tensors in this test to make the 631 # assertion easy. 632 633 if isinstance(distribution.extended, 634 parameter_server_strategy.ParameterServerStrategyExtended): 635 self.skipTest("b/148689177: AggregatingVariable doesn't " 636 "conform to Variable interface well") 637 638 # tf.function requires the return value to be Tensors, which is not always 639 # case for properties and methods of Variable, so we simply discard the 640 # return values. 641 def _discard_return(f): 642 f() 643 return 644 645 def _test(f, v): 646 # This verifies that the function under SaveContext: 647 # - contains no device annotations. 648 # - only references the primary component of the variable. 649 g = def_function.function(lambda: _discard_return(f)) 650 options = save_options.SaveOptions( 651 experimental_variable_policy=save_options.VariablePolicy.NONE) 652 with save_context.save_context(options): 653 # The graph should contain no device. 654 graph = g.get_concrete_function().graph 655 for op in graph.get_operations(): 656 self.assertEqual(op.device, "", msg=str(op)) 657 # The function should only capture the primary variable. Note that it 658 # may not have captures, e.g. v.aggregation. 659 captures = list(graph.captures) 660 self.assertLessEqual(len(captures), 1) 661 if graph.captures: 662 self.assertIs(captures[0][0], v._primary.handle) 663 664 def _assert(cond): 665 return control_flow_ops.Assert(cond, [cond]) 666 667 with distribution.scope(): 668 # We use four variables for convenience reasons. They have no special 669 # meaning. 670 # - v is used whenever possible. 671 # - w is used for scatter and gather, which require the variable to be 672 # non-scalar. 673 # - y is used when the dtype needs to be integer. Note that aggregation 674 # cannot be MEAN for integers. 675 v = variables_lib.Variable( 676 0., 677 synchronization=synchronization, 678 aggregation=aggregation, 679 trainable=True) 680 w = variables_lib.Variable([0., 0., 0.], 681 synchronization=synchronization, 682 aggregation=aggregation, 683 trainable=True) 684 if aggregation != variables_lib.VariableAggregation.MEAN: 685 y = variables_lib.Variable( 686 0, 687 synchronization=synchronization, 688 aggregation=aggregation) 689 690 # pylint: disable=g-long-lambda 691 692 # tf.Variable properties. 693 _test(lambda: self.assertEqual(v.aggregation, aggregation), v) 694 _test(lambda: self.assertIs(v.constraint, None), v) 695 # TODO(crccw): should we raise an error instead? 696 _test(lambda: self.assertEqual(v.device, v._primary.device), v) 697 _test(lambda: self.assertEqual(v.dtype, dtypes.float32), v) 698 if not context.executing_eagerly(): 699 _test(lambda: self.assertIs(v.graph, v._primary.graph), v) 700 if not context.executing_eagerly(): 701 _test(lambda: _assert(v.initial_value == 0), v) 702 _test(lambda: self.assertIs(v.initializer, v._primary.initializer), v) 703 _test(lambda: self.assertEqual(v.name, "Variable:0"), v) 704 if not context.executing_eagerly(): 705 _test(lambda: self.assertIs(v.op, v._primary.op), v) 706 _test(lambda: self.assertEqual(v.shape, tensor_shape.TensorShape(())), v) 707 _test(lambda: self.assertEqual(v.synchronization, synchronization), v) 708 _test(lambda: self.assertTrue(v.trainable, True), v) 709 710 # tf.Variable methods. 711 _test(lambda: check_ops.assert_equal_v2(v.assign(1.), 1.), v) 712 _test(lambda: check_ops.assert_equal_v2(v.assign_add(1.), 2.), v) 713 _test(lambda: check_ops.assert_equal_v2(v.assign_sub(1.), 1.), v) 714 # TODO(b/148689177): Implement batch_scatter_update. 715 # count_up_to() is skipped since it's deprecated. 716 # eval() is skipped since it shouldn't called in a tf.function. 717 # experimental_ref() is skipped since it's deprecated. 718 # from_proto() is skipped since it shouldn't called in a tf.function. 719 # TODO(b/148689177): Implement gather_nd. 720 _test( 721 lambda: check_ops.assert_equal_v2(v.get_shape(), 722 tensor_shape.TensorShape(())), v) 723 # initialized_value() is skipped since it shouldn't called in a tf.function. 724 # load() is skipped since it shouldn't called in a tf.function. 725 _test(lambda: check_ops.assert_equal_v2(v.read_value(), 1.), v) 726 # ref() is skipped since it shouldn't called in a tf.function. 727 _test( 728 lambda: check_ops.assert_equal_v2( 729 w.scatter_add(_make_index_slices(values=[1., 2.], indices=[0, 2])), 730 [1., 0., 2.]), w) 731 _test( 732 lambda: check_ops.assert_equal_v2( 733 w.scatter_div(_make_index_slices(values=[4., 2.], indices=[0, 2])), 734 [0.25, 0., 1.]), w) 735 _test( 736 lambda: check_ops.assert_equal_v2( 737 w.scatter_max(_make_index_slices(values=[1., 0.5], indices=[1, 2])), 738 [0.25, 1., 1.]), w) 739 _test( 740 lambda: check_ops.assert_equal_v2( 741 w.scatter_min(_make_index_slices(values=[1., 0.5], indices=[0, 1])), 742 [0.25, 0.5, 1.]), w) 743 _test( 744 lambda: check_ops.assert_equal_v2( 745 w.scatter_mul(_make_index_slices(values=[2., 0.5], indices=[0, 1])), 746 [0.5, 0.25, 1.]), w) 747 # TODO(b/148689177): Implement scatter_nd_* 748 _test( 749 lambda: check_ops.assert_equal_v2( 750 w.scatter_sub(_make_index_slices(values=[2., 0.5], indices=[0, 1])), 751 [-1.5, -0.25, 1.]), w) 752 _test( 753 lambda: check_ops.assert_equal_v2( 754 w.scatter_update( 755 _make_index_slices(values=[2., 0.5], indices=[0, 1])), 756 [2., 0.5, 1.]), w) 757 # set_shape() is skipped since ResourceVariable doesn't implement it. 758 # to_proto() is skipped since it shouldn't called in a tf.function. 759 _test(lambda: check_ops.assert_equal_v2(v.value(), 1.), v) 760 761 # DistributedVariable should be treated as ResourceVariable, so it needs to 762 # conform to ResourceVariable interface as well. 763 _test(lambda: self.assertIs(v.handle, v._primary.handle), v) 764 765 # Convert to tensor. 766 _test(lambda: check_ops.assert_equal_v2(ops.convert_to_tensor(v), 1.), v) 767 768 # Control dependency. 769 def _with_control_dep(): 770 with ops.control_dependencies([v.assign(1.)]): 771 return array_ops.identity(1) 772 773 _test(_with_control_dep, v) 774 775 # Operator overloads. 776 _test(lambda: check_ops.assert_equal_v2(v.assign(7.), 7.), v) 777 _test(lambda: check_ops.assert_equal_v2(v + 1., 8.), v) 778 _test(lambda: check_ops.assert_equal_v2(3 + v, 10.), v) 779 _test(lambda: check_ops.assert_equal_v2(v + v, 14.), v) 780 _test(lambda: check_ops.assert_equal_v2(v - 2., 5.), v) 781 _test(lambda: check_ops.assert_equal_v2(v - v, 0.), v) 782 _test(lambda: check_ops.assert_equal_v2(v * 2., 14.), v) 783 _test(lambda: check_ops.assert_equal_v2(3 * v, 21.), v) 784 _test(lambda: check_ops.assert_equal_v2(v * v, 49.), v) 785 _test( 786 lambda: check_ops.assert_equal_v2( 787 math_ops.cast(v / 2., dtypes.float32), 3.5), v) 788 _test( 789 lambda: check_ops.assert_equal_v2( 790 math_ops.cast(14. / v, dtypes.float32), 2.), v) 791 _test(lambda: _assert(v < 12.), v) 792 _test(lambda: _assert(v <= 12.), v) 793 _test(lambda: _assert(not v > 12.), v) 794 _test(lambda: _assert(not v >= 12.), v) 795 _test(lambda: _assert(not 12. < v), v) 796 _test(lambda: _assert(not 12. <= v), v) 797 _test(lambda: _assert(12. > v), v) 798 _test(lambda: _assert(12. >= v), v) 799 _test(lambda: check_ops.assert_near_v2(pow(v, 3.), 343.), v) 800 _test(lambda: check_ops.assert_near_v2(pow(2., v), 128.), v) 801 _test(lambda: check_ops.assert_equal_v2(abs(v), 7.), v) 802 803 # Operator overloads that only works for integers. 804 if aggregation != variables_lib.VariableAggregation.MEAN: 805 _test(lambda: check_ops.assert_equal_v2(y.assign(7), 7), y) 806 _test(lambda: check_ops.assert_equal_v2(y // 2, 3), y) 807 _test(lambda: check_ops.assert_equal_v2(15 // y, 2), y) 808 _test(lambda: check_ops.assert_equal_v2(y % 2, 1), y) 809 _test(lambda: check_ops.assert_equal_v2(16 % y, 2), y) 810 _test(lambda: check_ops.assert_equal_v2(y & 3, 3), y) 811 _test(lambda: check_ops.assert_equal_v2(3 & y, 3), y) 812 _test(lambda: check_ops.assert_equal_v2(y | 8, 15), y) 813 _test(lambda: check_ops.assert_equal_v2(16 | y, 23), y) 814 _test(lambda: check_ops.assert_equal_v2(y ^ 3, 4), y) 815 _test(lambda: check_ops.assert_equal_v2(11 ^ y, 12), y) 816 _test(lambda: check_ops.assert_equal_v2(-y, -7), y) 817 _test(lambda: check_ops.assert_equal_v2(~y, ~7), y) 818 819 # Index. 820 if isinstance(distribution.extended, tpu_strategy.TPUExtended): 821 # TODO(b/161572567): slice assignment doesn't work for TPU. 822 _test(lambda: check_ops.assert_equal_v2(w[0], 2.), w) 823 else: 824 _test(lambda: check_ops.assert_equal_v2(w[0].assign(1.), [1., 0.5, 1.]), 825 w) 826 _test(lambda: check_ops.assert_equal_v2(w[0], 1.), w) 827 828 # pylint: enable=g-long-lambda 829 830 def testUnsaveable(self, distribution, synchronization, aggregation, mode): 831 if isinstance(distribution.extended, 832 parameter_server_strategy.ParameterServerStrategyExtended): 833 self.skipTest("n/a: not appliable to AggregatingVariable") 834 if (isinstance(distribution, 835 collective_all_reduce_strategy.CollectiveAllReduceStrategy) 836 and mode == "graph"): 837 self.skipTest("MWMS combinations tests do not work well in graph mode.") 838 with distribution.scope(): 839 v = variables_lib.Variable([1., 1.], 840 synchronization=synchronization, 841 aggregation=aggregation) 842 843 with self.cached_session(): 844 self.evaluate(variables_lib.global_variables_initializer()) 845 846 export_dir = self.get_temp_dir() 847 848 def _assert_unsaveable(f): 849 # Ignore if it cannot be traced. Certain combinations are not supported or 850 # yet or not allowed. 851 try: 852 f = def_function.function(f).get_concrete_function() 853 except (NotImplementedError, ValueError): 854 return 855 with self.assertRaisesRegex(ValueError, "f_with_input_signature"): 856 save.save(v, export_dir, signatures=f) 857 858 _assert_unsaveable(lambda: v.assign(ops.convert_to_tensor([1., 1.]))) 859 _assert_unsaveable(lambda: v.assign_add(ops.convert_to_tensor([1., 1.]))) 860 _assert_unsaveable(lambda: v.assign_sub(ops.convert_to_tensor([1., 1.]))) 861 _assert_unsaveable(lambda: v.scatter_add(_make_index_slices([1.], [0]))) 862 _assert_unsaveable(lambda: v.scatter_sub(_make_index_slices([1.], [0]))) 863 _assert_unsaveable(lambda: v.scatter_mul(_make_index_slices([1.], [0]))) 864 _assert_unsaveable(lambda: v.scatter_div(_make_index_slices([1.], [0]))) 865 _assert_unsaveable(lambda: v.scatter_min(_make_index_slices([1.], [0]))) 866 _assert_unsaveable(lambda: v.scatter_max(_make_index_slices([1.], [0]))) 867 _assert_unsaveable(lambda: v.scatter_update(_make_index_slices([1.], [0]))) 868 # Reading a ON_READ variable should be unsaveable if either: 869 # 1) CollectiveAllReduceStrategy, and aggregation is MEAN/SUM. 870 # 2) aggregation is SUM. 871 if (synchronization == variables_lib.VariableSynchronization.ON_READ and 872 (aggregation == variables_lib.VariableAggregation.SUM or 873 (isinstance(distribution.extended, 874 collective_all_reduce_strategy.CollectiveAllReduceExtended) 875 and aggregation == variables_lib.VariableAggregation.MEAN))): 876 _assert_unsaveable(v.read_value) 877 _assert_unsaveable(v.value) 878 _assert_unsaveable(lambda: ops.convert_to_tensor(v)) 879 else: 880 # Otherwise reading a variable should be saveable. 881 882 @def_function.function 883 def f(): 884 v.read_value() 885 v.value() 886 return ops.convert_to_tensor(v) 887 888 with self.cached_session(): 889 save.save(v, export_dir, signatures=f.get_concrete_function()) 890 891 892@combinations.generate( 893 combinations.combine( 894 distribution=[ 895 strategy_combinations.mirrored_strategy_with_one_cpu, 896 strategy_combinations.tpu_strategy, 897 ], 898 mode=["eager"])) 899class PackedDistributedVariableTest(test.TestCase, parameterized.TestCase): 900 901 def testPackedVariable(self, distribution): 902 with distribution.scope(): 903 v0 = variables_lib.Variable(0.) 904 self.assertIsNone(v0._packed_var) 905 906 distribution._enable_packed_variable_in_eager_mode = True 907 with distribution.scope(): 908 v1 = variables_lib.Variable(0) 909 self.assertIsInstance(v1._packed_var, packed.PackedDistributedVariable) 910 911 devices = v1._devices 912 for i in range(1, len(devices)): 913 with distribute_lib.ReplicaContext(distribution, i): 914 v1.assign(i) 915 val = v1._get() 916 self.assertIsInstance(val, packed.PackedVarAndDevice) 917 self.assertEqual(val.device, devices[0]) 918 self.assertEqual(self.evaluate(val.read_value()), 0) 919 for i in range(0, len(devices)): 920 with distribute_lib.ReplicaContext(distribution, i): 921 val = v1._get() 922 self.assertIsInstance(val, packed.PackedVarAndDevice) 923 self.assertEqual(val.device, devices[i]) 924 self.assertEqual(self.evaluate(val.read_value()), i) 925 926 def testIgnorePackedVariableInSaveContext(self, distribution): 927 distribution._enable_packed_variable_in_eager_mode = True 928 with distribution.scope(): 929 v = variables_lib.Variable(0) 930 self.assertIsInstance( 931 v._packed_variable, packed.PackedDistributedVariable) 932 933 options = save_options.SaveOptions() 934 with save_context.save_context(options): 935 self.assertIsNone(v._packed_variable) 936 937 938class MirroredVariableTest(test.TestCase, parameterized.TestCase): 939 940 config = config_pb2.ConfigProto() 941 config.allow_soft_placement = True 942 943 @test_util.run_in_graph_and_eager_modes(config=config) 944 def testProperties(self): 945 if context.num_gpus() < 1 and context.executing_eagerly(): 946 self.skipTest("A GPU is not available for this test in eager mode.") 947 948 mirrored = _make_mirrored() 949 v = mirrored.values[0] 950 self.assertEqual(v.name, mirrored.name) 951 self.assertEqual(v.dtype, mirrored.dtype) 952 self.assertEqual(v.shape, mirrored.shape) 953 954 @test_util.run_in_graph_and_eager_modes(config=config) 955 def testVariableOnAnotherDevice(self): 956 v = variable_scope.get_variable( 957 name="v", initializer=[1.], use_resource=True) 958 mirrored = values_lib.MirroredVariable( 959 None, (v,), variable_scope.VariableAggregation.MEAN) 960 961 self.assertEqual(v.name, mirrored.name) 962 self.assertEqual(v.dtype, mirrored.dtype) 963 self.assertEqual(v.shape, mirrored.shape) 964 965 966class MirroredVariableSaveRestoreTest(test.TestCase, parameterized.TestCase): 967 968 def _assign_mirrored(self, v, new): 969 for var, n in zip(v.values, new): 970 self.evaluate(var.assign(n)) 971 972 def _save_return_saver(self, sess, var): 973 saver = saver_lib.Saver(var_list=[var]) 974 test_dir = self.get_temp_dir() 975 prefix = os.path.join(test_dir, "ckpt") 976 return saver.save(sess, prefix), saver 977 978 def _save(self, sess, var): 979 save_path, _ = self._save_return_saver(sess, var) 980 return save_path 981 982 def _save_mirrored(self, distribution): 983 """Save variables with mirroring, returns save_path.""" 984 with self.session(graph=ops.Graph()) as sess: 985 mirrored = _make_mirrored(distribution) 986 987 # Overwrite the initial values. 988 self._assign_mirrored(mirrored, [3., 4.]) 989 990 # Saves the current value of v[0], 3. 991 save_path = self._save(sess, mirrored) 992 993 # Change the values between save and restore. 994 self._assign_mirrored(mirrored, [5., 6.]) 995 return save_path 996 997 def _save_normal(self): 998 """Save variables without mirroring, returns save_path.""" 999 with self.session(graph=ops.Graph()) as sess: 1000 var = variable_scope.get_variable( 1001 name="v", initializer=1., use_resource=True) 1002 1003 # Overwrite the initial value. 1004 self.evaluate(var.assign(3.)) 1005 1006 # Saves the current value of var, 3. 1007 save_path = self._save(sess, var) 1008 1009 # Change the values between save and restore. 1010 self.evaluate(var.assign(5.)) 1011 return save_path 1012 1013 def _restore_normal(self, save_path): 1014 """Restore to variables without mirroring in a fresh graph.""" 1015 with self.session(graph=ops.Graph()) as sess: 1016 var = variable_scope.get_variable( 1017 name="v", initializer=7., use_resource=True) 1018 1019 # Overwrite the initial value. 1020 self.evaluate(var.assign(8.)) 1021 1022 # Restores the saved value of 3. to `var`. 1023 saver = saver_lib.Saver(var_list=[var]) 1024 saver.restore(sess, save_path) 1025 self.assertEqual(3., self.evaluate(var)) 1026 1027 def _restore_mirrored(self, save_path, distribution): 1028 """Restore to variables with mirroring in a fresh graph.""" 1029 with self.session(graph=ops.Graph()) as sess: 1030 mirrored = _make_mirrored(distribution) 1031 v = mirrored.values 1032 1033 # Overwrite the initial values. 1034 self._assign_mirrored(mirrored, [7., 8.]) 1035 1036 # Restores the saved value of 3. to both variables. 1037 saver = saver_lib.Saver(var_list=[mirrored]) 1038 saver.restore(sess, save_path) 1039 self.assertEqual([3., 3.], self.evaluate([v[0], v[1]])) 1040 1041 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 1042 def testSaveAndRestoreMirroredOneGraph(self, distribution): 1043 with self.cached_session() as sess: 1044 mirrored = _make_mirrored(distribution) 1045 v = mirrored .values 1046 1047 # Overwrite the initial values. 1048 self._assign_mirrored(mirrored, [3., 4.]) 1049 1050 # Saves the current value of v[0], 3. 1051 save_path, saver = self._save_return_saver(sess, mirrored) 1052 1053 # Change the values between save and restore. 1054 self._assign_mirrored(mirrored, [5., 6.]) 1055 1056 # Restores the saved value of 3. to both variables. 1057 saver.restore(sess, save_path) 1058 self.assertEqual([3., 3.], self.evaluate([v[0], v[1]])) 1059 1060 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 1061 def testSaveMirroredRestoreMirrored(self, distribution): 1062 if context.num_gpus() < 1 and context.executing_eagerly(): 1063 # Graph mode can work without GPU because the Placer "moves" the 1064 # variable to a CPU. In other words, if there is no GPU available, but 1065 # user requested to create a variable on GPU, Placer will ignore the 1066 # user request and assign the VarHandleOp to CPU. This requires 1067 # soft_placement, which is on by default. 1068 self.skipTest("A GPU is not available for this test in eager mode.") 1069 1070 save_path = self._save_mirrored(distribution) 1071 self._restore_mirrored(save_path, distribution) 1072 1073 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 1074 def testSaveMirroredRestoreNormal(self, distribution): 1075 if context.num_gpus() < 1 and context.executing_eagerly(): 1076 # Graph mode can work without GPU because the Placer "moves" the 1077 # variable to a CPU. In other words, if there is no GPU available, but 1078 # user requested to create a variable on GPU, Placer will ignore the 1079 # user request and assign the VarHandleOp to CPU. This requires 1080 # soft_placement, which is on by default. 1081 self.skipTest("A GPU is not available for this test in eager mode.") 1082 1083 save_path = self._save_mirrored(distribution) 1084 self._restore_normal(save_path) 1085 1086 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 1087 def testSaveNormalRestoreMirrored(self, distribution): 1088 if context.num_gpus() < 1 and context.executing_eagerly(): 1089 # Graph mode can work without GPU because the Placer "moves" the 1090 # variable to a CPU. In other words, if there is no GPU available, but 1091 # user requested to create a variable on GPU, Placer will ignore the 1092 # user request and assign the VarHandleOp to CPU. This requires 1093 # soft_placement, which is on by default. 1094 self.skipTest("A GPU is not available for this test in eager mode.") 1095 1096 save_path = self._save_normal() 1097 self._restore_mirrored(save_path, distribution) 1098 1099 1100_TPU_STRATEGIES = (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1) 1101 1102 1103def _make_replica_local(method, strategy=None): 1104 if strategy is None: 1105 devices = ("/device:GPU:0", "/device:CPU:0") 1106 else: 1107 devices = strategy.extended.worker_devices 1108 1109 v = [] 1110 for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]): 1111 with ops.device(d): 1112 v.append(variable_scope.get_variable( 1113 name=n, initializer=init, use_resource=True)) 1114 1115 if (strategy is not None) and isinstance(strategy, _TPU_STRATEGIES): 1116 var_cls = tpu_values.TPUSyncOnReadVariable 1117 else: 1118 var_cls = values_lib.SyncOnReadVariable 1119 replica_local = var_cls(strategy, v, method) 1120 return v, replica_local 1121 1122 1123class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase): 1124 1125 def _assign_replica_local(self, v, new): 1126 for var, n in zip(v, new): 1127 with ops.device(var.device): 1128 self.evaluate(var.assign(n)) 1129 1130 def _save_return_saver(self, sess, var): 1131 saver = saver_lib.Saver(var_list=[var]) 1132 test_dir = self.get_temp_dir() 1133 prefix = os.path.join(test_dir, "ckpt") 1134 return saver.save(sess, prefix), saver 1135 1136 def _save(self, sess, var): 1137 save_path, _ = self._save_return_saver(sess, var) 1138 return save_path 1139 1140 config = config_pb2.ConfigProto() 1141 config.allow_soft_placement = True 1142 1143 @test_util.run_in_graph_and_eager_modes(config=config) 1144 def testProperties(self): 1145 if context.num_gpus() < 1 and context.executing_eagerly(): 1146 self.skipTest("A GPU is not available for this test in eager mode.") 1147 v, replica_local = _make_replica_local( 1148 variable_scope.VariableAggregation.SUM) 1149 1150 self.assertEqual(v[0].constraint, replica_local.constraint) 1151 self.assertEqual(v[0].name, replica_local.name) 1152 self.assertEqual(v[0].dtype, replica_local.dtype) 1153 self.assertEqual(v[0].shape, replica_local.shape) 1154 self.assertEqual(variable_scope.VariableAggregation.SUM, 1155 replica_local.aggregation) 1156 1157 @test_util.run_v2_only 1158 def testCanPassToDefFun(self): 1159 @def_function.function 1160 def add1(x): 1161 return x + 1 1162 1163 v = variable_scope.get_variable( 1164 name="v", initializer=[1.], use_resource=True) 1165 replica_local = values_lib.SyncOnReadVariable( 1166 None, (v,), variable_scope.VariableAggregation.MEAN) 1167 self.assertEqual(2., self.evaluate(add1(replica_local))) 1168 1169 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 1170 def testTensorConversion(self, distribution): 1171 with context.graph_mode(): 1172 _, replica_local = _make_replica_local( 1173 variable_scope.VariableAggregation.SUM, distribution) 1174 converted = ops.convert_to_tensor(replica_local, as_ref=False) 1175 self.assertIsInstance(converted, ops.Tensor) 1176 self.assertEqual(converted.dtype, replica_local.dtype) 1177 1178 converted = ops.convert_to_tensor(replica_local, as_ref=True) 1179 # Resources variable are converted to tensors as well when as_ref is True. 1180 self.assertIsInstance(converted, ops.Tensor) 1181 self.assertEqual(converted.dtype, replica_local.dtype) 1182 1183 @combinations.generate(combinations.combine( 1184 distribution=[ 1185 strategy_combinations.mirrored_strategy_with_gpu_and_cpu, 1186 strategy_combinations.tpu_strategy, 1187 strategy_combinations.tpu_strategy_packed_var, 1188 ], mode=["eager"])) 1189 def testValueInCrossReplicaContext(self, distribution): 1190 value_list, replica_local = _make_replica_local( 1191 variable_scope.VariableAggregation.ONLY_FIRST_REPLICA, distribution) 1192 1193 self.assertIsInstance(replica_local.value(), ops.Tensor) 1194 self.assertEqual(self.evaluate(replica_local.value()), 1195 self.evaluate(value_list[0].value())) 1196 1197 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 1198 def testSaveAndRestoreReplicaLocalSumOneGraph(self, distribution): 1199 with self.cached_session() as sess: 1200 v, replica_local = _make_replica_local( 1201 variable_scope.VariableAggregation.SUM, distribution) 1202 1203 # Overwrite the initial values. 1204 self._assign_replica_local(v, [3., 4.]) 1205 1206 with distribution.scope(): 1207 # Saves the current value of v[0] + v[1], 7. 1208 save_path, saver = self._save_return_saver(sess, replica_local) 1209 1210 # Change the values between save and restore. 1211 self._assign_replica_local(v, [5., 6.]) 1212 1213 # Restores the saved value of 7. which gets divided equally 1214 # between the variables. 1215 saver.restore(sess, save_path) 1216 self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) 1217 1218 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 1219 def testSaveAndRestoreReplicaLocalMeanOneGraph(self, distribution): 1220 if context.num_gpus() < 1 and context.executing_eagerly(): 1221 self.skipTest("A GPU is not available for this test in eager mode.") 1222 1223 with self.cached_session() as sess: 1224 v, replica_local = _make_replica_local( 1225 variable_scope.VariableAggregation.MEAN, distribution) 1226 1227 # Overwrite the initial values. 1228 self._assign_replica_local(v, [3., 4.]) 1229 1230 with distribution.scope(): 1231 # Saves the current value of (v[0] + v[1])/2, 3.5. 1232 save_path, saver = self._save_return_saver(sess, replica_local) 1233 1234 # Change the values between save and restore. 1235 self._assign_replica_local(v, [5., 6.]) 1236 1237 # Restores the saved value of 3.5 to both variables. 1238 saver.restore(sess, save_path) 1239 self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) 1240 1241 def _save_replica_local_mean(self, distribution): 1242 """Save variables with mirroring, returns save_path.""" 1243 with self.session(graph=ops.Graph()) as sess: 1244 v, replica_local = _make_replica_local( 1245 variable_scope.VariableAggregation.MEAN, distribution) 1246 1247 # Overwrite the initial values. 1248 self._assign_replica_local(v, [3., 4.]) 1249 1250 with distribution.scope(): 1251 # Saves the current value of (v[0] + v[1])/2, 3.5 1252 save_path = self._save(sess, replica_local) 1253 1254 # Change the values between save and restore. 1255 self._assign_replica_local(v, [5., 6.]) 1256 return save_path 1257 1258 def _save_replica_local_sum(self, distribution): 1259 """Save variables with mirroring, returns save_path.""" 1260 with self.session(graph=ops.Graph()) as sess: 1261 v, replica_local = _make_replica_local( 1262 variable_scope.VariableAggregation.SUM, distribution) 1263 1264 # Overwrite the initial values. 1265 self._assign_replica_local(v, [1.5, 2.]) 1266 1267 with distribution.scope(): 1268 # Saves the current value of v[0] + v[1], 3.5 1269 save_path = self._save(sess, replica_local) 1270 1271 # Change the values between save and restore. 1272 self._assign_replica_local(v, [5., 6.]) 1273 return save_path 1274 1275 def _save_normal(self): 1276 """Save variables without mirroring, returns save_path.""" 1277 with self.session(graph=ops.Graph()) as sess: 1278 var = variable_scope.get_variable( 1279 name="v", initializer=1., use_resource=True) 1280 1281 # Overwrite the initial value. 1282 self.evaluate(var.assign(3.5)) 1283 1284 # Saves the current value of var, 3.5. 1285 save_path = self._save(sess, var) 1286 1287 # Change the values between save and restore. 1288 self.evaluate(var.assign(5.)) 1289 return save_path 1290 1291 def _restore_normal(self, save_path): 1292 """Restore to variables without mirroring in a fresh graph.""" 1293 with self.session(graph=ops.Graph()) as sess: 1294 var = variable_scope.get_variable( 1295 name="v", initializer=7., use_resource=True) 1296 1297 # Overwrite the initial value. 1298 self.evaluate(var.assign(8.)) 1299 1300 # Restores the saved value of 3.5 to `var`. 1301 saver = saver_lib.Saver(var_list=[var]) 1302 saver.restore(sess, save_path) 1303 self.assertEqual(3.5, self.evaluate(var)) 1304 1305 def _restore_replica_local_mean(self, save_path, distribution): 1306 """Restore to variables with mirroring in a fresh graph.""" 1307 with self.session(graph=ops.Graph()) as sess: 1308 v, replica_local = _make_replica_local( 1309 variable_scope.VariableAggregation.MEAN, distribution) 1310 1311 # Overwrite the initial values. 1312 self._assign_replica_local(v, [7., 8.]) 1313 1314 with distribution.scope(): 1315 # Restores the saved value of 3.5 to both variables. 1316 saver = saver_lib.Saver(var_list=[replica_local]) 1317 saver.restore(sess, save_path) 1318 self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]])) 1319 1320 def _restore_replica_local_sum(self, save_path, distribution): 1321 """Restore to variables with mirroring in a fresh graph.""" 1322 with self.session(graph=ops.Graph()) as sess: 1323 v, replica_local = _make_replica_local( 1324 variable_scope.VariableAggregation.SUM, distribution) 1325 1326 # Overwrite the initial values. 1327 self._assign_replica_local(v, [7., 8.]) 1328 1329 with distribution.scope(): 1330 # Restores the saved value of 3.5 to both variables. 1331 saver = saver_lib.Saver(var_list=[replica_local]) 1332 saver.restore(sess, save_path) 1333 self.assertEqual([1.75, 1.75], self.evaluate([v[0], v[1]])) 1334 1335 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 1336 def testSaveReplicaLocalRestoreReplicaLocalMean(self, distribution): 1337 save_path = self._save_replica_local_mean(distribution) 1338 self._restore_replica_local_mean(save_path, distribution) 1339 1340 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 1341 def testSaveReplicaLocalRestoreReplicaLocalSum(self, distribution): 1342 save_path = self._save_replica_local_sum(distribution) 1343 self._restore_replica_local_sum(save_path, distribution) 1344 1345 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 1346 def testSaveReplicaLocalMeanRestoreNormal(self, distribution): 1347 save_path = self._save_replica_local_mean(distribution) 1348 self._restore_normal(save_path) 1349 1350 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 1351 def testSaveReplicaLocalSumRestoreNormal(self, distribution): 1352 save_path = self._save_replica_local_sum(distribution) 1353 self._restore_normal(save_path) 1354 1355 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 1356 def testSaveNormalRestoreReplicaLocalMean(self, distribution): 1357 save_path = self._save_normal() 1358 self._restore_replica_local_mean(save_path, distribution) 1359 1360 @combinations.generate(mirrored_and_tpu_strategy_combinations()) 1361 def testSaveNormalRestoreReplicaLocalSum(self, distribution): 1362 save_path = self._save_normal() 1363 self._restore_replica_local_sum(save_path, distribution) 1364 1365 1366class MirroredTest(test.TestCase): 1367 1368 def testAddOp(self): 1369 if context.num_gpus() < 1: 1370 self.skipTest("A GPU is not available for this test.") 1371 mirrored_val = _make_mirrored_val(init_val=3.) 1372 1373 self.assertEqual(self.evaluate(constant_op.constant(6.)), 1374 self.evaluate(mirrored_val + mirrored_val)) 1375 self.assertEqual(self.evaluate(constant_op.constant(4.)), 1376 self.evaluate(mirrored_val + 1)) 1377 self.assertEqual(self.evaluate(mirrored_val + 1), 1378 self.evaluate(math_ops.add(mirrored_val, 1))) 1379 self.assertEqual(type(mirrored_val + 1), 1380 type(math_ops.add(mirrored_val, 1))) 1381 1382 1383class PerReplicaTest(test.TestCase, parameterized.TestCase): 1384 1385 @combinations.generate(combinations.combine(mode=["eager"])) 1386 def testTypeSpec(self): 1387 vals = (constant_op.constant(1.),) 1388 per_replica = values_lib.PerReplica(vals) 1389 1390 spec = per_replica._type_spec 1391 self.assertEqual(spec._value_specs, 1392 (tensor_spec.TensorSpec([], dtypes.float32),)) 1393 1394 @combinations.generate(combinations.combine(mode=["eager"])) 1395 def testTypeSpecRoundTrip(self): 1396 vals = (constant_op.constant(1.),) 1397 per_replica = values_lib.PerReplica(vals) 1398 1399 spec = per_replica._type_spec 1400 tensor_list = spec._to_components(per_replica) 1401 reconstructed = spec._from_components(tensor_list) 1402 1403 self.assertAllEqual(per_replica.values, reconstructed.values) 1404 1405 @combinations.generate(combinations.combine(mode=["eager"])) 1406 def testTypeSpecNest(self): 1407 vals = (constant_op.constant(1.), constant_op.constant([5., 6.0]),) 1408 per_replica = values_lib.PerReplica(vals) 1409 1410 # Note: nest.map_structure exercises nest.flatten and 1411 # nest.pack_sequence_as. 1412 result = nest.map_structure( 1413 lambda t: t + 10, per_replica, expand_composites=True) 1414 1415 self.assertLen(result.values, 2) 1416 self.assertAllEqual(result.values[0], 11.) 1417 self.assertAllEqual(result.values[1], [15., 16.0]) 1418 1419 @test_util.run_in_graph_and_eager_modes 1420 def testIsGraphTensor(self): 1421 per_replica = values_lib.PerReplica((constant_op.constant(1.),)) 1422 for t in nest.flatten(per_replica, expand_composites=True): 1423 self.assertEqual(hasattr(t, "graph"), not context.executing_eagerly()) 1424 1425 @combinations.generate(combinations.combine(mode=["eager"])) 1426 def testDoesNotTriggerFunctionTracing(self): 1427 traces = [] 1428 1429 @def_function.function 1430 def f(x): 1431 traces.append(None) # Only happens on trace. 1432 return x 1433 1434 per_replica = values_lib.PerReplica((constant_op.constant(1.),)) 1435 1436 # Trace once. 1437 f(per_replica) 1438 self.assertNotEmpty(traces) 1439 del traces[:] 1440 1441 per_replica_spec = per_replica._type_spec 1442 for _ in range(5): 1443 vals = per_replica_spec._to_components(per_replica) 1444 vals = [v * 2 for v in vals] 1445 per_replica = per_replica_spec._from_components(vals) 1446 1447 output = f(per_replica) 1448 self.assertIsInstance(output, values_lib.PerReplica) 1449 self.assertAllEqual(output._values, per_replica._values) 1450 self.assertEmpty(traces) # Make sure we're not re-tracing `f`. 1451 1452 @combinations.generate(combinations.combine(mode=["eager"])) 1453 def testFunctionCanReturnPerReplica(self): 1454 f = def_function.function(lambda x: x) 1455 x = values_lib.PerReplica((constant_op.constant(1.),)) 1456 y = f(x) 1457 self.assertIsNot(x, y) 1458 nest.map_structure(self.assertAllEqual, x, y, expand_composites=True) 1459 self.assertEqual(x._type_spec, y._type_spec) 1460 1461 @test_util.run_in_graph_and_eager_modes 1462 def testCondWithTensorValues(self): 1463 per_replica_1 = values_lib.PerReplica((constant_op.constant("a"),)) 1464 per_replica_2 = values_lib.PerReplica((constant_op.constant(["b", "c"]),)) 1465 condition = array_ops.placeholder_with_default(True, []) 1466 1467 result = control_flow_ops.cond( 1468 condition, lambda: per_replica_1, lambda: per_replica_2) 1469 1470 self.assertLen(result.values, 1) 1471 self.assertAllEqual(result.values[0], "a") 1472 1473 @test_util.run_in_graph_and_eager_modes 1474 def testCondWithValuesConvertibleToTensor(self): 1475 per_replica_1 = values_lib.PerReplica(("a",)) 1476 per_replica_2 = values_lib.PerReplica(("b",)) 1477 condition = array_ops.placeholder_with_default(True, []) 1478 1479 result = control_flow_ops.cond( 1480 condition, lambda: per_replica_1, lambda: per_replica_2) 1481 1482 self.assertLen(result.values, 1) 1483 self.assertAllEqual(result.values[0], "a") 1484 1485 @test_util.build_as_function_and_v1_graph 1486 def testCondWithValuesNotConvertibleToTensor(self): 1487 per_replica_1 = values_lib.PerReplica(({"a"},)) 1488 per_replica_2 = values_lib.PerReplica(({"b", "c"},)) 1489 condition = array_ops.placeholder(dtypes.bool, []) 1490 1491 with self.assertRaisesRegex(TypeError, "Could not build a TypeSpec for"): 1492 control_flow_ops.cond( 1493 condition, lambda: per_replica_1, lambda: per_replica_2) 1494 1495 1496def _make_index_slices(values, indices, dense_shape=None): 1497 if dense_shape: 1498 dense_shape = array_ops.identity(dense_shape) 1499 return indexed_slices.IndexedSlices( 1500 array_ops.identity(values), array_ops.identity(indices), dense_shape) 1501 1502 1503if __name__ == "__main__": 1504 ds_test_util.main() 1505