1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests generating test combinations.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from collections import OrderedDict 22 23from absl.testing import parameterized 24 25from tensorflow.python.framework import test_combinations as combinations 26from tensorflow.python.eager import test 27 28 29class TestingCombinationsTest(test.TestCase): 30 31 def test_combine(self): 32 self.assertEqual([{ 33 "a": 1, 34 "b": 2 35 }, { 36 "a": 1, 37 "b": 3 38 }, { 39 "a": 2, 40 "b": 2 41 }, { 42 "a": 2, 43 "b": 3 44 }], combinations.combine(a=[1, 2], b=[2, 3])) 45 46 def test_arguments_sorted(self): 47 self.assertEqual([ 48 OrderedDict([("aa", 1), ("ab", 2)]), 49 OrderedDict([("aa", 1), ("ab", 3)]), 50 OrderedDict([("aa", 2), ("ab", 2)]), 51 OrderedDict([("aa", 2), ("ab", 3)]) 52 ], combinations.combine(ab=[2, 3], aa=[1, 2])) 53 54 def test_combine_single_parameter(self): 55 self.assertEqual([{ 56 "a": 1, 57 "b": 2 58 }, { 59 "a": 2, 60 "b": 2 61 }], combinations.combine(a=[1, 2], b=2)) 62 63 def test_add(self): 64 self.assertEqual( 65 [{ 66 "a": 1 67 }, { 68 "a": 2 69 }, { 70 "b": 2 71 }, { 72 "b": 3 73 }], 74 combinations.combine(a=[1, 2]) + combinations.combine(b=[2, 3])) 75 76 def test_times(self): 77 c1 = combinations.combine(mode=["graph"], loss=["callable", "tensor"]) 78 c2 = combinations.combine(mode=["eager"], loss=["callable"]) 79 c3 = combinations.combine(distribution=["d1", "d2"]) 80 c4 = combinations.times(c3, c1 + c2) 81 self.assertEqual([ 82 OrderedDict([("distribution", "d1"), ("loss", "callable"), 83 ("mode", "graph")]), 84 OrderedDict([("distribution", "d1"), ("loss", "tensor"), 85 ("mode", "graph")]), 86 OrderedDict([("distribution", "d1"), ("loss", "callable"), 87 ("mode", "eager")]), 88 OrderedDict([("distribution", "d2"), ("loss", "callable"), 89 ("mode", "graph")]), 90 OrderedDict([("distribution", "d2"), ("loss", "tensor"), 91 ("mode", "graph")]), 92 OrderedDict([("distribution", "d2"), ("loss", "callable"), 93 ("mode", "eager")]) 94 ], c4) 95 96 def test_times_variable_arguments(self): 97 c1 = combinations.combine(mode=["graph", "eager"]) 98 c2 = combinations.combine(optimizer=["adam", "gd"]) 99 c3 = combinations.combine(distribution=["d1", "d2"]) 100 c4 = combinations.times(c3, c1, c2) 101 self.assertEqual([ 102 OrderedDict([("distribution", "d1"), ("mode", "graph"), 103 ("optimizer", "adam")]), 104 OrderedDict([("distribution", "d1"), ("mode", "graph"), 105 ("optimizer", "gd")]), 106 OrderedDict([("distribution", "d1"), ("mode", "eager"), 107 ("optimizer", "adam")]), 108 OrderedDict([("distribution", "d1"), ("mode", "eager"), 109 ("optimizer", "gd")]), 110 OrderedDict([("distribution", "d2"), ("mode", "graph"), 111 ("optimizer", "adam")]), 112 OrderedDict([("distribution", "d2"), ("mode", "graph"), 113 ("optimizer", "gd")]), 114 OrderedDict([("distribution", "d2"), ("mode", "eager"), 115 ("optimizer", "adam")]), 116 OrderedDict([("distribution", "d2"), ("mode", "eager"), 117 ("optimizer", "gd")]) 118 ], c4) 119 self.assertEqual( 120 combinations.combine( 121 mode=["graph", "eager"], 122 optimizer=["adam", "gd"], 123 distribution=["d1", "d2"]), c4) 124 125 def test_overlapping_keys(self): 126 c1 = combinations.combine(mode=["graph"], loss=["callable", "tensor"]) 127 c2 = combinations.combine(mode=["eager"], loss=["callable"]) 128 with self.assertRaisesRegex(ValueError, ".*Keys.+overlap.+"): 129 _ = combinations.times(c1, c2) 130 131 132@combinations.generate(combinations.combine(a=[1, 0], b=[2, 3], c=[1])) 133class CombineTheTestSuite(parameterized.TestCase): 134 135 def test_add_things(self, a, b, c): 136 self.assertLessEqual(3, a + b + c) 137 self.assertLessEqual(a + b + c, 5) 138 139 def test_add_things_one_more(self, a, b, c): 140 self.assertLessEqual(3, a + b + c) 141 self.assertLessEqual(a + b + c, 5) 142 143 def not_a_test(self, a=0, b=0, c=0): 144 del a, b, c 145 self.fail() 146 147 def _test_but_private(self, a=0, b=0, c=0): 148 del a, b, c 149 self.fail() 150 151 # Check that nothing funny happens to a non-callable that starts with "_test". 152 test_member = 0 153 154 155if __name__ == "__main__": 156 test.main() 157