• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the distributed values library."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22import os
23
24from absl.testing import parameterized
25import numpy as np
26
27from tensorflow.core.protobuf import config_pb2
28from tensorflow.python import tf2
29from tensorflow.python.distribute import collective_all_reduce_strategy
30from tensorflow.python.distribute import combinations
31from tensorflow.python.distribute import distribute_lib
32from tensorflow.python.distribute import distribute_utils
33from tensorflow.python.distribute import packed_distributed_variable as packed
34from tensorflow.python.distribute import parameter_server_strategy
35from tensorflow.python.distribute import ps_values
36from tensorflow.python.distribute import strategy_combinations
37from tensorflow.python.distribute import test_util as ds_test_util
38from tensorflow.python.distribute import tpu_strategy
39from tensorflow.python.distribute import tpu_values
40from tensorflow.python.distribute import values as values_lib
41from tensorflow.python.eager import context
42from tensorflow.python.eager import def_function
43from tensorflow.python.eager import test
44from tensorflow.python.framework import constant_op
45from tensorflow.python.framework import dtypes
46from tensorflow.python.framework import indexed_slices
47from tensorflow.python.framework import ops
48from tensorflow.python.framework import sparse_tensor
49from tensorflow.python.framework import tensor_shape
50from tensorflow.python.framework import tensor_spec
51from tensorflow.python.framework import test_util
52from tensorflow.python.ops import array_ops
53from tensorflow.python.ops import check_ops
54from tensorflow.python.ops import control_flow_ops
55from tensorflow.python.ops import math_ops
56from tensorflow.python.ops import sparse_ops
57from tensorflow.python.ops import variable_scope
58from tensorflow.python.ops import variables as variables_lib
59from tensorflow.python.saved_model import save
60from tensorflow.python.saved_model import save_context
61from tensorflow.python.saved_model import save_options
62from tensorflow.python.training import saver as saver_lib
63from tensorflow.python.training.tracking import util as trackable_utils
64from tensorflow.python.types import core
65from tensorflow.python.util import nest
66
67
68def _device_str(d):
69  return "/device:GPU:" + str(d)
70
71
72def _nested_value(d):
73  return ("a" + d, ["b" + d, {"c": "d" + d, "e": "f" + d}, "g" + d], "h" + d)
74
75
76def _make_mirrored_val(init_val=5.0):
77  v = []
78  devices = ["/device:GPU:0", "/device:CPU:0"]
79  for d, _ in zip(devices, ["v", "v/replica"]):
80    with ops.device(d):
81      v.append(constant_op.constant(init_val))
82  return values_lib.Mirrored(v)
83
84
85def _make_mirrored(distribution=None):
86  v = []
87  if distribution:
88    devices = distribution.extended.worker_devices
89  else:
90    devices = ["/device:GPU:0", "/device:CPU:0"]
91  for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]):
92    with ops.device(d):
93      v.append(
94          variable_scope.get_variable(
95              name=n, initializer=init, use_resource=True))
96
97  if (distribution is not None) and isinstance(distribution, _TPU_STRATEGIES):
98    var_cls = tpu_values.TPUMirroredVariable
99  else:
100    var_cls = values_lib.MirroredVariable
101  mirrored = var_cls(distribution, v, variable_scope.VariableAggregation.SUM)
102  return mirrored
103
104
105def mirrored_and_tpu_strategy_combinations():
106  return combinations.combine(
107      distribution=[
108          strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
109          strategy_combinations.tpu_strategy,
110          strategy_combinations.tpu_strategy_packed_var,
111      ],
112      mode=["graph", "eager"])
113
114
115class DistributedValuesTest(test.TestCase, parameterized.TestCase):
116
117  @combinations.generate(
118      combinations.combine(
119          distribution=(strategy_combinations.all_strategies_minus_default +
120                        strategy_combinations.multiworker_strategies),
121          mode=["eager"]
122      ))
123  def testMakeDistributedValueFromTensor(self, distribution):
124    if not tf2.enabled():
125      self.skipTest("Only V2 is supported.")
126    single_value = constant_op.constant(1)
127    def value_fn(ctx):
128      del ctx
129      return single_value
130
131    distributed_values = (
132        distribution.experimental_distribute_values_from_function(value_fn))
133    self.assertAllEqual(
134        ds_test_util.gather(distribution, distributed_values),
135        constant_op.constant(1., shape=(distribution.num_replicas_in_sync)))
136
137  @combinations.generate(
138      combinations.combine(
139          distribution=(strategy_combinations.all_strategies_minus_default +
140                        strategy_combinations.multiworker_strategies),
141          mode=["eager"]
142      ))
143  def testMakeDistributedValueSingleNumpyArrayConstant(self, distribution):
144    if not tf2.enabled():
145      self.skipTest("Only V2 is supported.")
146    array_value = np.array([1., 2., 3.])
147    def value_fn(ctx):
148      del ctx
149      return array_value
150
151    distributed_values = (
152        distribution.experimental_distribute_values_from_function(value_fn))
153    self.assertAllEqual(
154        ds_test_util.gather(distribution, distributed_values).numpy(),
155        [[1., 2., 3.]] * distribution.num_replicas_in_sync)
156
157  @combinations.generate(
158      combinations.combine(
159          distribution=(strategy_combinations.all_strategies_minus_default +
160                        strategy_combinations.multiworker_strategies),
161          mode=["eager"]
162      ))
163  def testMakeDistributedValueTupleConstant(self, distribution):
164    if not tf2.enabled():
165      self.skipTest("Only V2 is supported.")
166    tuple_value = (1., 2., 3.)
167    def value_fn(ctx):
168      del ctx
169      return tuple_value
170    distributed_values = (
171        distribution.experimental_distribute_values_from_function(value_fn))
172    distributed_values = ds_test_util.gather(distribution, distributed_values)
173
174    # Expected output for 2 replicas:
175    # ([1.0, 1.0], [2.0, 2.0], [3.0, 3.0])
176    expected = tuple([v for i in range(distribution.num_replicas_in_sync)]
177                     for v in tuple_value)
178    self.assertAllEqual(distributed_values, expected)
179
180  @combinations.generate(
181      combinations.combine(
182          distribution=(strategy_combinations.all_strategies_minus_default +
183                        strategy_combinations.multiworker_strategies),
184          mode=["eager"]
185      ))
186  def testMakeDistributedValueNestedStructurePerReplica(self, distribution):
187    if not tf2.enabled():
188      self.skipTest("Only V2 is supported.")
189    tuple_value = (1., 2., 3.)
190    def value_fn(ctx):
191      per_replica = []
192      for val in tuple_value:
193        per_replica.append(val * ctx.replica_id_in_sync_group)
194      return tuple(per_replica)
195    distributed_values = (
196        distribution.experimental_distribute_values_from_function(value_fn))
197    distributed_values = ds_test_util.gather(distribution, distributed_values)
198
199    # Expected output for 2 replicas:
200    # ([0.0, 1.0], [0.0, 2.0], [0.0, 3.0])
201    expected = tuple([v * i for i in range(distribution.num_replicas_in_sync)]
202                     for v in tuple_value)
203    self.assertAllEqual(distributed_values, expected)
204
205  # NOTE(priyag): Cannot test this with MultiWorkerMirroredStrategy because
206  # collective ops do not support SparseTensors.
207  @combinations.generate(
208      combinations.combine(
209          distribution=strategy_combinations.all_strategies_minus_default,
210          mode=["eager"]
211      ))
212  def testMakeDistributedValueSpareTensor(self, distribution):
213    if not tf2.enabled():
214      self.skipTest("Only V2 is supported.")
215    def value_fn(ctx):
216      del ctx
217      return sparse_tensor.SparseTensor(
218          indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
219
220    distributed_values = (
221        distribution.experimental_distribute_values_from_function(value_fn))
222    local_results = distribution.experimental_local_results(distributed_values)
223    for i in range(distribution.num_replicas_in_sync):
224      self.assertAllEqual(
225          sparse_ops.sparse_tensor_to_dense(local_results[i]),
226          [[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]])
227
228  @combinations.generate(
229      combinations.combine(
230          distribution=(strategy_combinations.all_strategies_minus_default +
231                        strategy_combinations.multiworker_strategies),
232          mode=["eager"]
233      ))
234  def testMakeDistributedValueExtractFromArray(self, distribution):
235    if not tf2.enabled():
236      self.skipTest("Only V2 is supported.")
237    multiple_values = range(distribution.num_replicas_in_sync)
238    def value_fn(ctx):
239      return multiple_values[ctx.replica_id_in_sync_group]
240    distributed_values = (
241        distribution.experimental_distribute_values_from_function(value_fn))
242    distributed_values = ds_test_util.gather(distribution, distributed_values)
243    expected = range(distribution.num_replicas_in_sync)
244    self.assertAllEqual(distributed_values, expected)
245
246  @combinations.generate(
247      combinations.combine(
248          distribution=(strategy_combinations.all_strategies_minus_default +
249                        strategy_combinations.multiworker_strategies),
250          mode=["eager"]
251      ))
252  def testMakeDistributedValueAndRun(self, distribution):
253    if not tf2.enabled():
254      self.skipTest("Only V2 is supported.")
255
256    @def_function.function
257    def run():
258      multiple_values = range(distribution.num_replicas_in_sync)
259      def value_fn(ctx):
260        return multiple_values[ctx.replica_id_in_sync_group]
261      distributed_values = (
262          distribution.experimental_distribute_values_from_function(value_fn))
263
264      def computation(x):
265        return math_ops.square(x)
266
267      outputs = ds_test_util.gather(
268          distribution,
269          distribution.run(computation, args=(distributed_values,)))
270      return outputs
271
272    results = run()
273
274    expected = [i**2 for i in range(distribution.num_replicas_in_sync)]
275    self.assertAllEqual(results, expected)
276
277  @combinations.generate(
278      combinations.combine(
279          distribution=[
280              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
281              strategy_combinations.tpu_strategy,
282              strategy_combinations.tpu_strategy_packed_var,
283              strategy_combinations.central_storage_strategy_with_two_gpus,
284          ] + strategy_combinations.multiworker_strategies,
285          mode=["eager"]))
286  def testMakeDistributedValueDefaultDevicePlacement(self, distribution):
287    if not tf2.enabled():
288      self.skipTest("Only V2 is supported.")
289    def value_fn(ctx):
290      del ctx
291      return constant_op.constant(1.0)
292    distributed_values = (
293        distribution.experimental_distribute_values_from_function(value_fn))
294    for i in range(len(distribution.extended.worker_devices)):
295      self.assertAllEqual(distributed_values._values[i].device,
296                          "/job:localhost/replica:0/task:0/device:CPU:0")
297
298  @combinations.generate(
299      combinations.combine(
300          distribution=[
301              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
302              strategy_combinations.tpu_strategy,
303              strategy_combinations.tpu_strategy_packed_var,
304              strategy_combinations.central_storage_strategy_with_two_gpus,
305          ] + strategy_combinations.multiworker_strategies,
306          mode=["eager"]))
307  def testMakeDistributedValueExplicitDevicePlacement(self, distribution):
308    if not tf2.enabled():
309      self.skipTest("Only V2 is supported.")
310    worker_devices = distribution.extended.worker_devices
311    def value_fn(ctx):
312      # In multi client setup, worker_devices is just the devices on that
313      # worker.
314      worker_device_id = ctx.replica_id_in_sync_group % len(worker_devices)
315      with ops.device(worker_devices[worker_device_id]):
316        return array_ops.identity(1.0)
317    distributed_values = (
318        distribution.experimental_distribute_values_from_function(value_fn))
319    for i in range(len(distribution.extended.worker_devices)):
320      self.assertAllEqual(distributed_values._values[i].device,
321                          worker_devices[i])
322
323
324class DistributedDelegateTest(test.TestCase):
325
326  @test_util.run_in_graph_and_eager_modes
327  def testGetAttr(self):
328    class Foo(object):
329
330      def __init__(self, x):
331        self.x = x
332
333    v = values_lib.DistributedDelegate((Foo(7), Foo(8)))
334    self.assertEqual(7, v.x)
335    with self.assertRaises(AttributeError):
336      _ = v.y
337
338  @test_util.run_in_graph_and_eager_modes
339  def testOperatorOverride(self):
340    v = values_lib.DistributedDelegate((7, 8))
341    # v should act like int(7).
342    self.assertEqual(8, v + 1)
343    self.assertEqual(10, 3 + v)
344    self.assertEqual(14, v + v)
345    self.assertEqual(5, v - 2)
346    self.assertEqual(6, 13 - v)
347    self.assertEqual(0, v - v)
348    self.assertEqual(14, v * 2)
349    self.assertEqual(21, 3 * v)
350    self.assertEqual(49, v * v)
351    self.assertEqual(3.5, v / 2)
352    self.assertEqual(1.5, 10.5 / v)
353    self.assertEqual(3, v // 2)
354    self.assertEqual(2, 15 // v)
355    self.assertEqual(1, v % 2)
356    self.assertEqual(2, 16 % v)
357    # pylint: disable=g-generic-assert
358    self.assertTrue(v < 12)
359    self.assertTrue(v <= 12)
360    self.assertFalse(v > 12)
361    self.assertFalse(v >= 12)
362    self.assertFalse(12 < v)
363    self.assertFalse(12 <= v)
364    self.assertTrue(12 > v)
365    self.assertTrue(12 >= v)
366    # pylint: enable=g-generic-assert
367    self.assertEqual(3, v & 3)
368    self.assertEqual(3, 11 & v)
369    self.assertEqual(15, v | 8)
370    self.assertEqual(23, 16 | v)
371    self.assertEqual(4, v ^ 3)
372    self.assertEqual(12, 11 ^ v)
373    self.assertEqual(343, pow(v, 3))
374    self.assertEqual(3, pow(v, 3, 10))
375    self.assertEqual(128, pow(2, v))
376    self.assertEqual(-7, -v)
377    self.assertEqual(~7, ~v)
378    self.assertEqual(7, abs(v))
379    with self.assertRaises(TypeError):
380      _ = v[2]
381
382  @test_util.run_in_graph_and_eager_modes
383  def testCopy(self):
384
385    class Foo(object):
386
387      def __init__(self, x):
388        self.x = x
389
390    v = values_lib.DistributedDelegate((Foo(7), Foo(8)))
391    v_shallow_copy = copy.copy(v)
392    self.assertEqual(v.x, v_shallow_copy.x)
393    v_deep_copy = copy.deepcopy(v)
394    self.assertEqual(v.x, v_deep_copy.x)
395
396
397@combinations.generate(
398    combinations.combine(
399        distribution=[
400            strategy_combinations.mirrored_strategy_with_one_cpu,
401            strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
402            strategy_combinations.tpu_strategy,
403            strategy_combinations.tpu_strategy_packed_var,
404            strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
405            strategy_combinations.multi_worker_mirrored_2x1_cpu,
406            strategy_combinations.multi_worker_mirrored_2x1_gpu,
407            strategy_combinations.multi_worker_mirrored_2x2_gpu
408        ],
409        synchronization=[
410            variables_lib.VariableSynchronization.ON_READ,
411            variables_lib.VariableSynchronization.ON_WRITE,
412        ],
413        aggregation=[
414            variables_lib.VariableAggregation.MEAN,
415            variables_lib.VariableAggregation.SUM,
416            variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
417        ],
418        mode=["graph", "eager"],
419        use_var_policy=[True, False]))
420class DistributedVariableTest(test.TestCase, parameterized.TestCase):
421
422  def testExtendsVariable(self, distribution, synchronization, aggregation):
423    with distribution.scope():
424      v = variables_lib.Variable(
425          1., synchronization=synchronization, aggregation=aggregation)
426    self.assertIsInstance(v, variables_lib.Variable)
427
428  def testCheckpointing(self, distribution, synchronization, aggregation, mode):
429
430    if (isinstance(distribution,
431                   collective_all_reduce_strategy.CollectiveAllReduceStrategy)
432        and mode == "graph"):
433      self.skipTest("MWMS combinations tests do not work well in graph mode.")
434
435    with distribution.scope():
436      v = variables_lib.Variable(
437          constant_op.constant([1., 2., 3., 4]),
438          synchronization=synchronization,
439          aggregation=aggregation)
440
441    self.evaluate(v.initializer)
442    before_save = self.evaluate(v.read_value())
443
444    # Save random weights into checkpoint.
445    checkpoint = trackable_utils.Checkpoint(v=v)
446    prefix = os.path.join(self.get_temp_dir(), "ckpt")
447    with self.test_session():
448      save_path = checkpoint.save(prefix)
449
450    # Assign inverted value.
451    self.evaluate(v.assign(constant_op.constant([4., 3., 2., 1.])))
452    after_assign = self.evaluate(v.read_value())
453    self.assertNotAllClose(before_save, after_assign)
454
455    # Restore from the checkpoint.
456    with self.test_session():
457      checkpoint.restore(save_path).assert_consumed().run_restore_ops()
458    after_restore = self.evaluate(v)
459    self.assertAllClose(before_save, after_restore)
460
461  def testTraceback(self, distribution, synchronization, aggregation):
462    if context.executing_eagerly():
463      self.skipTest("does not apply to eager")
464    with distribution.scope():
465      variable_scope.get_variable(
466          name="testVar",
467          initializer=1.,
468          use_resource=True,
469          synchronization=synchronization,
470          aggregation=aggregation)
471      with self.assertRaisesRegex(ValueError,
472                                  "Variable testVar already exists"):
473        variable_scope.get_variable(
474            name="testVar",
475            initializer=1.,
476            use_resource=True,
477            synchronization=synchronization,
478            aggregation=aggregation)
479
480  def testSelectReplica(self, distribution, synchronization, aggregation):
481    with distribution.scope():
482      v = variables_lib.Variable(
483          1., synchronization=synchronization, aggregation=aggregation)
484    self.assertIs(v, distribute_utils.select_replica(0, v))
485
486  def testIsTensorLike(self, distribution, synchronization, aggregation):
487    if isinstance(distribution.extended,
488                  tpu_strategy.TPUExtended) and context.executing_eagerly():
489      self.skipTest("TPU doesn't support pure eager")
490
491    with distribution.scope():
492      v = variables_lib.Variable(
493          0., synchronization=synchronization, aggregation=aggregation)
494    # In cross replica context.
495    self.assertIsInstance(v, core.Tensor)
496    # In replica context.
497    distribution.run(
498        lambda v: self.assertIsInstance(v, core.Tensor), args=(v,))
499
500  def testAssignReturnValueIsTensorLike(self, distribution, synchronization,
501                                        aggregation):
502    if isinstance(distribution.extended, tpu_strategy.TPUExtended):
503      if context.executing_eagerly():
504        self.skipTest("TPU doesn't support pure eager")
505      else:
506        self.skipTest("b/152076846")
507
508    with distribution.scope():
509      v = variables_lib.Variable(
510          0., synchronization=synchronization, aggregation=aggregation)
511
512    def assert_is_tensor_like(v):
513      # We can't use Python literals because they are treated as non-distributed
514      # values is not allowed when aggregation is SUM. See
515      # `cross_device_ops.reduce_non_distributed_value`.
516      delta = array_ops.identity(1.)
517      self.assertIsInstance(v.assign(delta), core.Tensor)
518      self.assertIsInstance(v.assign_sub(delta), core.Tensor)
519      self.assertIsInstance(v.assign_add(delta), core.Tensor)
520
521    # In cross replica context we return a PerReplica which is not Tensor like
522    # all the time yet.
523    if (synchronization == variables_lib.VariableSynchronization.ON_READ and
524        aggregation != variables_lib.VariableAggregation.SUM):
525      assert_is_tensor_like(v)
526
527    # In replica context.
528    distribution.run(assert_is_tensor_like, args=(v,))
529
530  def testDeepCopy(self, distribution, synchronization,
531                   aggregation):
532    if not context.executing_eagerly():
533      self.skipTest("deepcopy only supported in eager mode")
534
535    with distribution.scope():
536      v = variables_lib.Variable(
537          0., synchronization=synchronization, aggregation=aggregation)
538      in_dist_copy = copy.deepcopy(v)
539
540    out_dist_copy = copy.deepcopy(v)
541
542    def assert_is_deep_copy(v1, v2):
543      self.assertIsInstance(v2, type(v1))
544      self.assertEqual(v1.aggregation, v2.aggregation)
545      self.assertEqual(v1.distribute_strategy, v2.distribute_strategy)
546      if isinstance(v1, ps_values.AggregatingVariable):
547        self.assertIsInstance(v2.get(), type(v1.get()))
548        self.assertNotEqual(id(v1.get()), id(v2.get()))
549      else:
550        if v1._policy:
551          self.assertNotEqual(id(v1._policy), id(v2._policy))  # pylint: disable=protected-access
552        else:
553          self.assertEqual(id(v1._policy), id(v2._policy))  # pylint: disable=protected-access
554        self.assertEqual(len(v1.values), len(v2.values))
555        for (v1v, v2v) in zip(v1.values, v2.values):
556          self.assertEqual(v1v.device, v2v.device)
557          self.assertNotEqual(id(v1v), id(v2v))
558          self.assertAllEqual(self.evaluate(v1.values),
559                              self.evaluate(v2.values))
560
561    self.evaluate(variables_lib.global_variables_initializer())
562    if not isinstance(distribution.extended, tpu_strategy.TPUExtended):
563      distribution.run(assert_is_deep_copy, args=(v, in_dist_copy))
564      distribution.run(assert_is_deep_copy, args=(v, out_dist_copy))
565
566  def testAssignSignature(self, distribution, synchronization, aggregation):
567    # This test verifies assign*() can be called in the same way as normal
568    # variables.
569    with distribution.scope():
570      v = variables_lib.Variable(
571          0., synchronization=synchronization, aggregation=aggregation)
572
573      def assign():
574        one = constant_op.constant(1.)
575        v.assign(one, True, "assign", False)
576        # TODO(b/154017756): SyncOnReadVariable.assign() doesn't support passing
577        # value as a keyword argument.
578        v.assign(one, use_locking=True, name="assign", read_value=False)
579        v.assign_add(one, True, "assign", False)
580        v.assign_add(one, use_locking=True, name="assign", read_value=False)
581        v.assign_sub(one, True, "assign", False)
582        v.assign_sub(one, use_locking=True, name="assign", read_value=False)
583        # Return something for graph mode to fetch.
584        return constant_op.constant(1)
585
586      self.evaluate(variables_lib.global_variables_initializer())
587      if not (synchronization == variables_lib.VariableSynchronization.ON_READ
588              and aggregation == variables_lib.VariableAggregation.SUM):
589        self.evaluate(distribution.experimental_local_results(assign()))
590      if not (isinstance(distribution.extended, tpu_strategy.TPUExtended) and
591              context.executing_eagerly()):
592        self.evaluate(
593            distribution.experimental_local_results(distribution.run(assign)))
594
595  def testStrategyExtendedUpdate(self, distribution, synchronization,
596                                 aggregation):
597    if len(distribution.extended.parameter_devices) != 2:
598      self.skipTest("n/a: needs exactly two parameter devices")
599    if (synchronization == variables_lib.VariableSynchronization.ON_WRITE and
600        aggregation != variables_lib.VariableAggregation.NONE):
601      self.skipTest("n/a: doesn't apply to ON_WRITE variable with aggregation")
602    with distribution.scope():
603      v = variables_lib.Variable(
604          0., synchronization=synchronization, aggregation=aggregation)
605    value = values_lib.PerReplica([1., 2.])
606
607    assign_fn = lambda var, value: var.assign(value)
608    self.evaluate(distribution.extended.update(v, assign_fn, args=(value,)))
609    self.assertAllEqual(self.evaluate(v.values), [1., 2.])
610
611    assign_add_fn = lambda var, value: var.assign_add(value)
612    self.evaluate(distribution.extended.update(v, assign_add_fn, args=(value,)))
613    self.assertAllEqual(self.evaluate(v.values), [2., 4.])
614
615    assign_sub_fn = lambda var, value: var.assign_sub(value)
616    self.evaluate(distribution.extended.update(v, assign_sub_fn, args=(value,)))
617    self.assertAllEqual(self.evaluate(v.values), [1., 2.])
618
619    read_assign_fn = lambda var, value: var.assign_add(var.value() + var.
620                                                       read_value())
621    self.evaluate(
622        distribution.extended.update(v, read_assign_fn, args=(value,)))
623    self.assertAllEqual(self.evaluate(v.values), [3., 6.])
624
625  def testSaveNonDistributed(self, distribution, synchronization, aggregation):
626    # This test verifies that the DistributedVariable behave like the primary
627    # variable when saving a non-distributed version of the model (the default).
628    # The test asserts that the function traced under SaveContext has no device
629    # annotations and only reference the primary component of the variable. Note
630    # that please avoid capturing other eager tensors in this test to make the
631    # assertion easy.
632
633    if isinstance(distribution.extended,
634                  parameter_server_strategy.ParameterServerStrategyExtended):
635      self.skipTest("b/148689177: AggregatingVariable doesn't "
636                    "conform to Variable interface well")
637
638    # tf.function requires the return value to be Tensors, which is not always
639    # case for properties and methods of Variable, so we simply discard the
640    # return values.
641    def _discard_return(f):
642      f()
643      return
644
645    def _test(f, v):
646      # This verifies that the function under SaveContext:
647      #   - contains no device annotations.
648      #   - only references the primary component of the variable.
649      g = def_function.function(lambda: _discard_return(f))
650      options = save_options.SaveOptions(
651          experimental_variable_policy=save_options.VariablePolicy.NONE)
652      with save_context.save_context(options):
653        # The graph should contain no device.
654        graph = g.get_concrete_function().graph
655      for op in graph.get_operations():
656        self.assertEqual(op.device, "", msg=str(op))
657      # The function should only capture the primary variable. Note that it
658      # may not have captures, e.g. v.aggregation.
659      captures = list(graph.captures)
660      self.assertLessEqual(len(captures), 1)
661      if graph.captures:
662        self.assertIs(captures[0][0], v._primary.handle)
663
664    def _assert(cond):
665      return control_flow_ops.Assert(cond, [cond])
666
667    with distribution.scope():
668      # We use four variables for convenience reasons. They have no special
669      # meaning.
670      # - v is used whenever possible.
671      # - w is used for scatter and gather, which require the variable to be
672      # non-scalar.
673      # - y is used when the dtype needs to be integer. Note that aggregation
674      # cannot be MEAN for integers.
675      v = variables_lib.Variable(
676          0.,
677          synchronization=synchronization,
678          aggregation=aggregation,
679          trainable=True)
680      w = variables_lib.Variable([0., 0., 0.],
681                                 synchronization=synchronization,
682                                 aggregation=aggregation,
683                                 trainable=True)
684      if aggregation != variables_lib.VariableAggregation.MEAN:
685        y = variables_lib.Variable(
686            0,
687            synchronization=synchronization,
688            aggregation=aggregation)
689
690    # pylint: disable=g-long-lambda
691
692    # tf.Variable properties.
693    _test(lambda: self.assertEqual(v.aggregation, aggregation), v)
694    _test(lambda: self.assertIs(v.constraint, None), v)
695    # TODO(crccw): should we raise an error instead?
696    _test(lambda: self.assertEqual(v.device, v._primary.device), v)
697    _test(lambda: self.assertEqual(v.dtype, dtypes.float32), v)
698    if not context.executing_eagerly():
699      _test(lambda: self.assertIs(v.graph, v._primary.graph), v)
700    if not context.executing_eagerly():
701      _test(lambda: _assert(v.initial_value == 0), v)
702    _test(lambda: self.assertIs(v.initializer, v._primary.initializer), v)
703    _test(lambda: self.assertEqual(v.name, "Variable:0"), v)
704    if not context.executing_eagerly():
705      _test(lambda: self.assertIs(v.op, v._primary.op), v)
706    _test(lambda: self.assertEqual(v.shape, tensor_shape.TensorShape(())), v)
707    _test(lambda: self.assertEqual(v.synchronization, synchronization), v)
708    _test(lambda: self.assertTrue(v.trainable, True), v)
709
710    # tf.Variable methods.
711    _test(lambda: check_ops.assert_equal_v2(v.assign(1.), 1.), v)
712    _test(lambda: check_ops.assert_equal_v2(v.assign_add(1.), 2.), v)
713    _test(lambda: check_ops.assert_equal_v2(v.assign_sub(1.), 1.), v)
714    # TODO(b/148689177): Implement batch_scatter_update.
715    # count_up_to() is skipped since it's deprecated.
716    # eval() is skipped since it shouldn't called in a tf.function.
717    # experimental_ref() is skipped since it's deprecated.
718    # from_proto() is skipped since it shouldn't called in a tf.function.
719    # TODO(b/148689177): Implement gather_nd.
720    _test(
721        lambda: check_ops.assert_equal_v2(v.get_shape(),
722                                          tensor_shape.TensorShape(())), v)
723    # initialized_value() is skipped since it shouldn't called in a tf.function.
724    # load() is skipped since it shouldn't called in a tf.function.
725    _test(lambda: check_ops.assert_equal_v2(v.read_value(), 1.), v)
726    # ref() is skipped since it shouldn't called in a tf.function.
727    _test(
728        lambda: check_ops.assert_equal_v2(
729            w.scatter_add(_make_index_slices(values=[1., 2.], indices=[0, 2])),
730            [1., 0., 2.]), w)
731    _test(
732        lambda: check_ops.assert_equal_v2(
733            w.scatter_div(_make_index_slices(values=[4., 2.], indices=[0, 2])),
734            [0.25, 0., 1.]), w)
735    _test(
736        lambda: check_ops.assert_equal_v2(
737            w.scatter_max(_make_index_slices(values=[1., 0.5], indices=[1, 2])),
738            [0.25, 1., 1.]), w)
739    _test(
740        lambda: check_ops.assert_equal_v2(
741            w.scatter_min(_make_index_slices(values=[1., 0.5], indices=[0, 1])),
742            [0.25, 0.5, 1.]), w)
743    _test(
744        lambda: check_ops.assert_equal_v2(
745            w.scatter_mul(_make_index_slices(values=[2., 0.5], indices=[0, 1])),
746            [0.5, 0.25, 1.]), w)
747    # TODO(b/148689177): Implement scatter_nd_*
748    _test(
749        lambda: check_ops.assert_equal_v2(
750            w.scatter_sub(_make_index_slices(values=[2., 0.5], indices=[0, 1])),
751            [-1.5, -0.25, 1.]), w)
752    _test(
753        lambda: check_ops.assert_equal_v2(
754            w.scatter_update(
755                _make_index_slices(values=[2., 0.5], indices=[0, 1])),
756            [2., 0.5, 1.]), w)
757    # set_shape() is skipped since ResourceVariable doesn't implement it.
758    # to_proto() is skipped since it shouldn't called in a tf.function.
759    _test(lambda: check_ops.assert_equal_v2(v.value(), 1.), v)
760
761    # DistributedVariable should be treated as ResourceVariable, so it needs to
762    # conform to ResourceVariable interface as well.
763    _test(lambda: self.assertIs(v.handle, v._primary.handle), v)
764
765    # Convert to tensor.
766    _test(lambda: check_ops.assert_equal_v2(ops.convert_to_tensor(v), 1.), v)
767
768    # Control dependency.
769    def _with_control_dep():
770      with ops.control_dependencies([v.assign(1.)]):
771        return array_ops.identity(1)
772
773    _test(_with_control_dep, v)
774
775    # Operator overloads.
776    _test(lambda: check_ops.assert_equal_v2(v.assign(7.), 7.), v)
777    _test(lambda: check_ops.assert_equal_v2(v + 1., 8.), v)
778    _test(lambda: check_ops.assert_equal_v2(3 + v, 10.), v)
779    _test(lambda: check_ops.assert_equal_v2(v + v, 14.), v)
780    _test(lambda: check_ops.assert_equal_v2(v - 2., 5.), v)
781    _test(lambda: check_ops.assert_equal_v2(v - v, 0.), v)
782    _test(lambda: check_ops.assert_equal_v2(v * 2., 14.), v)
783    _test(lambda: check_ops.assert_equal_v2(3 * v, 21.), v)
784    _test(lambda: check_ops.assert_equal_v2(v * v, 49.), v)
785    _test(
786        lambda: check_ops.assert_equal_v2(
787            math_ops.cast(v / 2., dtypes.float32), 3.5), v)
788    _test(
789        lambda: check_ops.assert_equal_v2(
790            math_ops.cast(14. / v, dtypes.float32), 2.), v)
791    _test(lambda: _assert(v < 12.), v)
792    _test(lambda: _assert(v <= 12.), v)
793    _test(lambda: _assert(not v > 12.), v)
794    _test(lambda: _assert(not v >= 12.), v)
795    _test(lambda: _assert(not 12. < v), v)
796    _test(lambda: _assert(not 12. <= v), v)
797    _test(lambda: _assert(12. > v), v)
798    _test(lambda: _assert(12. >= v), v)
799    _test(lambda: check_ops.assert_near_v2(pow(v, 3.), 343.), v)
800    _test(lambda: check_ops.assert_near_v2(pow(2., v), 128.), v)
801    _test(lambda: check_ops.assert_equal_v2(abs(v), 7.), v)
802
803    # Operator overloads that only works for integers.
804    if aggregation != variables_lib.VariableAggregation.MEAN:
805      _test(lambda: check_ops.assert_equal_v2(y.assign(7), 7), y)
806      _test(lambda: check_ops.assert_equal_v2(y // 2, 3), y)
807      _test(lambda: check_ops.assert_equal_v2(15 // y, 2), y)
808      _test(lambda: check_ops.assert_equal_v2(y % 2, 1), y)
809      _test(lambda: check_ops.assert_equal_v2(16 % y, 2), y)
810      _test(lambda: check_ops.assert_equal_v2(y & 3, 3), y)
811      _test(lambda: check_ops.assert_equal_v2(3 & y, 3), y)
812      _test(lambda: check_ops.assert_equal_v2(y | 8, 15), y)
813      _test(lambda: check_ops.assert_equal_v2(16 | y, 23), y)
814      _test(lambda: check_ops.assert_equal_v2(y ^ 3, 4), y)
815      _test(lambda: check_ops.assert_equal_v2(11 ^ y, 12), y)
816      _test(lambda: check_ops.assert_equal_v2(-y, -7), y)
817      _test(lambda: check_ops.assert_equal_v2(~y, ~7), y)
818
819    # Index.
820    if isinstance(distribution.extended, tpu_strategy.TPUExtended):
821      # TODO(b/161572567): slice assignment doesn't work for TPU.
822      _test(lambda: check_ops.assert_equal_v2(w[0], 2.), w)
823    else:
824      _test(lambda: check_ops.assert_equal_v2(w[0].assign(1.), [1., 0.5, 1.]),
825            w)
826      _test(lambda: check_ops.assert_equal_v2(w[0], 1.), w)
827
828    # pylint: enable=g-long-lambda
829
830  def testUnsaveable(self, distribution, synchronization, aggregation, mode):
831    if isinstance(distribution.extended,
832                  parameter_server_strategy.ParameterServerStrategyExtended):
833      self.skipTest("n/a: not appliable to AggregatingVariable")
834    if (isinstance(distribution,
835                   collective_all_reduce_strategy.CollectiveAllReduceStrategy)
836        and mode == "graph"):
837      self.skipTest("MWMS combinations tests do not work well in graph mode.")
838    with distribution.scope():
839      v = variables_lib.Variable([1., 1.],
840                                 synchronization=synchronization,
841                                 aggregation=aggregation)
842
843    with self.cached_session():
844      self.evaluate(variables_lib.global_variables_initializer())
845
846    export_dir = self.get_temp_dir()
847
848    def _assert_unsaveable(f):
849      # Ignore if it cannot be traced. Certain combinations are not supported or
850      # yet or not allowed.
851      try:
852        f = def_function.function(f).get_concrete_function()
853      except (NotImplementedError, ValueError):
854        return
855      with self.assertRaisesRegex(ValueError, "f_with_input_signature"):
856        save.save(v, export_dir, signatures=f)
857
858    _assert_unsaveable(lambda: v.assign(ops.convert_to_tensor([1., 1.])))
859    _assert_unsaveable(lambda: v.assign_add(ops.convert_to_tensor([1., 1.])))
860    _assert_unsaveable(lambda: v.assign_sub(ops.convert_to_tensor([1., 1.])))
861    _assert_unsaveable(lambda: v.scatter_add(_make_index_slices([1.], [0])))
862    _assert_unsaveable(lambda: v.scatter_sub(_make_index_slices([1.], [0])))
863    _assert_unsaveable(lambda: v.scatter_mul(_make_index_slices([1.], [0])))
864    _assert_unsaveable(lambda: v.scatter_div(_make_index_slices([1.], [0])))
865    _assert_unsaveable(lambda: v.scatter_min(_make_index_slices([1.], [0])))
866    _assert_unsaveable(lambda: v.scatter_max(_make_index_slices([1.], [0])))
867    _assert_unsaveable(lambda: v.scatter_update(_make_index_slices([1.], [0])))
868    # Reading a ON_READ variable should be unsaveable if either:
869    # 1) CollectiveAllReduceStrategy, and aggregation is MEAN/SUM.
870    # 2) aggregation is SUM.
871    if (synchronization == variables_lib.VariableSynchronization.ON_READ and
872        (aggregation == variables_lib.VariableAggregation.SUM or
873         (isinstance(distribution.extended,
874                     collective_all_reduce_strategy.CollectiveAllReduceExtended)
875          and aggregation == variables_lib.VariableAggregation.MEAN))):
876      _assert_unsaveable(v.read_value)
877      _assert_unsaveable(v.value)
878      _assert_unsaveable(lambda: ops.convert_to_tensor(v))
879    else:
880      # Otherwise reading a variable should be saveable.
881
882      @def_function.function
883      def f():
884        v.read_value()
885        v.value()
886        return ops.convert_to_tensor(v)
887
888      with self.cached_session():
889        save.save(v, export_dir, signatures=f.get_concrete_function())
890
891
892@combinations.generate(
893    combinations.combine(
894        distribution=[
895            strategy_combinations.mirrored_strategy_with_one_cpu,
896            strategy_combinations.tpu_strategy,
897        ],
898        mode=["eager"]))
899class PackedDistributedVariableTest(test.TestCase, parameterized.TestCase):
900
901  def testPackedVariable(self, distribution):
902    with distribution.scope():
903      v0 = variables_lib.Variable(0.)
904    self.assertIsNone(v0._packed_var)
905
906    distribution._enable_packed_variable_in_eager_mode = True
907    with distribution.scope():
908      v1 = variables_lib.Variable(0)
909      self.assertIsInstance(v1._packed_var, packed.PackedDistributedVariable)
910
911    devices = v1._devices
912    for i in range(1, len(devices)):
913      with distribute_lib.ReplicaContext(distribution, i):
914        v1.assign(i)
915    val = v1._get()
916    self.assertIsInstance(val, packed.PackedVarAndDevice)
917    self.assertEqual(val.device, devices[0])
918    self.assertEqual(self.evaluate(val.read_value()), 0)
919    for i in range(0, len(devices)):
920      with distribute_lib.ReplicaContext(distribution, i):
921        val = v1._get()
922        self.assertIsInstance(val, packed.PackedVarAndDevice)
923        self.assertEqual(val.device, devices[i])
924        self.assertEqual(self.evaluate(val.read_value()), i)
925
926  def testIgnorePackedVariableInSaveContext(self, distribution):
927    distribution._enable_packed_variable_in_eager_mode = True
928    with distribution.scope():
929      v = variables_lib.Variable(0)
930      self.assertIsInstance(
931          v._packed_variable, packed.PackedDistributedVariable)
932
933    options = save_options.SaveOptions()
934    with save_context.save_context(options):
935      self.assertIsNone(v._packed_variable)
936
937
938class MirroredVariableTest(test.TestCase, parameterized.TestCase):
939
940  config = config_pb2.ConfigProto()
941  config.allow_soft_placement = True
942
943  @test_util.run_in_graph_and_eager_modes(config=config)
944  def testProperties(self):
945    if context.num_gpus() < 1 and context.executing_eagerly():
946      self.skipTest("A GPU is not available for this test in eager mode.")
947
948    mirrored = _make_mirrored()
949    v = mirrored.values[0]
950    self.assertEqual(v.name, mirrored.name)
951    self.assertEqual(v.dtype, mirrored.dtype)
952    self.assertEqual(v.shape, mirrored.shape)
953
954  @test_util.run_in_graph_and_eager_modes(config=config)
955  def testVariableOnAnotherDevice(self):
956    v = variable_scope.get_variable(
957        name="v", initializer=[1.], use_resource=True)
958    mirrored = values_lib.MirroredVariable(
959        None, (v,), variable_scope.VariableAggregation.MEAN)
960
961    self.assertEqual(v.name, mirrored.name)
962    self.assertEqual(v.dtype, mirrored.dtype)
963    self.assertEqual(v.shape, mirrored.shape)
964
965
966class MirroredVariableSaveRestoreTest(test.TestCase, parameterized.TestCase):
967
968  def _assign_mirrored(self, v, new):
969    for var, n in zip(v.values, new):
970      self.evaluate(var.assign(n))
971
972  def _save_return_saver(self, sess, var):
973    saver = saver_lib.Saver(var_list=[var])
974    test_dir = self.get_temp_dir()
975    prefix = os.path.join(test_dir, "ckpt")
976    return saver.save(sess, prefix), saver
977
978  def _save(self, sess, var):
979    save_path, _ = self._save_return_saver(sess, var)
980    return save_path
981
982  def _save_mirrored(self, distribution):
983    """Save variables with mirroring, returns save_path."""
984    with self.session(graph=ops.Graph()) as sess:
985      mirrored = _make_mirrored(distribution)
986
987      # Overwrite the initial values.
988      self._assign_mirrored(mirrored, [3., 4.])
989
990      # Saves the current value of v[0], 3.
991      save_path = self._save(sess, mirrored)
992
993      # Change the values between save and restore.
994      self._assign_mirrored(mirrored, [5., 6.])
995    return save_path
996
997  def _save_normal(self):
998    """Save variables without mirroring, returns save_path."""
999    with self.session(graph=ops.Graph()) as sess:
1000      var = variable_scope.get_variable(
1001          name="v", initializer=1., use_resource=True)
1002
1003      # Overwrite the initial value.
1004      self.evaluate(var.assign(3.))
1005
1006      # Saves the current value of var, 3.
1007      save_path = self._save(sess, var)
1008
1009      # Change the values between save and restore.
1010      self.evaluate(var.assign(5.))
1011    return save_path
1012
1013  def _restore_normal(self, save_path):
1014    """Restore to variables without mirroring in a fresh graph."""
1015    with self.session(graph=ops.Graph()) as sess:
1016      var = variable_scope.get_variable(
1017          name="v", initializer=7., use_resource=True)
1018
1019      # Overwrite the initial value.
1020      self.evaluate(var.assign(8.))
1021
1022      # Restores the saved value of 3. to `var`.
1023      saver = saver_lib.Saver(var_list=[var])
1024      saver.restore(sess, save_path)
1025      self.assertEqual(3., self.evaluate(var))
1026
1027  def _restore_mirrored(self, save_path, distribution):
1028    """Restore to variables with mirroring in a fresh graph."""
1029    with self.session(graph=ops.Graph()) as sess:
1030      mirrored = _make_mirrored(distribution)
1031      v = mirrored.values
1032
1033      # Overwrite the initial values.
1034      self._assign_mirrored(mirrored, [7., 8.])
1035
1036      # Restores the saved value of 3. to both variables.
1037      saver = saver_lib.Saver(var_list=[mirrored])
1038      saver.restore(sess, save_path)
1039      self.assertEqual([3., 3.], self.evaluate([v[0], v[1]]))
1040
1041  @combinations.generate(mirrored_and_tpu_strategy_combinations())
1042  def testSaveAndRestoreMirroredOneGraph(self, distribution):
1043    with self.cached_session() as sess:
1044      mirrored = _make_mirrored(distribution)
1045      v = mirrored  .values
1046
1047      # Overwrite the initial values.
1048      self._assign_mirrored(mirrored, [3., 4.])
1049
1050      # Saves the current value of v[0], 3.
1051      save_path, saver = self._save_return_saver(sess, mirrored)
1052
1053      # Change the values between save and restore.
1054      self._assign_mirrored(mirrored, [5., 6.])
1055
1056      # Restores the saved value of 3. to both variables.
1057      saver.restore(sess, save_path)
1058      self.assertEqual([3., 3.], self.evaluate([v[0], v[1]]))
1059
1060  @combinations.generate(mirrored_and_tpu_strategy_combinations())
1061  def testSaveMirroredRestoreMirrored(self, distribution):
1062    if context.num_gpus() < 1 and context.executing_eagerly():
1063      # Graph mode can work without GPU because the Placer "moves" the
1064      # variable to a CPU. In other words, if there is no GPU available, but
1065      # user requested to create a variable on GPU, Placer will ignore the
1066      # user request and assign the VarHandleOp to CPU. This requires
1067      # soft_placement, which is on by default.
1068      self.skipTest("A GPU is not available for this test in eager mode.")
1069
1070    save_path = self._save_mirrored(distribution)
1071    self._restore_mirrored(save_path, distribution)
1072
1073  @combinations.generate(mirrored_and_tpu_strategy_combinations())
1074  def testSaveMirroredRestoreNormal(self, distribution):
1075    if context.num_gpus() < 1 and context.executing_eagerly():
1076      # Graph mode can work without GPU because the Placer "moves" the
1077      # variable to a CPU. In other words, if there is no GPU available, but
1078      # user requested to create a variable on GPU, Placer will ignore the
1079      # user request and assign the VarHandleOp to CPU. This requires
1080      # soft_placement, which is on by default.
1081      self.skipTest("A GPU is not available for this test in eager mode.")
1082
1083    save_path = self._save_mirrored(distribution)
1084    self._restore_normal(save_path)
1085
1086  @combinations.generate(mirrored_and_tpu_strategy_combinations())
1087  def testSaveNormalRestoreMirrored(self, distribution):
1088    if context.num_gpus() < 1 and context.executing_eagerly():
1089      # Graph mode can work without GPU because the Placer "moves" the
1090      # variable to a CPU. In other words, if there is no GPU available, but
1091      # user requested to create a variable on GPU, Placer will ignore the
1092      # user request and assign the VarHandleOp to CPU. This requires
1093      # soft_placement, which is on by default.
1094      self.skipTest("A GPU is not available for this test in eager mode.")
1095
1096    save_path = self._save_normal()
1097    self._restore_mirrored(save_path, distribution)
1098
1099
1100_TPU_STRATEGIES = (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)
1101
1102
1103def _make_replica_local(method, strategy=None):
1104  if strategy is None:
1105    devices = ("/device:GPU:0", "/device:CPU:0")
1106  else:
1107    devices = strategy.extended.worker_devices
1108
1109  v = []
1110  for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]):
1111    with ops.device(d):
1112      v.append(variable_scope.get_variable(
1113          name=n, initializer=init, use_resource=True))
1114
1115  if (strategy is not None) and isinstance(strategy, _TPU_STRATEGIES):
1116    var_cls = tpu_values.TPUSyncOnReadVariable
1117  else:
1118    var_cls = values_lib.SyncOnReadVariable
1119  replica_local = var_cls(strategy, v, method)
1120  return v, replica_local
1121
1122
1123class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
1124
1125  def _assign_replica_local(self, v, new):
1126    for var, n in zip(v, new):
1127      with ops.device(var.device):
1128        self.evaluate(var.assign(n))
1129
1130  def _save_return_saver(self, sess, var):
1131    saver = saver_lib.Saver(var_list=[var])
1132    test_dir = self.get_temp_dir()
1133    prefix = os.path.join(test_dir, "ckpt")
1134    return saver.save(sess, prefix), saver
1135
1136  def _save(self, sess, var):
1137    save_path, _ = self._save_return_saver(sess, var)
1138    return save_path
1139
1140  config = config_pb2.ConfigProto()
1141  config.allow_soft_placement = True
1142
1143  @test_util.run_in_graph_and_eager_modes(config=config)
1144  def testProperties(self):
1145    if context.num_gpus() < 1 and context.executing_eagerly():
1146      self.skipTest("A GPU is not available for this test in eager mode.")
1147    v, replica_local = _make_replica_local(
1148        variable_scope.VariableAggregation.SUM)
1149
1150    self.assertEqual(v[0].constraint, replica_local.constraint)
1151    self.assertEqual(v[0].name, replica_local.name)
1152    self.assertEqual(v[0].dtype, replica_local.dtype)
1153    self.assertEqual(v[0].shape, replica_local.shape)
1154    self.assertEqual(variable_scope.VariableAggregation.SUM,
1155                     replica_local.aggregation)
1156
1157  @test_util.run_v2_only
1158  def testCanPassToDefFun(self):
1159    @def_function.function
1160    def add1(x):
1161      return x + 1
1162
1163    v = variable_scope.get_variable(
1164        name="v", initializer=[1.], use_resource=True)
1165    replica_local = values_lib.SyncOnReadVariable(
1166        None, (v,), variable_scope.VariableAggregation.MEAN)
1167    self.assertEqual(2., self.evaluate(add1(replica_local)))
1168
1169  @combinations.generate(mirrored_and_tpu_strategy_combinations())
1170  def testTensorConversion(self, distribution):
1171    with context.graph_mode():
1172      _, replica_local = _make_replica_local(
1173          variable_scope.VariableAggregation.SUM, distribution)
1174      converted = ops.convert_to_tensor(replica_local, as_ref=False)
1175      self.assertIsInstance(converted, ops.Tensor)
1176      self.assertEqual(converted.dtype, replica_local.dtype)
1177
1178      converted = ops.convert_to_tensor(replica_local, as_ref=True)
1179      # Resources variable are converted to tensors as well when as_ref is True.
1180      self.assertIsInstance(converted, ops.Tensor)
1181      self.assertEqual(converted.dtype, replica_local.dtype)
1182
1183  @combinations.generate(combinations.combine(
1184      distribution=[
1185          strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1186          strategy_combinations.tpu_strategy,
1187          strategy_combinations.tpu_strategy_packed_var,
1188      ], mode=["eager"]))
1189  def testValueInCrossReplicaContext(self, distribution):
1190    value_list, replica_local = _make_replica_local(
1191        variable_scope.VariableAggregation.ONLY_FIRST_REPLICA, distribution)
1192
1193    self.assertIsInstance(replica_local.value(), ops.Tensor)
1194    self.assertEqual(self.evaluate(replica_local.value()),
1195                     self.evaluate(value_list[0].value()))
1196
1197  @combinations.generate(mirrored_and_tpu_strategy_combinations())
1198  def testSaveAndRestoreReplicaLocalSumOneGraph(self, distribution):
1199    with self.cached_session() as sess:
1200      v, replica_local = _make_replica_local(
1201          variable_scope.VariableAggregation.SUM, distribution)
1202
1203      # Overwrite the initial values.
1204      self._assign_replica_local(v, [3., 4.])
1205
1206      with distribution.scope():
1207        # Saves the current value of v[0] + v[1], 7.
1208        save_path, saver = self._save_return_saver(sess, replica_local)
1209
1210        # Change the values between save and restore.
1211        self._assign_replica_local(v, [5., 6.])
1212
1213        # Restores the saved value of 7. which gets divided equally
1214        # between the variables.
1215        saver.restore(sess, save_path)
1216        self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]]))
1217
1218  @combinations.generate(mirrored_and_tpu_strategy_combinations())
1219  def testSaveAndRestoreReplicaLocalMeanOneGraph(self, distribution):
1220    if context.num_gpus() < 1 and context.executing_eagerly():
1221      self.skipTest("A GPU is not available for this test in eager mode.")
1222
1223    with self.cached_session() as sess:
1224      v, replica_local = _make_replica_local(
1225          variable_scope.VariableAggregation.MEAN, distribution)
1226
1227      # Overwrite the initial values.
1228      self._assign_replica_local(v, [3., 4.])
1229
1230      with distribution.scope():
1231        # Saves the current value of (v[0] + v[1])/2, 3.5.
1232        save_path, saver = self._save_return_saver(sess, replica_local)
1233
1234        # Change the values between save and restore.
1235        self._assign_replica_local(v, [5., 6.])
1236
1237        # Restores the saved value of 3.5 to both variables.
1238        saver.restore(sess, save_path)
1239        self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]]))
1240
1241  def _save_replica_local_mean(self, distribution):
1242    """Save variables with mirroring, returns save_path."""
1243    with self.session(graph=ops.Graph()) as sess:
1244      v, replica_local = _make_replica_local(
1245          variable_scope.VariableAggregation.MEAN, distribution)
1246
1247      # Overwrite the initial values.
1248      self._assign_replica_local(v, [3., 4.])
1249
1250      with distribution.scope():
1251        # Saves the current value of (v[0] + v[1])/2, 3.5
1252        save_path = self._save(sess, replica_local)
1253
1254        # Change the values between save and restore.
1255        self._assign_replica_local(v, [5., 6.])
1256    return save_path
1257
1258  def _save_replica_local_sum(self, distribution):
1259    """Save variables with mirroring, returns save_path."""
1260    with self.session(graph=ops.Graph()) as sess:
1261      v, replica_local = _make_replica_local(
1262          variable_scope.VariableAggregation.SUM, distribution)
1263
1264      # Overwrite the initial values.
1265      self._assign_replica_local(v, [1.5, 2.])
1266
1267      with distribution.scope():
1268        # Saves the current value of v[0] + v[1], 3.5
1269        save_path = self._save(sess, replica_local)
1270
1271        # Change the values between save and restore.
1272        self._assign_replica_local(v, [5., 6.])
1273    return save_path
1274
1275  def _save_normal(self):
1276    """Save variables without mirroring, returns save_path."""
1277    with self.session(graph=ops.Graph()) as sess:
1278      var = variable_scope.get_variable(
1279          name="v", initializer=1., use_resource=True)
1280
1281      # Overwrite the initial value.
1282      self.evaluate(var.assign(3.5))
1283
1284      # Saves the current value of var, 3.5.
1285      save_path = self._save(sess, var)
1286
1287      # Change the values between save and restore.
1288      self.evaluate(var.assign(5.))
1289    return save_path
1290
1291  def _restore_normal(self, save_path):
1292    """Restore to variables without mirroring in a fresh graph."""
1293    with self.session(graph=ops.Graph()) as sess:
1294      var = variable_scope.get_variable(
1295          name="v", initializer=7., use_resource=True)
1296
1297      # Overwrite the initial value.
1298      self.evaluate(var.assign(8.))
1299
1300      # Restores the saved value of 3.5 to `var`.
1301      saver = saver_lib.Saver(var_list=[var])
1302      saver.restore(sess, save_path)
1303      self.assertEqual(3.5, self.evaluate(var))
1304
1305  def _restore_replica_local_mean(self, save_path, distribution):
1306    """Restore to variables with mirroring in a fresh graph."""
1307    with self.session(graph=ops.Graph()) as sess:
1308      v, replica_local = _make_replica_local(
1309          variable_scope.VariableAggregation.MEAN, distribution)
1310
1311      # Overwrite the initial values.
1312      self._assign_replica_local(v, [7., 8.])
1313
1314      with distribution.scope():
1315        # Restores the saved value of 3.5 to both variables.
1316        saver = saver_lib.Saver(var_list=[replica_local])
1317        saver.restore(sess, save_path)
1318        self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]]))
1319
1320  def _restore_replica_local_sum(self, save_path, distribution):
1321    """Restore to variables with mirroring in a fresh graph."""
1322    with self.session(graph=ops.Graph()) as sess:
1323      v, replica_local = _make_replica_local(
1324          variable_scope.VariableAggregation.SUM, distribution)
1325
1326      # Overwrite the initial values.
1327      self._assign_replica_local(v, [7., 8.])
1328
1329      with distribution.scope():
1330        # Restores the saved value of 3.5 to both variables.
1331        saver = saver_lib.Saver(var_list=[replica_local])
1332        saver.restore(sess, save_path)
1333        self.assertEqual([1.75, 1.75], self.evaluate([v[0], v[1]]))
1334
1335  @combinations.generate(mirrored_and_tpu_strategy_combinations())
1336  def testSaveReplicaLocalRestoreReplicaLocalMean(self, distribution):
1337    save_path = self._save_replica_local_mean(distribution)
1338    self._restore_replica_local_mean(save_path, distribution)
1339
1340  @combinations.generate(mirrored_and_tpu_strategy_combinations())
1341  def testSaveReplicaLocalRestoreReplicaLocalSum(self, distribution):
1342    save_path = self._save_replica_local_sum(distribution)
1343    self._restore_replica_local_sum(save_path, distribution)
1344
1345  @combinations.generate(mirrored_and_tpu_strategy_combinations())
1346  def testSaveReplicaLocalMeanRestoreNormal(self, distribution):
1347    save_path = self._save_replica_local_mean(distribution)
1348    self._restore_normal(save_path)
1349
1350  @combinations.generate(mirrored_and_tpu_strategy_combinations())
1351  def testSaveReplicaLocalSumRestoreNormal(self, distribution):
1352    save_path = self._save_replica_local_sum(distribution)
1353    self._restore_normal(save_path)
1354
1355  @combinations.generate(mirrored_and_tpu_strategy_combinations())
1356  def testSaveNormalRestoreReplicaLocalMean(self, distribution):
1357    save_path = self._save_normal()
1358    self._restore_replica_local_mean(save_path, distribution)
1359
1360  @combinations.generate(mirrored_and_tpu_strategy_combinations())
1361  def testSaveNormalRestoreReplicaLocalSum(self, distribution):
1362    save_path = self._save_normal()
1363    self._restore_replica_local_sum(save_path, distribution)
1364
1365
1366class MirroredTest(test.TestCase):
1367
1368  def testAddOp(self):
1369    if context.num_gpus() < 1:
1370      self.skipTest("A GPU is not available for this test.")
1371    mirrored_val = _make_mirrored_val(init_val=3.)
1372
1373    self.assertEqual(self.evaluate(constant_op.constant(6.)),
1374                     self.evaluate(mirrored_val + mirrored_val))
1375    self.assertEqual(self.evaluate(constant_op.constant(4.)),
1376                     self.evaluate(mirrored_val + 1))
1377    self.assertEqual(self.evaluate(mirrored_val + 1),
1378                     self.evaluate(math_ops.add(mirrored_val, 1)))
1379    self.assertEqual(type(mirrored_val + 1),
1380                     type(math_ops.add(mirrored_val, 1)))
1381
1382
1383class PerReplicaTest(test.TestCase, parameterized.TestCase):
1384
1385  @combinations.generate(combinations.combine(mode=["eager"]))
1386  def testTypeSpec(self):
1387    vals = (constant_op.constant(1.),)
1388    per_replica = values_lib.PerReplica(vals)
1389
1390    spec = per_replica._type_spec
1391    self.assertEqual(spec._value_specs,
1392                     (tensor_spec.TensorSpec([], dtypes.float32),))
1393
1394  @combinations.generate(combinations.combine(mode=["eager"]))
1395  def testTypeSpecRoundTrip(self):
1396    vals = (constant_op.constant(1.),)
1397    per_replica = values_lib.PerReplica(vals)
1398
1399    spec = per_replica._type_spec
1400    tensor_list = spec._to_components(per_replica)
1401    reconstructed = spec._from_components(tensor_list)
1402
1403    self.assertAllEqual(per_replica.values, reconstructed.values)
1404
1405  @combinations.generate(combinations.combine(mode=["eager"]))
1406  def testTypeSpecNest(self):
1407    vals = (constant_op.constant(1.), constant_op.constant([5., 6.0]),)
1408    per_replica = values_lib.PerReplica(vals)
1409
1410    # Note: nest.map_structure exercises nest.flatten and
1411    # nest.pack_sequence_as.
1412    result = nest.map_structure(
1413        lambda t: t + 10, per_replica, expand_composites=True)
1414
1415    self.assertLen(result.values, 2)
1416    self.assertAllEqual(result.values[0], 11.)
1417    self.assertAllEqual(result.values[1], [15., 16.0])
1418
1419  @test_util.run_in_graph_and_eager_modes
1420  def testIsGraphTensor(self):
1421    per_replica = values_lib.PerReplica((constant_op.constant(1.),))
1422    for t in nest.flatten(per_replica, expand_composites=True):
1423      self.assertEqual(hasattr(t, "graph"), not context.executing_eagerly())
1424
1425  @combinations.generate(combinations.combine(mode=["eager"]))
1426  def testDoesNotTriggerFunctionTracing(self):
1427    traces = []
1428
1429    @def_function.function
1430    def f(x):
1431      traces.append(None)  # Only happens on trace.
1432      return x
1433
1434    per_replica = values_lib.PerReplica((constant_op.constant(1.),))
1435
1436    # Trace once.
1437    f(per_replica)
1438    self.assertNotEmpty(traces)
1439    del traces[:]
1440
1441    per_replica_spec = per_replica._type_spec
1442    for _ in range(5):
1443      vals = per_replica_spec._to_components(per_replica)
1444      vals = [v * 2 for v in vals]
1445      per_replica = per_replica_spec._from_components(vals)
1446
1447      output = f(per_replica)
1448      self.assertIsInstance(output, values_lib.PerReplica)
1449      self.assertAllEqual(output._values, per_replica._values)
1450      self.assertEmpty(traces)  # Make sure we're not re-tracing `f`.
1451
1452  @combinations.generate(combinations.combine(mode=["eager"]))
1453  def testFunctionCanReturnPerReplica(self):
1454    f = def_function.function(lambda x: x)
1455    x = values_lib.PerReplica((constant_op.constant(1.),))
1456    y = f(x)
1457    self.assertIsNot(x, y)
1458    nest.map_structure(self.assertAllEqual, x, y, expand_composites=True)
1459    self.assertEqual(x._type_spec, y._type_spec)
1460
1461  @test_util.run_in_graph_and_eager_modes
1462  def testCondWithTensorValues(self):
1463    per_replica_1 = values_lib.PerReplica((constant_op.constant("a"),))
1464    per_replica_2 = values_lib.PerReplica((constant_op.constant(["b", "c"]),))
1465    condition = array_ops.placeholder_with_default(True, [])
1466
1467    result = control_flow_ops.cond(
1468        condition, lambda: per_replica_1, lambda: per_replica_2)
1469
1470    self.assertLen(result.values, 1)
1471    self.assertAllEqual(result.values[0], "a")
1472
1473  @test_util.run_in_graph_and_eager_modes
1474  def testCondWithValuesConvertibleToTensor(self):
1475    per_replica_1 = values_lib.PerReplica(("a",))
1476    per_replica_2 = values_lib.PerReplica(("b",))
1477    condition = array_ops.placeholder_with_default(True, [])
1478
1479    result = control_flow_ops.cond(
1480        condition, lambda: per_replica_1, lambda: per_replica_2)
1481
1482    self.assertLen(result.values, 1)
1483    self.assertAllEqual(result.values[0], "a")
1484
1485  @test_util.build_as_function_and_v1_graph
1486  def testCondWithValuesNotConvertibleToTensor(self):
1487    per_replica_1 = values_lib.PerReplica(({"a"},))
1488    per_replica_2 = values_lib.PerReplica(({"b", "c"},))
1489    condition = array_ops.placeholder(dtypes.bool, [])
1490
1491    with self.assertRaisesRegex(TypeError, "Could not build a TypeSpec for"):
1492      control_flow_ops.cond(
1493          condition, lambda: per_replica_1, lambda: per_replica_2)
1494
1495
1496def _make_index_slices(values, indices, dense_shape=None):
1497  if dense_shape:
1498    dense_shape = array_ops.identity(dense_shape)
1499  return indexed_slices.IndexedSlices(
1500      array_ops.identity(values), array_ops.identity(indices), dense_shape)
1501
1502
1503if __name__ == "__main__":
1504  ds_test_util.main()
1505