1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the 'License'); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an 'AS IS' BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for test utilities.""" 16 17from absl.testing import parameterized 18 19from tensorflow.python.distribute import combinations 20from tensorflow.python.distribute import strategy_combinations 21from tensorflow.python.distribute import test_util 22from tensorflow.python.eager import context 23from tensorflow.python.eager import def_function 24from tensorflow.python.eager import test 25from tensorflow.python.framework import config 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.ops import array_ops 29 30 31@combinations.generate( 32 combinations.combine( 33 strategy=[ 34 strategy_combinations.multi_worker_mirrored_2x1_cpu, 35 strategy_combinations.multi_worker_mirrored_2x1_gpu, 36 strategy_combinations.multi_worker_mirrored_2x2_gpu, 37 ] + strategy_combinations.strategies_minus_tpu, 38 mode=['eager', 'graph'])) 39class GatherTest(test.TestCase, parameterized.TestCase): 40 41 def testOne(self, strategy): 42 43 @def_function.function 44 def f(): 45 return array_ops.ones((), dtypes.float32) 46 47 results = test_util.gather(strategy, strategy.run(f)) 48 self.assertAllEqual( 49 self.evaluate(results), [1.] * strategy.num_replicas_in_sync) 50 51 def testNest(self, strategy): 52 53 @def_function.function 54 def f(): 55 return { 56 'foo': 57 array_ops.ones((), dtypes.float32), 58 'bar': [ 59 array_ops.zeros((), dtypes.float32), 60 array_ops.ones((), dtypes.float32), 61 ] 62 } 63 64 results = test_util.gather(strategy, strategy.run(f)) 65 self.assertAllEqual( 66 self.evaluate(results['foo']), [1.] * strategy.num_replicas_in_sync) 67 self.assertAllEqual( 68 self.evaluate(results['bar'][0]), [0.] * strategy.num_replicas_in_sync) 69 self.assertAllEqual( 70 self.evaluate(results['bar'][1]), [1.] * strategy.num_replicas_in_sync) 71 72 73class LogicalDevicesTest(test.TestCase): 74 75 def testLogicalCPUs(self): 76 context._reset_context() 77 test_util.set_logical_devices_to_at_least('CPU', 3) 78 cpu_device = config.list_physical_devices('CPU')[0] 79 self.assertLen(config.get_logical_device_configuration(cpu_device), 3) 80 81 82class AssertSequentailExecutionTest(test.TestCase): 83 84 def test1(self): 85 86 @def_function.function 87 def f(): 88 a = array_ops.identity(1., name='a') 89 b = a + 1 90 c = array_ops.identity(2., name='c') 91 d = array_ops.identity(a + c, name='d') 92 with ops.control_dependencies([b]): 93 e = array_ops.identity(3., name='e') 94 f = array_ops.identity(c + e, name='f') 95 return d, f 96 97 graph = f.get_concrete_function().graph 98 order = test_util.topological_sort_operations(graph.get_operations()) 99 a = graph.get_operation_by_name('a') 100 c = graph.get_operation_by_name('c') 101 d = graph.get_operation_by_name('d') 102 e = graph.get_operation_by_name('e') 103 f = graph.get_operation_by_name('f') 104 test_util.assert_sequential_execution(order, [a, d]) 105 test_util.assert_sequential_execution(order, [e, a, f]) 106 with self.assertRaises(AssertionError): 107 test_util.assert_sequential_execution(order, [a, c]) 108 with self.assertRaises(AssertionError): 109 test_util.assert_sequential_execution(order, [f, a, c]) 110 with self.assertRaises(AssertionError): 111 test_util.assert_sequential_execution(order, [d, e, a, c]) 112 113 114if __name__ == '__main__': 115 test_util.main() 116