• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the 'License');
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an 'AS IS' BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for test utilities."""
16
17from absl.testing import parameterized
18
19from tensorflow.python.distribute import combinations
20from tensorflow.python.distribute import strategy_combinations
21from tensorflow.python.distribute import test_util
22from tensorflow.python.eager import context
23from tensorflow.python.eager import def_function
24from tensorflow.python.eager import test
25from tensorflow.python.framework import config
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.ops import array_ops
29
30
31@combinations.generate(
32    combinations.combine(
33        strategy=[
34            strategy_combinations.multi_worker_mirrored_2x1_cpu,
35            strategy_combinations.multi_worker_mirrored_2x1_gpu,
36            strategy_combinations.multi_worker_mirrored_2x2_gpu,
37        ] + strategy_combinations.strategies_minus_tpu,
38        mode=['eager', 'graph']))
39class GatherTest(test.TestCase, parameterized.TestCase):
40
41  def testOne(self, strategy):
42
43    @def_function.function
44    def f():
45      return array_ops.ones((), dtypes.float32)
46
47    results = test_util.gather(strategy, strategy.run(f))
48    self.assertAllEqual(
49        self.evaluate(results), [1.] * strategy.num_replicas_in_sync)
50
51  def testNest(self, strategy):
52
53    @def_function.function
54    def f():
55      return {
56          'foo':
57              array_ops.ones((), dtypes.float32),
58          'bar': [
59              array_ops.zeros((), dtypes.float32),
60              array_ops.ones((), dtypes.float32),
61          ]
62      }
63
64    results = test_util.gather(strategy, strategy.run(f))
65    self.assertAllEqual(
66        self.evaluate(results['foo']), [1.] * strategy.num_replicas_in_sync)
67    self.assertAllEqual(
68        self.evaluate(results['bar'][0]), [0.] * strategy.num_replicas_in_sync)
69    self.assertAllEqual(
70        self.evaluate(results['bar'][1]), [1.] * strategy.num_replicas_in_sync)
71
72
73class LogicalDevicesTest(test.TestCase):
74
75  def testLogicalCPUs(self):
76    context._reset_context()
77    test_util.set_logical_devices_to_at_least('CPU', 3)
78    cpu_device = config.list_physical_devices('CPU')[0]
79    self.assertLen(config.get_logical_device_configuration(cpu_device), 3)
80
81
82class AssertSequentailExecutionTest(test.TestCase):
83
84  def test1(self):
85
86    @def_function.function
87    def f():
88      a = array_ops.identity(1., name='a')
89      b = a + 1
90      c = array_ops.identity(2., name='c')
91      d = array_ops.identity(a + c, name='d')
92      with ops.control_dependencies([b]):
93        e = array_ops.identity(3., name='e')
94      f = array_ops.identity(c + e, name='f')
95      return d, f
96
97    graph = f.get_concrete_function().graph
98    order = test_util.topological_sort_operations(graph.get_operations())
99    a = graph.get_operation_by_name('a')
100    c = graph.get_operation_by_name('c')
101    d = graph.get_operation_by_name('d')
102    e = graph.get_operation_by_name('e')
103    f = graph.get_operation_by_name('f')
104    test_util.assert_sequential_execution(order, [a, d])
105    test_util.assert_sequential_execution(order, [e, a, f])
106    with self.assertRaises(AssertionError):
107      test_util.assert_sequential_execution(order, [a, c])
108    with self.assertRaises(AssertionError):
109      test_util.assert_sequential_execution(order, [f, a, c])
110    with self.assertRaises(AssertionError):
111      test_util.assert_sequential_execution(order, [d, e, a, c])
112
113
114if __name__ == '__main__':
115  test_util.main()
116