• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for the distributed values library."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22import os
23
24from absl.testing import parameterized
25import numpy as np
26
27from tensorflow.core.protobuf import config_pb2
28from tensorflow.python import tf2
29from tensorflow.python.data.ops import dataset_ops
30from tensorflow.python.distribute import collective_all_reduce_strategy
31from tensorflow.python.distribute import combinations
32from tensorflow.python.distribute import distribute_lib
33from tensorflow.python.distribute import distribute_utils
34from tensorflow.python.distribute import packed_distributed_variable as packed
35from tensorflow.python.distribute import parameter_server_strategy
36from tensorflow.python.distribute import ps_values
37from tensorflow.python.distribute import strategy_combinations
38from tensorflow.python.distribute import test_util as ds_test_util
39from tensorflow.python.distribute import tpu_strategy
40from tensorflow.python.distribute import tpu_values
41from tensorflow.python.distribute import values as values_lib
42from tensorflow.python.eager import context
43from tensorflow.python.eager import def_function
44from tensorflow.python.eager import test
45from tensorflow.python.framework import constant_op
46from tensorflow.python.framework import dtypes
47from tensorflow.python.framework import indexed_slices
48from tensorflow.python.framework import ops
49from tensorflow.python.framework import sparse_tensor
50from tensorflow.python.framework import tensor_shape
51from tensorflow.python.framework import tensor_spec
52from tensorflow.python.framework import test_util
53from tensorflow.python.ops import array_ops
54from tensorflow.python.ops import check_ops
55from tensorflow.python.ops import control_flow_ops
56from tensorflow.python.ops import math_ops
57from tensorflow.python.ops import sparse_ops
58from tensorflow.python.ops import variable_scope
59from tensorflow.python.ops import variables as variables_lib
60from tensorflow.python.saved_model import save
61from tensorflow.python.saved_model import save_context
62from tensorflow.python.saved_model import save_options
63from tensorflow.python.training import saver as saver_lib
64from tensorflow.python.training.tracking import util as trackable_utils
65from tensorflow.python.types import core
66from tensorflow.python.util import nest
67
68
69def _device_str(d):
70  return "/device:GPU:" + str(d)
71
72
73def _nested_value(d):
74  return ("a" + d, ["b" + d, {"c": "d" + d, "e": "f" + d}, "g" + d], "h" + d)
75
76
77def _make_mirrored_val(init_val=5.0):
78  v = []
79  devices = ["/device:GPU:0", "/device:CPU:0"]
80  for d, _ in zip(devices, ["v", "v/replica"]):
81    with ops.device(d):
82      v.append(constant_op.constant(init_val))
83  return values_lib.Mirrored(v)
84
85
86def _make_mirrored(distribution=None):
87  v = []
88  if distribution:
89    devices = distribution.extended.worker_devices
90  else:
91    devices = ["/device:GPU:0", "/device:CPU:0"]
92  for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]):
93    with ops.device(d):
94      v.append(
95          variable_scope.get_variable(
96              name=n, initializer=init, use_resource=True))
97
98  if (distribution is not None) and isinstance(distribution, _TPU_STRATEGIES):
99    var_cls = tpu_values.TPUMirroredVariable
100  else:
101    var_cls = values_lib.MirroredVariable
102  mirrored = var_cls(distribution, v, variable_scope.VariableAggregation.SUM)
103  return mirrored
104
105
106def mirrored_and_tpu_strategy_combinations():
107  return combinations.combine(
108      distribution=[
109          strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
110          strategy_combinations.mirrored_strategy_with_two_gpus_no_merge_call,
111          strategy_combinations.tpu_strategy,
112          strategy_combinations.tpu_strategy_packed_var,
113      ],
114      mode=["graph", "eager"])
115
116
117class DistributedValuesTest(test.TestCase, parameterized.TestCase):
118
119  @combinations.generate(
120      combinations.combine(
121          distribution=(strategy_combinations.all_strategies_minus_default +
122                        strategy_combinations.multiworker_strategies),
123          mode=["eager"]
124      ))
125  def testMakeDistributedValueFromTensor(self, distribution):
126    if not tf2.enabled():
127      self.skipTest("Only V2 is supported.")
128    single_value = constant_op.constant(1)
129    def value_fn(ctx):
130      del ctx
131      return single_value
132
133    distributed_values = (
134        distribution.experimental_distribute_values_from_function(value_fn))
135    self.assertAllEqual(
136        ds_test_util.gather(distribution, distributed_values),
137        constant_op.constant(1., shape=(distribution.num_replicas_in_sync)))
138
139  @combinations.generate(
140      combinations.combine(
141          distribution=(strategy_combinations.all_strategies_minus_default +
142                        strategy_combinations.multiworker_strategies),
143          mode=["eager"]
144      ))
145  def testMakeDistributedValueSingleNumpyArrayConstant(self, distribution):
146    if not tf2.enabled():
147      self.skipTest("Only V2 is supported.")
148    array_value = np.array([1., 2., 3.])
149    def value_fn(ctx):
150      del ctx
151      return array_value
152
153    distributed_values = (
154        distribution.experimental_distribute_values_from_function(value_fn))
155    self.assertAllEqual(
156        ds_test_util.gather(distribution, distributed_values).numpy(),
157        [[1., 2., 3.]] * distribution.num_replicas_in_sync)
158
159  @combinations.generate(
160      combinations.combine(
161          distribution=(strategy_combinations.all_strategies_minus_default +
162                        strategy_combinations.multiworker_strategies),
163          mode=["eager"]
164      ))
165  def testMakeDistributedValueTupleConstant(self, distribution):
166    if not tf2.enabled():
167      self.skipTest("Only V2 is supported.")
168    tuple_value = (1., 2., 3.)
169    def value_fn(ctx):
170      del ctx
171      return tuple_value
172    distributed_values = (
173        distribution.experimental_distribute_values_from_function(value_fn))
174    distributed_values = ds_test_util.gather(distribution, distributed_values)
175
176    # Expected output for 2 replicas:
177    # ([1.0, 1.0], [2.0, 2.0], [3.0, 3.0])
178    expected = tuple([v for i in range(distribution.num_replicas_in_sync)]
179                     for v in tuple_value)
180    self.assertAllEqual(distributed_values, expected)
181
182  @combinations.generate(
183      combinations.combine(
184          distribution=(strategy_combinations.all_strategies_minus_default +
185                        strategy_combinations.multiworker_strategies),
186          mode=["eager"]
187      ))
188  def testMakeDistributedValueNestedStructurePerReplica(self, distribution):
189    if not tf2.enabled():
190      self.skipTest("Only V2 is supported.")
191    tuple_value = (1., 2., 3.)
192    def value_fn(ctx):
193      per_replica = []
194      for val in tuple_value:
195        per_replica.append(val * ctx.replica_id_in_sync_group)
196      return tuple(per_replica)
197    distributed_values = (
198        distribution.experimental_distribute_values_from_function(value_fn))
199    distributed_values = ds_test_util.gather(distribution, distributed_values)
200
201    # Expected output for 2 replicas:
202    # ([0.0, 1.0], [0.0, 2.0], [0.0, 3.0])
203    expected = tuple([v * i for i in range(distribution.num_replicas_in_sync)]
204                     for v in tuple_value)
205    self.assertAllEqual(distributed_values, expected)
206
207  # NOTE(priyag): Cannot test this with MultiWorkerMirroredStrategy because
208  # collective ops do not support SparseTensors.
209  @combinations.generate(
210      combinations.combine(
211          distribution=strategy_combinations.all_strategies_minus_default,
212          mode=["eager"]
213      ))
214  def testMakeDistributedValueSpareTensor(self, distribution):
215    if not tf2.enabled():
216      self.skipTest("Only V2 is supported.")
217    def value_fn(ctx):
218      del ctx
219      return sparse_tensor.SparseTensor(
220          indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
221
222    distributed_values = (
223        distribution.experimental_distribute_values_from_function(value_fn))
224    local_results = distribution.experimental_local_results(distributed_values)
225    for i in range(distribution.num_replicas_in_sync):
226      self.assertAllEqual(
227          sparse_ops.sparse_tensor_to_dense(local_results[i]),
228          [[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]])
229
230  @combinations.generate(
231      combinations.combine(
232          distribution=(strategy_combinations.all_strategies_minus_default +
233                        strategy_combinations.multiworker_strategies),
234          mode=["eager"]
235      ))
236  def testMakeDistributedValueExtractFromArray(self, distribution):
237    if not tf2.enabled():
238      self.skipTest("Only V2 is supported.")
239    multiple_values = range(distribution.num_replicas_in_sync)
240    def value_fn(ctx):
241      return multiple_values[ctx.replica_id_in_sync_group]
242    distributed_values = (
243        distribution.experimental_distribute_values_from_function(value_fn))
244    distributed_values = ds_test_util.gather(distribution, distributed_values)
245    expected = range(distribution.num_replicas_in_sync)
246    self.assertAllEqual(distributed_values, expected)
247
248  @combinations.generate(
249      combinations.combine(
250          distribution=(strategy_combinations.all_strategies_minus_default +
251                        strategy_combinations.multiworker_strategies),
252          mode=["eager"]
253      ))
254  def testMakeDistributedValueAndRun(self, distribution):
255    if not tf2.enabled():
256      self.skipTest("Only V2 is supported.")
257
258    @def_function.function
259    def run():
260      multiple_values = range(distribution.num_replicas_in_sync)
261      def value_fn(ctx):
262        return multiple_values[ctx.replica_id_in_sync_group]
263      distributed_values = (
264          distribution.experimental_distribute_values_from_function(value_fn))
265
266      def computation(x):
267        return math_ops.square(x)
268
269      outputs = ds_test_util.gather(
270          distribution,
271          distribution.run(computation, args=(distributed_values,)))
272      return outputs
273
274    results = run()
275
276    expected = [i**2 for i in range(distribution.num_replicas_in_sync)]
277    self.assertAllEqual(results, expected)
278
279  @combinations.generate(
280      combinations.combine(
281          distribution=[
282              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
283              strategy_combinations
284              .mirrored_strategy_with_two_gpus_no_merge_call,
285              strategy_combinations.tpu_strategy,
286              strategy_combinations.tpu_strategy_packed_var,
287              strategy_combinations.central_storage_strategy_with_two_gpus,
288          ] + strategy_combinations.multiworker_strategies,
289          mode=["eager"]))
290  def testMakeDistributedValueDefaultDevicePlacement(self, distribution):
291    if not tf2.enabled():
292      self.skipTest("Only V2 is supported.")
293    def value_fn(ctx):
294      del ctx
295      return constant_op.constant(1.0)
296    distributed_values = (
297        distribution.experimental_distribute_values_from_function(value_fn))
298    default_device = array_ops.identity(constant_op.constant(1.0)).device
299    for i in range(len(distribution.extended.worker_devices)):
300      self.assertAllEqual(distributed_values._values[i].device, default_device)
301
302  @combinations.generate(
303      combinations.combine(
304          distribution=[
305              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
306              strategy_combinations
307              .mirrored_strategy_with_two_gpus_no_merge_call,
308              strategy_combinations.tpu_strategy,
309              strategy_combinations.tpu_strategy_packed_var,
310              strategy_combinations.central_storage_strategy_with_two_gpus,
311          ] + strategy_combinations.multiworker_strategies,
312          mode=["eager"],
313          op_type=[constant_op.constant, array_ops.identity]))
314  def testMakeDistributedValueExplicitDevicePlacement(self, distribution,
315                                                      op_type):
316    if not tf2.enabled():
317      self.skipTest("Only V2 is supported.")
318    worker_devices = distribution.extended.worker_devices
319    def value_fn(ctx):
320      # In multi client setup, worker_devices is just the devices on that
321      # worker.
322      worker_device_id = ctx.replica_id_in_sync_group % len(worker_devices)
323      with ops.device(worker_devices[worker_device_id]):
324        return op_type(1.0)
325
326    distributed_values = (
327        distribution.experimental_distribute_values_from_function(value_fn))
328    for i in range(len(distribution.extended.worker_devices)):
329      self.assertAllEqual(distributed_values._values[i].device,
330                          worker_devices[i])
331
332
333class PerWorkerResourceTest(test.TestCase, parameterized.TestCase):
334
335  @combinations.generate(
336      combinations.combine(dataset_fn_as_tf_function=[True, False]))
337  def testMapFnTracing(self, dataset_fn_as_tf_function):
338    # For a PerWorkerResource to correctly behave when used in dataset.map,
339    # it has to be that the map_fn is not traced only once such that
340    # PerWorkerResource.local_table can return the correct resource. This test
341    # can detect the potential breakage of this behavior on TAP.
342    self._traced_once = 0
343
344    def map_fn(x):
345      self._traced_once += 1
346      return x
347
348    def dataset_fn():
349      dataset = dataset_ops.DatasetV2.from_tensors([0, 1, 2]).repeat().batch(
350          2, drop_remainder=True)
351      dataset = dataset.map(map_fn)
352      return dataset
353
354    datasets = []
355    number_of_input_pipelines = 5
356
357    if dataset_fn_as_tf_function:
358      dataset_fn = def_function.function(dataset_fn)
359      expected_tracing_times = 1
360    else:
361      expected_tracing_times = number_of_input_pipelines
362
363    for _ in range(number_of_input_pipelines):
364      datasets.append(dataset_fn())
365
366    self.assertEqual(self._traced_once, expected_tracing_times)
367
368
369class DistributedDelegateTest(test.TestCase):
370
371  @test_util.run_in_graph_and_eager_modes
372  def testGetAttr(self):
373    class Foo(object):
374
375      def __init__(self, x):
376        self.x = x
377
378    v = values_lib.DistributedDelegate((Foo(7), Foo(8)))
379    self.assertEqual(7, v.x)
380    with self.assertRaises(AttributeError):
381      _ = v.y
382
383  @test_util.run_in_graph_and_eager_modes
384  def testOperatorOverride(self):
385    v = values_lib.DistributedDelegate((7, 8))
386    # v should act like int(7).
387    self.assertEqual(8, v + 1)
388    self.assertEqual(10, 3 + v)
389    self.assertEqual(14, v + v)
390    self.assertEqual(5, v - 2)
391    self.assertEqual(6, 13 - v)
392    self.assertEqual(0, v - v)
393    self.assertEqual(14, v * 2)
394    self.assertEqual(21, 3 * v)
395    self.assertEqual(49, v * v)
396    self.assertEqual(3.5, v / 2)
397    self.assertEqual(1.5, 10.5 / v)
398    self.assertEqual(3, v // 2)
399    self.assertEqual(2, 15 // v)
400    self.assertEqual(1, v % 2)
401    self.assertEqual(2, 16 % v)
402    # pylint: disable=g-generic-assert
403    self.assertTrue(v < 12)
404    self.assertTrue(v <= 12)
405    self.assertFalse(v > 12)
406    self.assertFalse(v >= 12)
407    self.assertFalse(12 < v)
408    self.assertFalse(12 <= v)
409    self.assertTrue(12 > v)
410    self.assertTrue(12 >= v)
411    # pylint: enable=g-generic-assert
412    self.assertEqual(3, v & 3)
413    self.assertEqual(3, 11 & v)
414    self.assertEqual(15, v | 8)
415    self.assertEqual(23, 16 | v)
416    self.assertEqual(4, v ^ 3)
417    self.assertEqual(12, 11 ^ v)
418    self.assertEqual(343, pow(v, 3))
419    self.assertEqual(3, pow(v, 3, 10))
420    self.assertEqual(128, pow(2, v))
421    self.assertEqual(-7, -v)
422    self.assertEqual(~7, ~v)
423    self.assertEqual(7, abs(v))
424    with self.assertRaises(TypeError):
425      _ = v[2]
426
427  @test_util.run_in_graph_and_eager_modes
428  def testCopy(self):
429
430    class Foo(object):
431
432      def __init__(self, x):
433        self.x = x
434
435    v = values_lib.DistributedDelegate((Foo(7), Foo(8)))
436    v_shallow_copy = copy.copy(v)
437    self.assertEqual(v.x, v_shallow_copy.x)
438    v_deep_copy = copy.deepcopy(v)
439    self.assertEqual(v.x, v_deep_copy.x)
440
441
442@combinations.generate(
443    combinations.combine(
444        distribution=[
445            strategy_combinations.mirrored_strategy_with_one_cpu,
446            strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
447            strategy_combinations.mirrored_strategy_with_two_gpus_no_merge_call,
448            strategy_combinations.tpu_strategy,
449            strategy_combinations.tpu_strategy_packed_var,
450            strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
451            strategy_combinations.multi_worker_mirrored_2x1_cpu,
452            strategy_combinations.multi_worker_mirrored_2x1_gpu,
453            strategy_combinations.multi_worker_mirrored_2x2_gpu,
454            strategy_combinations.multi_worker_mirrored_2x2_gpu_no_merge_call,
455        ],
456        synchronization=[
457            variables_lib.VariableSynchronization.ON_READ,
458            variables_lib.VariableSynchronization.ON_WRITE,
459        ],
460        aggregation=[
461            variables_lib.VariableAggregation.MEAN,
462            variables_lib.VariableAggregation.SUM,
463            variables_lib.VariableAggregation.ONLY_FIRST_REPLICA,
464        ],
465        mode=["graph", "eager"],
466        use_var_policy=[True, False]))
467class DistributedVariableTest(test.TestCase, parameterized.TestCase):
468
469  def testExtendsVariable(self, distribution, synchronization, aggregation):
470    with distribution.scope():
471      v = variables_lib.Variable(
472          1., synchronization=synchronization, aggregation=aggregation)
473    self.assertIsInstance(v, variables_lib.Variable)
474
475  def testCheckpointing(self, distribution, synchronization, aggregation, mode):
476
477    if (isinstance(distribution,
478                   collective_all_reduce_strategy.CollectiveAllReduceStrategy)
479        and mode == "graph"):
480      self.skipTest("MWMS combinations tests do not work well in graph mode.")
481
482    with distribution.scope():
483      v = variables_lib.Variable(
484          constant_op.constant([1., 2., 3., 4]),
485          synchronization=synchronization,
486          aggregation=aggregation)
487
488    self.evaluate(v.initializer)
489    before_save = self.evaluate(v.read_value())
490
491    # Save random weights into checkpoint.
492    checkpoint = trackable_utils.Checkpoint(v=v)
493    prefix = os.path.join(self.get_temp_dir(), "ckpt")
494    with self.test_session():
495      save_path = checkpoint.save(prefix)
496
497    # Assign inverted value.
498    self.evaluate(v.assign(constant_op.constant([4., 3., 2., 1.])))
499    after_assign = self.evaluate(v.read_value())
500    self.assertNotAllClose(before_save, after_assign)
501
502    # Restore from the checkpoint.
503    with self.test_session():
504      checkpoint.restore(save_path).assert_consumed().run_restore_ops()
505    after_restore = self.evaluate(v)
506    self.assertAllClose(before_save, after_restore)
507
508  def testTraceback(self, distribution, synchronization, aggregation):
509    if context.executing_eagerly():
510      self.skipTest("does not apply to eager")
511    with distribution.scope():
512      variable_scope.get_variable(
513          name="testVar",
514          initializer=1.,
515          use_resource=True,
516          synchronization=synchronization,
517          aggregation=aggregation)
518      with self.assertRaisesRegex(ValueError,
519                                  "Variable testVar already exists"):
520        variable_scope.get_variable(
521            name="testVar",
522            initializer=1.,
523            use_resource=True,
524            synchronization=synchronization,
525            aggregation=aggregation)
526
527  def testSelectReplica(self, distribution, synchronization, aggregation):
528    with distribution.scope():
529      v = variables_lib.Variable(
530          1., synchronization=synchronization, aggregation=aggregation)
531    self.assertIs(v, distribute_utils.select_replica(0, v))
532
533  def testIsTensorLike(self, distribution, synchronization, aggregation):
534    if isinstance(distribution.extended,
535                  tpu_strategy.TPUExtended) and context.executing_eagerly():
536      self.skipTest("TPU doesn't support pure eager")
537
538    with distribution.scope():
539      v = variables_lib.Variable(
540          0., synchronization=synchronization, aggregation=aggregation)
541    # In cross replica context.
542    self.assertIsInstance(v, core.Tensor)
543    # In replica context.
544    distribution.run(
545        lambda v: self.assertIsInstance(v, core.Tensor), args=(v,))
546
547  def testAssignReturnValueIsTensorLike(self, distribution, synchronization,
548                                        aggregation):
549    if isinstance(distribution.extended, tpu_strategy.TPUExtended):
550      if context.executing_eagerly():
551        self.skipTest("TPU doesn't support pure eager")
552      else:
553        self.skipTest("b/152076846")
554
555    with distribution.scope():
556      v = variables_lib.Variable(
557          0., synchronization=synchronization, aggregation=aggregation)
558
559    def assert_is_tensor_like(v):
560      # We can't use Python literals because they are treated as non-distributed
561      # values is not allowed when aggregation is SUM. See
562      # `cross_device_ops.reduce_non_distributed_value`.
563      delta = array_ops.identity(1.)
564      self.assertIsInstance(v.assign(delta), core.Tensor)
565      self.assertIsInstance(v.assign_sub(delta), core.Tensor)
566      self.assertIsInstance(v.assign_add(delta), core.Tensor)
567
568    # In cross replica context we return a PerReplica which is not Tensor like
569    # all the time yet.
570    if (synchronization == variables_lib.VariableSynchronization.ON_READ and
571        aggregation != variables_lib.VariableAggregation.SUM):
572      assert_is_tensor_like(v)
573
574    # In replica context.
575    distribution.run(assert_is_tensor_like, args=(v,))
576
577  def testDeepCopy(self, distribution, synchronization,
578                   aggregation):
579    if not context.executing_eagerly():
580      self.skipTest("deepcopy only supported in eager mode")
581
582    with distribution.scope():
583      v = variables_lib.Variable(
584          0., synchronization=synchronization, aggregation=aggregation)
585      in_dist_copy = copy.deepcopy(v)
586
587    out_dist_copy = copy.deepcopy(v)
588
589    def assert_is_deep_copy(v1, v2):
590      self.assertIsInstance(v2, type(v1))
591      self.assertEqual(v1.aggregation, v2.aggregation)
592      self.assertEqual(v1.distribute_strategy, v2.distribute_strategy)
593      if isinstance(v1, ps_values.AggregatingVariable):
594        self.assertIsInstance(v2.get(), type(v1.get()))
595        self.assertNotEqual(id(v1.get()), id(v2.get()))
596      else:
597        if v1._policy:
598          self.assertNotEqual(id(v1._policy), id(v2._policy))  # pylint: disable=protected-access
599        else:
600          self.assertEqual(id(v1._policy), id(v2._policy))  # pylint: disable=protected-access
601        self.assertEqual(len(v1.values), len(v2.values))
602        for (v1v, v2v) in zip(v1.values, v2.values):
603          self.assertEqual(v1v.device, v2v.device)
604          self.assertNotEqual(id(v1v), id(v2v))
605          self.assertAllEqual(self.evaluate(v1.values),
606                              self.evaluate(v2.values))
607
608    self.evaluate(variables_lib.global_variables_initializer())
609    if not isinstance(distribution.extended, tpu_strategy.TPUExtended):
610      distribution.run(assert_is_deep_copy, args=(v, in_dist_copy))
611      distribution.run(assert_is_deep_copy, args=(v, out_dist_copy))
612
613  def testAssignSignature(self, distribution, synchronization, aggregation):
614    # This test verifies assign*() can be called in the same way as normal
615    # variables.
616    with distribution.scope():
617      v = variables_lib.Variable(
618          0., synchronization=synchronization, aggregation=aggregation)
619
620      def assign():
621        one = constant_op.constant(1.)
622        v.assign(one, True, "assign", False)
623        # TODO(b/154017756): SyncOnReadVariable.assign() doesn't support passing
624        # value as a keyword argument.
625        v.assign(one, use_locking=True, name="assign", read_value=False)
626        v.assign_add(one, True, "assign", False)
627        v.assign_add(one, use_locking=True, name="assign", read_value=False)
628        v.assign_sub(one, True, "assign", False)
629        v.assign_sub(one, use_locking=True, name="assign", read_value=False)
630        # Return something for graph mode to fetch.
631        return constant_op.constant(1)
632
633      self.evaluate(variables_lib.global_variables_initializer())
634      if not (synchronization == variables_lib.VariableSynchronization.ON_READ
635              and aggregation == variables_lib.VariableAggregation.SUM):
636        self.evaluate(distribution.experimental_local_results(assign()))
637      if not (isinstance(distribution.extended, tpu_strategy.TPUExtended) and
638              context.executing_eagerly()):
639        self.evaluate(
640            distribution.experimental_local_results(distribution.run(assign)))
641
642  def testStrategyExtendedUpdate(self, distribution, synchronization,
643                                 aggregation):
644    if len(distribution.extended.parameter_devices) != 2:
645      self.skipTest("n/a: needs exactly two parameter devices")
646    if (synchronization == variables_lib.VariableSynchronization.ON_WRITE and
647        aggregation != variables_lib.VariableAggregation.NONE):
648      self.skipTest("n/a: doesn't apply to ON_WRITE variable with aggregation")
649    with distribution.scope():
650      v = variables_lib.Variable(
651          0., synchronization=synchronization, aggregation=aggregation)
652    value = values_lib.PerReplica([1., 2.])
653
654    assign_fn = lambda var, value: var.assign(value)
655    self.evaluate(distribution.extended.update(v, assign_fn, args=(value,)))
656    self.assertAllEqual(self.evaluate(v.values), [1., 2.])
657
658    assign_add_fn = lambda var, value: var.assign_add(value)
659    self.evaluate(distribution.extended.update(v, assign_add_fn, args=(value,)))
660    self.assertAllEqual(self.evaluate(v.values), [2., 4.])
661
662    assign_sub_fn = lambda var, value: var.assign_sub(value)
663    self.evaluate(distribution.extended.update(v, assign_sub_fn, args=(value,)))
664    self.assertAllEqual(self.evaluate(v.values), [1., 2.])
665
666    read_assign_fn = lambda var, value: var.assign_add(var.value() + var.
667                                                       read_value())
668    self.evaluate(
669        distribution.extended.update(v, read_assign_fn, args=(value,)))
670    self.assertAllEqual(self.evaluate(v.values), [3., 6.])
671
672  def testSaveNonDistributed(self, distribution, synchronization, aggregation):
673    # This test verifies that the DistributedVariable behave like the primary
674    # variable when saving a non-distributed version of the model (the default).
675    # The test asserts that the function traced under SaveContext has no device
676    # annotations and only reference the primary component of the variable. Note
677    # that please avoid capturing other eager tensors in this test to make the
678    # assertion easy.
679
680    if isinstance(distribution.extended,
681                  parameter_server_strategy.ParameterServerStrategyExtended):
682      self.skipTest("b/148689177: AggregatingVariable doesn't "
683                    "conform to Variable interface well")
684
685    # tf.function requires the return value to be Tensors, which is not always
686    # case for properties and methods of Variable, so we simply discard the
687    # return values.
688    def _discard_return(f):
689      f()
690      return
691
692    def _test(f, v):
693      # This verifies that the function under SaveContext:
694      #   - contains no device annotations.
695      #   - only references the primary component of the variable.
696      g = def_function.function(lambda: _discard_return(f))
697      options = save_options.SaveOptions(
698          experimental_variable_policy=save_options.VariablePolicy.NONE)
699      with save_context.save_context(options):
700        # The graph should contain no device.
701        graph = g.get_concrete_function().graph
702      for op in graph.get_operations():
703        self.assertEqual(op.device, "", msg=str(op))
704      # The function should only capture the primary variable. Note that it
705      # may not have captures, e.g. v.aggregation.
706      captures = list(graph.captures)
707      self.assertLessEqual(len(captures), 1)
708      if graph.captures:
709        self.assertIs(captures[0][0], v._primary.handle)
710
711    def _assert(cond):
712      return control_flow_ops.Assert(cond, [cond])
713
714    with distribution.scope():
715      # We use four variables for convenience reasons. They have no special
716      # meaning.
717      # - v is used whenever possible.
718      # - w is used for scatter and gather, which require the variable to be
719      # non-scalar.
720      # - y is used when the dtype needs to be integer. Note that aggregation
721      # cannot be MEAN for integers.
722      v = variables_lib.Variable(
723          0.,
724          synchronization=synchronization,
725          aggregation=aggregation,
726          trainable=True)
727      w = variables_lib.Variable([0., 0., 0.],
728                                 synchronization=synchronization,
729                                 aggregation=aggregation,
730                                 trainable=True)
731      if aggregation != variables_lib.VariableAggregation.MEAN:
732        y = variables_lib.Variable(
733            0,
734            synchronization=synchronization,
735            aggregation=aggregation)
736
737    # pylint: disable=g-long-lambda
738
739    # tf.Variable properties.
740    _test(lambda: self.assertEqual(v.aggregation, aggregation), v)
741    _test(lambda: self.assertIs(v.constraint, None), v)
742    # TODO(crccw): should we raise an error instead?
743    _test(lambda: self.assertEqual(v.device, v._primary.device), v)
744    _test(lambda: self.assertEqual(v.dtype, dtypes.float32), v)
745    if not context.executing_eagerly():
746      _test(lambda: self.assertIs(v.graph, v._primary.graph), v)
747    if not context.executing_eagerly():
748      _test(lambda: _assert(v.initial_value == 0), v)
749    _test(lambda: self.assertIs(v.initializer, v._primary.initializer), v)
750    _test(lambda: self.assertEqual(v.name, "Variable:0"), v)
751    if not context.executing_eagerly():
752      _test(lambda: self.assertIs(v.op, v._primary.op), v)
753    _test(lambda: self.assertEqual(v.shape, tensor_shape.TensorShape(())), v)
754    _test(lambda: self.assertEqual(v.synchronization, synchronization), v)
755    _test(lambda: self.assertTrue(v.trainable, True), v)
756
757    # tf.Variable methods.
758    _test(lambda: check_ops.assert_equal_v2(v.assign(1.), 1.), v)
759    _test(lambda: check_ops.assert_equal_v2(v.assign_add(1.), 2.), v)
760    _test(lambda: check_ops.assert_equal_v2(v.assign_sub(1.), 1.), v)
761    # TODO(b/148689177): Implement batch_scatter_update.
762    # count_up_to() is skipped since it's deprecated.
763    # eval() is skipped since it shouldn't called in a tf.function.
764    # experimental_ref() is skipped since it's deprecated.
765    # from_proto() is skipped since it shouldn't called in a tf.function.
766    # TODO(b/148689177): Implement gather_nd.
767    _test(
768        lambda: check_ops.assert_equal_v2(v.get_shape(),
769                                          tensor_shape.TensorShape(())), v)
770    # initialized_value() is skipped since it shouldn't called in a tf.function.
771    # load() is skipped since it shouldn't called in a tf.function.
772    _test(lambda: check_ops.assert_equal_v2(v.read_value(), 1.), v)
773    # ref() is skipped since it shouldn't called in a tf.function.
774    _test(
775        lambda: check_ops.assert_equal_v2(
776            w.scatter_add(_make_index_slices(values=[1., 2.], indices=[0, 2])),
777            [1., 0., 2.]), w)
778    _test(
779        lambda: check_ops.assert_equal_v2(
780            w.scatter_div(_make_index_slices(values=[4., 2.], indices=[0, 2])),
781            [0.25, 0., 1.]), w)
782    _test(
783        lambda: check_ops.assert_equal_v2(
784            w.scatter_max(_make_index_slices(values=[1., 0.5], indices=[1, 2])),
785            [0.25, 1., 1.]), w)
786    _test(
787        lambda: check_ops.assert_equal_v2(
788            w.scatter_min(_make_index_slices(values=[1., 0.5], indices=[0, 1])),
789            [0.25, 0.5, 1.]), w)
790    _test(
791        lambda: check_ops.assert_equal_v2(
792            w.scatter_mul(_make_index_slices(values=[2., 0.5], indices=[0, 1])),
793            [0.5, 0.25, 1.]), w)
794    # TODO(b/148689177): Implement scatter_nd_*
795    _test(
796        lambda: check_ops.assert_equal_v2(
797            w.scatter_sub(_make_index_slices(values=[2., 0.5], indices=[0, 1])),
798            [-1.5, -0.25, 1.]), w)
799    _test(
800        lambda: check_ops.assert_equal_v2(
801            w.scatter_update(
802                _make_index_slices(values=[2., 0.5], indices=[0, 1])),
803            [2., 0.5, 1.]), w)
804    # set_shape() is skipped since ResourceVariable doesn't implement it.
805    # to_proto() is skipped since it shouldn't called in a tf.function.
806    _test(lambda: check_ops.assert_equal_v2(v.value(), 1.), v)
807
808    # DistributedVariable should be treated as ResourceVariable, so it needs to
809    # conform to ResourceVariable interface as well.
810    _test(lambda: self.assertIs(v.handle, v._primary.handle), v)
811
812    # Convert to tensor.
813    _test(lambda: check_ops.assert_equal_v2(ops.convert_to_tensor(v), 1.), v)
814
815    # Control dependency.
816    def _with_control_dep():
817      with ops.control_dependencies([v.assign(1.)]):
818        return array_ops.identity(1)
819
820    _test(_with_control_dep, v)
821
822    # Operator overloads.
823    _test(lambda: check_ops.assert_equal_v2(v.assign(7.), 7.), v)
824    _test(lambda: check_ops.assert_equal_v2(v + 1., 8.), v)
825    _test(lambda: check_ops.assert_equal_v2(3 + v, 10.), v)
826    _test(lambda: check_ops.assert_equal_v2(v + v, 14.), v)
827    _test(lambda: check_ops.assert_equal_v2(v - 2., 5.), v)
828    _test(lambda: check_ops.assert_equal_v2(v - v, 0.), v)
829    _test(lambda: check_ops.assert_equal_v2(v * 2., 14.), v)
830    _test(lambda: check_ops.assert_equal_v2(3 * v, 21.), v)
831    _test(lambda: check_ops.assert_equal_v2(v * v, 49.), v)
832    _test(
833        lambda: check_ops.assert_equal_v2(
834            math_ops.cast(v / 2., dtypes.float32), 3.5), v)
835    _test(
836        lambda: check_ops.assert_equal_v2(
837            math_ops.cast(14. / v, dtypes.float32), 2.), v)
838    _test(lambda: _assert(v < 12.), v)
839    _test(lambda: _assert(v <= 12.), v)
840    _test(lambda: _assert(not v > 12.), v)
841    _test(lambda: _assert(not v >= 12.), v)
842    _test(lambda: _assert(not 12. < v), v)
843    _test(lambda: _assert(not 12. <= v), v)
844    _test(lambda: _assert(12. > v), v)
845    _test(lambda: _assert(12. >= v), v)
846    _test(lambda: check_ops.assert_near_v2(pow(v, 3.), 343.), v)
847    _test(lambda: check_ops.assert_near_v2(pow(2., v), 128.), v)
848    _test(lambda: check_ops.assert_equal_v2(abs(v), 7.), v)
849
850    # Operator overloads that only works for integers.
851    if aggregation != variables_lib.VariableAggregation.MEAN:
852      _test(lambda: check_ops.assert_equal_v2(y.assign(7), 7), y)
853      _test(lambda: check_ops.assert_equal_v2(y // 2, 3), y)
854      _test(lambda: check_ops.assert_equal_v2(15 // y, 2), y)
855      _test(lambda: check_ops.assert_equal_v2(y % 2, 1), y)
856      _test(lambda: check_ops.assert_equal_v2(16 % y, 2), y)
857      _test(lambda: check_ops.assert_equal_v2(y & 3, 3), y)
858      _test(lambda: check_ops.assert_equal_v2(3 & y, 3), y)
859      _test(lambda: check_ops.assert_equal_v2(y | 8, 15), y)
860      _test(lambda: check_ops.assert_equal_v2(16 | y, 23), y)
861      _test(lambda: check_ops.assert_equal_v2(y ^ 3, 4), y)
862      _test(lambda: check_ops.assert_equal_v2(11 ^ y, 12), y)
863      _test(lambda: check_ops.assert_equal_v2(-y, -7), y)
864      _test(lambda: check_ops.assert_equal_v2(~y, ~7), y)
865
866    # Index.
867    if isinstance(distribution.extended, tpu_strategy.TPUExtended):
868      # TODO(b/161572567): slice assignment doesn't work for TPU.
869      _test(lambda: check_ops.assert_equal_v2(w[0], 2.), w)
870    else:
871      _test(lambda: check_ops.assert_equal_v2(w[0].assign(1.), [1., 0.5, 1.]),
872            w)
873      _test(lambda: check_ops.assert_equal_v2(w[0], 1.), w)
874
875    # pylint: enable=g-long-lambda
876
877  def testUnsaveable(self, distribution, synchronization, aggregation, mode):
878    if isinstance(distribution.extended,
879                  parameter_server_strategy.ParameterServerStrategyExtended):
880      self.skipTest("n/a: not appliable to AggregatingVariable")
881    if (isinstance(distribution,
882                   collective_all_reduce_strategy.CollectiveAllReduceStrategy)
883        and mode == "graph"):
884      self.skipTest("MWMS combinations tests do not work well in graph mode.")
885    if not distribution.extended._use_merge_call():
886      self.skipTest("Unsupported combination.")
887    with distribution.scope():
888      v = variables_lib.Variable([1., 1.],
889                                 synchronization=synchronization,
890                                 aggregation=aggregation)
891
892    with self.cached_session():
893      self.evaluate(variables_lib.global_variables_initializer())
894
895    export_dir = self.get_temp_dir()
896
897    def _assert_unsaveable(f):
898      # Ignore if it cannot be traced. Certain combinations are not supported or
899      # yet or not allowed.
900      try:
901        f = def_function.function(f).get_concrete_function()
902      except (NotImplementedError, ValueError):
903        return
904      with self.assertRaisesRegex(ValueError, "f_with_input_signature"):
905        save.save(v, export_dir, signatures=f)
906
907    _assert_unsaveable(lambda: v.assign(ops.convert_to_tensor([1., 1.])))
908    _assert_unsaveable(lambda: v.assign_add(ops.convert_to_tensor([1., 1.])))
909    _assert_unsaveable(lambda: v.assign_sub(ops.convert_to_tensor([1., 1.])))
910    _assert_unsaveable(lambda: v.scatter_add(_make_index_slices([1.], [0])))
911    _assert_unsaveable(lambda: v.scatter_sub(_make_index_slices([1.], [0])))
912    _assert_unsaveable(lambda: v.scatter_mul(_make_index_slices([1.], [0])))
913    _assert_unsaveable(lambda: v.scatter_div(_make_index_slices([1.], [0])))
914    _assert_unsaveable(lambda: v.scatter_min(_make_index_slices([1.], [0])))
915    _assert_unsaveable(lambda: v.scatter_max(_make_index_slices([1.], [0])))
916    _assert_unsaveable(lambda: v.scatter_update(_make_index_slices([1.], [0])))
917    # Reading a ON_READ variable should be unsaveable if either:
918    # 1) CollectiveAllReduceStrategy, and aggregation is MEAN/SUM.
919    # 2) aggregation is SUM.
920    if (synchronization == variables_lib.VariableSynchronization.ON_READ and
921        (aggregation == variables_lib.VariableAggregation.SUM or
922         (not distribution.extended._use_merge_call()) or
923         (isinstance(distribution.extended,
924                     collective_all_reduce_strategy.CollectiveAllReduceExtended)
925          and aggregation == variables_lib.VariableAggregation.MEAN))):
926      _assert_unsaveable(v.read_value)
927      _assert_unsaveable(v.value)
928      _assert_unsaveable(lambda: ops.convert_to_tensor(v))
929    else:
930      # Otherwise reading a variable should be saveable.
931
932      @def_function.function
933      def f():
934        v.read_value()
935        v.value()
936        return ops.convert_to_tensor(v)
937
938      with self.cached_session():
939        save.save(v, export_dir, signatures=f.get_concrete_function())
940
941
942@combinations.generate(
943    combinations.combine(
944        distribution=[
945            strategy_combinations.mirrored_strategy_with_one_cpu,
946            strategy_combinations.tpu_strategy,
947        ],
948        mode=["eager"]))
949class PackedDistributedVariableTest(test.TestCase, parameterized.TestCase):
950
951  def testPackedVariable(self, distribution):
952    with distribution.scope():
953      v0 = variables_lib.Variable(0.)
954    self.assertIsNone(v0._packed_var)
955
956    distribution._enable_packed_variable_in_eager_mode = True
957    with distribution.scope():
958      v1 = variables_lib.Variable(0)
959      self.assertIsInstance(v1._packed_var, packed.PackedDistributedVariable)
960
961    devices = v1._devices
962    for i in range(1, len(devices)):
963      with distribute_lib.ReplicaContext(distribution, i):
964        v1.assign(i)
965    val = v1._get()
966    self.assertIsInstance(val, packed.PackedVarAndDevice)
967    self.assertEqual(val.device, devices[0])
968    self.assertEqual(self.evaluate(val.read_value()), 0)
969    for i in range(0, len(devices)):
970      with distribute_lib.ReplicaContext(distribution, i):
971        val = v1._get()
972        self.assertIsInstance(val, packed.PackedVarAndDevice)
973        self.assertEqual(val.device, devices[i])
974        self.assertEqual(self.evaluate(val.read_value()), i)
975
976  def testIgnorePackedVariableInSaveContext(self, distribution):
977    distribution._enable_packed_variable_in_eager_mode = True
978    with distribution.scope():
979      v = variables_lib.Variable(0)
980      self.assertIsInstance(
981          v._packed_variable, packed.PackedDistributedVariable)
982
983    options = save_options.SaveOptions()
984    with save_context.save_context(options):
985      self.assertIsNone(v._packed_variable)
986
987
988class MirroredVariableTest(test.TestCase, parameterized.TestCase):
989
990  config = config_pb2.ConfigProto()
991  config.allow_soft_placement = True
992
993  @test_util.run_in_graph_and_eager_modes(config=config)
994  def testProperties(self):
995    if context.num_gpus() < 1 and context.executing_eagerly():
996      self.skipTest("A GPU is not available for this test in eager mode.")
997
998    mirrored = _make_mirrored()
999    v = mirrored.values[0]
1000    self.assertEqual(v.name, mirrored.name)
1001    self.assertEqual(v.dtype, mirrored.dtype)
1002    self.assertEqual(v.shape, mirrored.shape)
1003
1004  @test_util.run_in_graph_and_eager_modes(config=config)
1005  def testVariableOnAnotherDevice(self):
1006    v = variable_scope.get_variable(
1007        name="v", initializer=[1.], use_resource=True)
1008    mirrored = values_lib.MirroredVariable(
1009        None, (v,), variable_scope.VariableAggregation.MEAN)
1010
1011    self.assertEqual(v.name, mirrored.name)
1012    self.assertEqual(v.dtype, mirrored.dtype)
1013    self.assertEqual(v.shape, mirrored.shape)
1014
1015
1016class MirroredVariableSaveRestoreTest(test.TestCase, parameterized.TestCase):
1017
1018  def _assign_mirrored(self, v, new):
1019    for var, n in zip(v.values, new):
1020      self.evaluate(var.assign(n))
1021
1022  def _save_return_saver(self, sess, var):
1023    saver = saver_lib.Saver(var_list=[var])
1024    test_dir = self.get_temp_dir()
1025    prefix = os.path.join(test_dir, "ckpt")
1026    return saver.save(sess, prefix), saver
1027
1028  def _save(self, sess, var):
1029    save_path, _ = self._save_return_saver(sess, var)
1030    return save_path
1031
1032  def _save_mirrored(self, distribution):
1033    """Save variables with mirroring, returns save_path."""
1034    with self.session(graph=ops.Graph()) as sess:
1035      mirrored = _make_mirrored(distribution)
1036
1037      # Overwrite the initial values.
1038      self._assign_mirrored(mirrored, [3., 4.])
1039
1040      # Saves the current value of v[0], 3.
1041      save_path = self._save(sess, mirrored)
1042
1043      # Change the values between save and restore.
1044      self._assign_mirrored(mirrored, [5., 6.])
1045    return save_path
1046
1047  def _save_normal(self):
1048    """Save variables without mirroring, returns save_path."""
1049    with self.session(graph=ops.Graph()) as sess:
1050      var = variable_scope.get_variable(
1051          name="v", initializer=1., use_resource=True)
1052
1053      # Overwrite the initial value.
1054      self.evaluate(var.assign(3.))
1055
1056      # Saves the current value of var, 3.
1057      save_path = self._save(sess, var)
1058
1059      # Change the values between save and restore.
1060      self.evaluate(var.assign(5.))
1061    return save_path
1062
1063  def _restore_normal(self, save_path):
1064    """Restore to variables without mirroring in a fresh graph."""
1065    with self.session(graph=ops.Graph()) as sess:
1066      var = variable_scope.get_variable(
1067          name="v", initializer=7., use_resource=True)
1068
1069      # Overwrite the initial value.
1070      self.evaluate(var.assign(8.))
1071
1072      # Restores the saved value of 3. to `var`.
1073      saver = saver_lib.Saver(var_list=[var])
1074      saver.restore(sess, save_path)
1075      self.assertEqual(3., self.evaluate(var))
1076
1077  def _restore_mirrored(self, save_path, distribution):
1078    """Restore to variables with mirroring in a fresh graph."""
1079    with self.session(graph=ops.Graph()) as sess:
1080      mirrored = _make_mirrored(distribution)
1081      v = mirrored.values
1082
1083      # Overwrite the initial values.
1084      self._assign_mirrored(mirrored, [7., 8.])
1085
1086      # Restores the saved value of 3. to both variables.
1087      saver = saver_lib.Saver(var_list=[mirrored])
1088      saver.restore(sess, save_path)
1089      self.assertEqual([3., 3.], self.evaluate([v[0], v[1]]))
1090
1091  @combinations.generate(mirrored_and_tpu_strategy_combinations())
1092  def testSaveAndRestoreMirroredOneGraph(self, distribution):
1093    with self.cached_session() as sess:
1094      mirrored = _make_mirrored(distribution)
1095      v = mirrored  .values
1096
1097      # Overwrite the initial values.
1098      self._assign_mirrored(mirrored, [3., 4.])
1099
1100      # Saves the current value of v[0], 3.
1101      save_path, saver = self._save_return_saver(sess, mirrored)
1102
1103      # Change the values between save and restore.
1104      self._assign_mirrored(mirrored, [5., 6.])
1105
1106      # Restores the saved value of 3. to both variables.
1107      saver.restore(sess, save_path)
1108      self.assertEqual([3., 3.], self.evaluate([v[0], v[1]]))
1109
1110  @combinations.generate(mirrored_and_tpu_strategy_combinations())
1111  def testSaveMirroredRestoreMirrored(self, distribution):
1112    if context.num_gpus() < 1 and context.executing_eagerly():
1113      # Graph mode can work without GPU because the Placer "moves" the
1114      # variable to a CPU. In other words, if there is no GPU available, but
1115      # user requested to create a variable on GPU, Placer will ignore the
1116      # user request and assign the VarHandleOp to CPU. This requires
1117      # soft_placement, which is on by default.
1118      self.skipTest("A GPU is not available for this test in eager mode.")
1119
1120    save_path = self._save_mirrored(distribution)
1121    self._restore_mirrored(save_path, distribution)
1122
1123  @combinations.generate(mirrored_and_tpu_strategy_combinations())
1124  def testSaveMirroredRestoreNormal(self, distribution):
1125    if context.num_gpus() < 1 and context.executing_eagerly():
1126      # Graph mode can work without GPU because the Placer "moves" the
1127      # variable to a CPU. In other words, if there is no GPU available, but
1128      # user requested to create a variable on GPU, Placer will ignore the
1129      # user request and assign the VarHandleOp to CPU. This requires
1130      # soft_placement, which is on by default.
1131      self.skipTest("A GPU is not available for this test in eager mode.")
1132
1133    save_path = self._save_mirrored(distribution)
1134    self._restore_normal(save_path)
1135
1136  @combinations.generate(mirrored_and_tpu_strategy_combinations())
1137  def testSaveNormalRestoreMirrored(self, distribution):
1138    if context.num_gpus() < 1 and context.executing_eagerly():
1139      # Graph mode can work without GPU because the Placer "moves" the
1140      # variable to a CPU. In other words, if there is no GPU available, but
1141      # user requested to create a variable on GPU, Placer will ignore the
1142      # user request and assign the VarHandleOp to CPU. This requires
1143      # soft_placement, which is on by default.
1144      self.skipTest("A GPU is not available for this test in eager mode.")
1145
1146    save_path = self._save_normal()
1147    self._restore_mirrored(save_path, distribution)
1148
1149
1150_TPU_STRATEGIES = (tpu_strategy.TPUStrategy, tpu_strategy.TPUStrategyV1)
1151
1152
1153def _make_replica_local(method, strategy=None):
1154  if strategy is None:
1155    devices = ("/device:GPU:0", "/device:CPU:0")
1156  else:
1157    devices = strategy.extended.worker_devices
1158
1159  v = []
1160  for d, n, init in zip(devices, ["v", "v/replica"], [1., 2.]):
1161    with ops.device(d):
1162      v.append(variable_scope.get_variable(
1163          name=n, initializer=init, use_resource=True))
1164
1165  if (strategy is not None) and isinstance(strategy, _TPU_STRATEGIES):
1166    var_cls = tpu_values.TPUSyncOnReadVariable
1167  else:
1168    var_cls = values_lib.SyncOnReadVariable
1169  replica_local = var_cls(strategy, v, method)
1170  return v, replica_local
1171
1172
1173class SyncOnReadVariableTest(test.TestCase, parameterized.TestCase):
1174
1175  def _assign_replica_local(self, v, new):
1176    for var, n in zip(v, new):
1177      with ops.device(var.device):
1178        self.evaluate(var.assign(n))
1179
1180  def _save_return_saver(self, sess, var):
1181    saver = saver_lib.Saver(var_list=[var])
1182    test_dir = self.get_temp_dir()
1183    prefix = os.path.join(test_dir, "ckpt")
1184    return saver.save(sess, prefix), saver
1185
1186  def _save(self, sess, var):
1187    save_path, _ = self._save_return_saver(sess, var)
1188    return save_path
1189
1190  config = config_pb2.ConfigProto()
1191  config.allow_soft_placement = True
1192
1193  @test_util.run_in_graph_and_eager_modes(config=config)
1194  def testProperties(self):
1195    if context.num_gpus() < 1 and context.executing_eagerly():
1196      self.skipTest("A GPU is not available for this test in eager mode.")
1197    v, replica_local = _make_replica_local(
1198        variable_scope.VariableAggregation.SUM)
1199
1200    self.assertEqual(v[0].constraint, replica_local.constraint)
1201    self.assertEqual(v[0].name, replica_local.name)
1202    self.assertEqual(v[0].dtype, replica_local.dtype)
1203    self.assertEqual(v[0].shape, replica_local.shape)
1204    self.assertEqual(variable_scope.VariableAggregation.SUM,
1205                     replica_local.aggregation)
1206
1207  @combinations.generate(
1208      combinations.combine(
1209          distribution=[
1210              strategy_combinations.mirrored_strategy_with_gpu_and_cpu
1211          ],
1212          mode=["eager"]))
1213  def testCanPassToDefFun(self, distribution):
1214
1215    @def_function.function
1216    def add1(x):
1217      return x + 1.
1218
1219    with distribution.scope():
1220      v = variables_lib.Variable(
1221          1.,
1222          aggregation=variables_lib.VariableAggregation.MEAN,
1223          synchronization=variables_lib.VariableSynchronization.ON_READ)
1224
1225    self.assertEqual(2., self.evaluate(add1(v)))
1226
1227  @combinations.generate(mirrored_and_tpu_strategy_combinations())
1228  def testTensorConversion(self, distribution):
1229    with context.graph_mode():
1230      _, replica_local = _make_replica_local(
1231          variable_scope.VariableAggregation.SUM, distribution)
1232      converted = ops.convert_to_tensor(replica_local, as_ref=False)
1233      self.assertIsInstance(converted, ops.Tensor)
1234      self.assertEqual(converted.dtype, replica_local.dtype)
1235
1236      converted = ops.convert_to_tensor(replica_local, as_ref=True)
1237      # Resources variable are converted to tensors as well when as_ref is True.
1238      self.assertIsInstance(converted, ops.Tensor)
1239      self.assertEqual(converted.dtype, replica_local.dtype)
1240
1241  @combinations.generate(combinations.combine(
1242      distribution=[
1243          strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1244          strategy_combinations.mirrored_strategy_with_two_gpus_no_merge_call,
1245          strategy_combinations.tpu_strategy,
1246          strategy_combinations.tpu_strategy_packed_var,
1247      ], mode=["eager"]))
1248  def testValueInCrossReplicaContext(self, distribution):
1249    value_list, replica_local = _make_replica_local(
1250        variable_scope.VariableAggregation.ONLY_FIRST_REPLICA, distribution)
1251
1252    self.assertIsInstance(replica_local.value(), ops.Tensor)
1253    self.assertEqual(self.evaluate(replica_local.value()),
1254                     self.evaluate(value_list[0].value()))
1255
1256  @combinations.generate(
1257      combinations.combine(
1258          distribution=[
1259              strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
1260              strategy_combinations.tpu_strategy_packed_var,
1261          ],
1262          mode=["eager"]))
1263  def testValueInDefaultReplicaContext(self, distribution):
1264    with distribution.scope():
1265      v1 = variables_lib.Variable(
1266          0.0,
1267          aggregation=variables_lib.VariableAggregation.SUM,
1268          synchronization=variables_lib.VariableSynchronization.ON_READ)
1269      v2 = variables_lib.Variable(
1270          0.0,
1271          aggregation=variables_lib.VariableAggregation.SUM,
1272          synchronization=variables_lib.VariableSynchronization.ON_READ)
1273
1274    @def_function.function
1275    def replica_fn():
1276      v1.assign_add(1.0)
1277      v2.assign_add(2.0)
1278
1279    distribution.run(replica_fn)
1280    sum_v = v1 + v2
1281    self.assertEqual(sum_v, 6.0)
1282
1283  @combinations.generate(mirrored_and_tpu_strategy_combinations())
1284  def testSaveAndRestoreReplicaLocalSumOneGraph(self, distribution):
1285    with self.cached_session() as sess:
1286      v, replica_local = _make_replica_local(
1287          variable_scope.VariableAggregation.SUM, distribution)
1288
1289      # Overwrite the initial values.
1290      self._assign_replica_local(v, [3., 4.])
1291
1292      with distribution.scope():
1293        # Saves the current value of v[0] + v[1], 7.
1294        save_path, saver = self._save_return_saver(sess, replica_local)
1295
1296        # Change the values between save and restore.
1297        self._assign_replica_local(v, [5., 6.])
1298
1299        # Restores the saved value of 7. which gets divided equally
1300        # between the variables.
1301        saver.restore(sess, save_path)
1302        self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]]))
1303
1304  @combinations.generate(mirrored_and_tpu_strategy_combinations())
1305  def testSaveAndRestoreReplicaLocalMeanOneGraph(self, distribution):
1306    if context.num_gpus() < 1 and context.executing_eagerly():
1307      self.skipTest("A GPU is not available for this test in eager mode.")
1308
1309    with self.cached_session() as sess:
1310      v, replica_local = _make_replica_local(
1311          variable_scope.VariableAggregation.MEAN, distribution)
1312
1313      # Overwrite the initial values.
1314      self._assign_replica_local(v, [3., 4.])
1315
1316      with distribution.scope():
1317        # Saves the current value of (v[0] + v[1])/2, 3.5.
1318        save_path, saver = self._save_return_saver(sess, replica_local)
1319
1320        # Change the values between save and restore.
1321        self._assign_replica_local(v, [5., 6.])
1322
1323        # Restores the saved value of 3.5 to both variables.
1324        saver.restore(sess, save_path)
1325        self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]]))
1326
1327  def _save_replica_local_mean(self, distribution):
1328    """Save variables with mirroring, returns save_path."""
1329    with self.session(graph=ops.Graph()) as sess:
1330      v, replica_local = _make_replica_local(
1331          variable_scope.VariableAggregation.MEAN, distribution)
1332
1333      # Overwrite the initial values.
1334      self._assign_replica_local(v, [3., 4.])
1335
1336      with distribution.scope():
1337        # Saves the current value of (v[0] + v[1])/2, 3.5
1338        save_path = self._save(sess, replica_local)
1339
1340        # Change the values between save and restore.
1341        self._assign_replica_local(v, [5., 6.])
1342    return save_path
1343
1344  def _save_replica_local_sum(self, distribution):
1345    """Save variables with mirroring, returns save_path."""
1346    with self.session(graph=ops.Graph()) as sess:
1347      v, replica_local = _make_replica_local(
1348          variable_scope.VariableAggregation.SUM, distribution)
1349
1350      # Overwrite the initial values.
1351      self._assign_replica_local(v, [1.5, 2.])
1352
1353      with distribution.scope():
1354        # Saves the current value of v[0] + v[1], 3.5
1355        save_path = self._save(sess, replica_local)
1356
1357        # Change the values between save and restore.
1358        self._assign_replica_local(v, [5., 6.])
1359    return save_path
1360
1361  def _save_normal(self):
1362    """Save variables without mirroring, returns save_path."""
1363    with self.session(graph=ops.Graph()) as sess:
1364      var = variable_scope.get_variable(
1365          name="v", initializer=1., use_resource=True)
1366
1367      # Overwrite the initial value.
1368      self.evaluate(var.assign(3.5))
1369
1370      # Saves the current value of var, 3.5.
1371      save_path = self._save(sess, var)
1372
1373      # Change the values between save and restore.
1374      self.evaluate(var.assign(5.))
1375    return save_path
1376
1377  def _restore_normal(self, save_path):
1378    """Restore to variables without mirroring in a fresh graph."""
1379    with self.session(graph=ops.Graph()) as sess:
1380      var = variable_scope.get_variable(
1381          name="v", initializer=7., use_resource=True)
1382
1383      # Overwrite the initial value.
1384      self.evaluate(var.assign(8.))
1385
1386      # Restores the saved value of 3.5 to `var`.
1387      saver = saver_lib.Saver(var_list=[var])
1388      saver.restore(sess, save_path)
1389      self.assertEqual(3.5, self.evaluate(var))
1390
1391  def _restore_replica_local_mean(self, save_path, distribution):
1392    """Restore to variables with mirroring in a fresh graph."""
1393    with self.session(graph=ops.Graph()) as sess:
1394      v, replica_local = _make_replica_local(
1395          variable_scope.VariableAggregation.MEAN, distribution)
1396
1397      # Overwrite the initial values.
1398      self._assign_replica_local(v, [7., 8.])
1399
1400      with distribution.scope():
1401        # Restores the saved value of 3.5 to both variables.
1402        saver = saver_lib.Saver(var_list=[replica_local])
1403        saver.restore(sess, save_path)
1404        self.assertEqual([3.5, 3.5], self.evaluate([v[0], v[1]]))
1405
1406  def _restore_replica_local_sum(self, save_path, distribution):
1407    """Restore to variables with mirroring in a fresh graph."""
1408    with self.session(graph=ops.Graph()) as sess:
1409      v, replica_local = _make_replica_local(
1410          variable_scope.VariableAggregation.SUM, distribution)
1411
1412      # Overwrite the initial values.
1413      self._assign_replica_local(v, [7., 8.])
1414
1415      with distribution.scope():
1416        # Restores the saved value of 3.5 to both variables.
1417        saver = saver_lib.Saver(var_list=[replica_local])
1418        saver.restore(sess, save_path)
1419        self.assertEqual([1.75, 1.75], self.evaluate([v[0], v[1]]))
1420
1421  @combinations.generate(mirrored_and_tpu_strategy_combinations())
1422  def testSaveReplicaLocalRestoreReplicaLocalMean(self, distribution):
1423    save_path = self._save_replica_local_mean(distribution)
1424    self._restore_replica_local_mean(save_path, distribution)
1425
1426  @combinations.generate(mirrored_and_tpu_strategy_combinations())
1427  def testSaveReplicaLocalRestoreReplicaLocalSum(self, distribution):
1428    save_path = self._save_replica_local_sum(distribution)
1429    self._restore_replica_local_sum(save_path, distribution)
1430
1431  @combinations.generate(mirrored_and_tpu_strategy_combinations())
1432  def testSaveReplicaLocalMeanRestoreNormal(self, distribution):
1433    save_path = self._save_replica_local_mean(distribution)
1434    self._restore_normal(save_path)
1435
1436  @combinations.generate(mirrored_and_tpu_strategy_combinations())
1437  def testSaveReplicaLocalSumRestoreNormal(self, distribution):
1438    save_path = self._save_replica_local_sum(distribution)
1439    self._restore_normal(save_path)
1440
1441  @combinations.generate(mirrored_and_tpu_strategy_combinations())
1442  def testSaveNormalRestoreReplicaLocalMean(self, distribution):
1443    save_path = self._save_normal()
1444    self._restore_replica_local_mean(save_path, distribution)
1445
1446  @combinations.generate(mirrored_and_tpu_strategy_combinations())
1447  def testSaveNormalRestoreReplicaLocalSum(self, distribution):
1448    save_path = self._save_normal()
1449    self._restore_replica_local_sum(save_path, distribution)
1450
1451
1452class MirroredTest(test.TestCase):
1453
1454  def testAddOp(self):
1455    if context.num_gpus() < 1:
1456      self.skipTest("A GPU is not available for this test.")
1457    mirrored_val = _make_mirrored_val(init_val=3.)
1458
1459    self.assertEqual(self.evaluate(constant_op.constant(6.)),
1460                     self.evaluate(mirrored_val + mirrored_val))
1461    self.assertEqual(self.evaluate(constant_op.constant(4.)),
1462                     self.evaluate(mirrored_val + 1))
1463    self.assertEqual(self.evaluate(mirrored_val + 1),
1464                     self.evaluate(math_ops.add(mirrored_val, 1)))
1465    self.assertEqual(type(mirrored_val + 1),
1466                     type(math_ops.add(mirrored_val, 1)))
1467
1468
1469class PerReplicaTest(test.TestCase, parameterized.TestCase):
1470
1471  @combinations.generate(combinations.combine(mode=["eager"]))
1472  def testTypeSpec(self):
1473    vals = (constant_op.constant(1.),)
1474    per_replica = values_lib.PerReplica(vals)
1475
1476    spec = per_replica._type_spec
1477    self.assertEqual(spec._value_specs,
1478                     (tensor_spec.TensorSpec([], dtypes.float32),))
1479
1480  @combinations.generate(combinations.combine(mode=["eager"]))
1481  def testTypeSpecRoundTrip(self):
1482    vals = (constant_op.constant(1.),)
1483    per_replica = values_lib.PerReplica(vals)
1484
1485    spec = per_replica._type_spec
1486    tensor_list = spec._to_components(per_replica)
1487    reconstructed = spec._from_components(tensor_list)
1488
1489    self.assertAllEqual(per_replica.values, reconstructed.values)
1490
1491  @combinations.generate(combinations.combine(mode=["eager"]))
1492  def testTypeSpecNest(self):
1493    vals = (constant_op.constant(1.), constant_op.constant([5., 6.0]),)
1494    per_replica = values_lib.PerReplica(vals)
1495
1496    # Note: nest.map_structure exercises nest.flatten and
1497    # nest.pack_sequence_as.
1498    result = nest.map_structure(
1499        lambda t: t + 10, per_replica, expand_composites=True)
1500
1501    self.assertLen(result.values, 2)
1502    self.assertAllEqual(result.values[0], 11.)
1503    self.assertAllEqual(result.values[1], [15., 16.0])
1504
1505  @test_util.run_in_graph_and_eager_modes
1506  def testIsGraphTensor(self):
1507    per_replica = values_lib.PerReplica((constant_op.constant(1.),))
1508    for t in nest.flatten(per_replica, expand_composites=True):
1509      self.assertEqual(hasattr(t, "graph"), not context.executing_eagerly())
1510
1511  @combinations.generate(combinations.combine(mode=["eager"]))
1512  def testDoesNotTriggerFunctionTracing(self):
1513    traces = []
1514
1515    @def_function.function
1516    def f(x):
1517      traces.append(None)  # Only happens on trace.
1518      return x
1519
1520    per_replica = values_lib.PerReplica((constant_op.constant(1.),))
1521
1522    # Trace once.
1523    f(per_replica)
1524    self.assertNotEmpty(traces)
1525    del traces[:]
1526
1527    per_replica_spec = per_replica._type_spec
1528    for _ in range(5):
1529      vals = per_replica_spec._to_components(per_replica)
1530      vals = [v * 2 for v in vals]
1531      per_replica = per_replica_spec._from_components(vals)
1532
1533      output = f(per_replica)
1534      self.assertIsInstance(output, values_lib.PerReplica)
1535      self.assertAllEqual(output._values, per_replica._values)
1536      self.assertEmpty(traces)  # Make sure we're not re-tracing `f`.
1537
1538  @combinations.generate(combinations.combine(mode=["eager"]))
1539  def testFunctionCanReturnPerReplica(self):
1540    f = def_function.function(lambda x: x)
1541    x = values_lib.PerReplica((constant_op.constant(1.),))
1542    y = f(x)
1543    self.assertIsNot(x, y)
1544    nest.map_structure(self.assertAllEqual, x, y, expand_composites=True)
1545    self.assertEqual(x._type_spec, y._type_spec)
1546
1547  @test_util.run_in_graph_and_eager_modes
1548  def testCondWithTensorValues(self):
1549    per_replica_1 = values_lib.PerReplica((constant_op.constant("a"),))
1550    per_replica_2 = values_lib.PerReplica((constant_op.constant(["b", "c"]),))
1551    condition = array_ops.placeholder_with_default(True, [])
1552
1553    result = control_flow_ops.cond(
1554        condition, lambda: per_replica_1, lambda: per_replica_2)
1555
1556    self.assertLen(result.values, 1)
1557    self.assertAllEqual(result.values[0], "a")
1558
1559  @test_util.run_in_graph_and_eager_modes
1560  def testCondWithValuesConvertibleToTensor(self):
1561    per_replica_1 = values_lib.PerReplica(("a",))
1562    per_replica_2 = values_lib.PerReplica(("b",))
1563    condition = array_ops.placeholder_with_default(True, [])
1564
1565    result = control_flow_ops.cond(
1566        condition, lambda: per_replica_1, lambda: per_replica_2)
1567
1568    self.assertLen(result.values, 1)
1569    self.assertAllEqual(result.values[0], "a")
1570
1571  @test_util.build_as_function_and_v1_graph
1572  def testCondWithValuesNotConvertibleToTensor(self):
1573    per_replica_1 = values_lib.PerReplica(({"a"},))
1574    per_replica_2 = values_lib.PerReplica(({"b", "c"},))
1575    condition = array_ops.placeholder(dtypes.bool, [])
1576
1577    with self.assertRaisesRegex(TypeError, "Could not build a TypeSpec for"):
1578      control_flow_ops.cond(
1579          condition, lambda: per_replica_1, lambda: per_replica_2)
1580
1581
1582def _make_index_slices(values, indices, dense_shape=None):
1583  if dense_shape:
1584    dense_shape = array_ops.identity(dense_shape)
1585  return indexed_slices.IndexedSlices(
1586      array_ops.identity(values), array_ops.identity(indices), dense_shape)
1587
1588
1589if __name__ == "__main__":
1590  ds_test_util.main()
1591