• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests generating test combinations."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from collections import OrderedDict
22
23from absl.testing import parameterized
24
25from tensorflow.python.framework import test_combinations as combinations
26from tensorflow.python.eager import test
27
28
29class TestingCombinationsTest(test.TestCase):
30
31  def test_combine(self):
32    self.assertEqual([{
33        "a": 1,
34        "b": 2
35    }, {
36        "a": 1,
37        "b": 3
38    }, {
39        "a": 2,
40        "b": 2
41    }, {
42        "a": 2,
43        "b": 3
44    }], combinations.combine(a=[1, 2], b=[2, 3]))
45
46  def test_arguments_sorted(self):
47    self.assertEqual([
48        OrderedDict([("aa", 1), ("ab", 2)]),
49        OrderedDict([("aa", 1), ("ab", 3)]),
50        OrderedDict([("aa", 2), ("ab", 2)]),
51        OrderedDict([("aa", 2), ("ab", 3)])
52    ], combinations.combine(ab=[2, 3], aa=[1, 2]))
53
54  def test_combine_single_parameter(self):
55    self.assertEqual([{
56        "a": 1,
57        "b": 2
58    }, {
59        "a": 2,
60        "b": 2
61    }], combinations.combine(a=[1, 2], b=2))
62
63  def test_add(self):
64    self.assertEqual(
65        [{
66            "a": 1
67        }, {
68            "a": 2
69        }, {
70            "b": 2
71        }, {
72            "b": 3
73        }],
74        combinations.combine(a=[1, 2]) + combinations.combine(b=[2, 3]))
75
76  def test_times(self):
77    c1 = combinations.combine(mode=["graph"], loss=["callable", "tensor"])
78    c2 = combinations.combine(mode=["eager"], loss=["callable"])
79    c3 = combinations.combine(distribution=["d1", "d2"])
80    c4 = combinations.times(c3, c1 + c2)
81    self.assertEqual([
82        OrderedDict([("distribution", "d1"), ("loss", "callable"),
83                     ("mode", "graph")]),
84        OrderedDict([("distribution", "d1"), ("loss", "tensor"),
85                     ("mode", "graph")]),
86        OrderedDict([("distribution", "d1"), ("loss", "callable"),
87                     ("mode", "eager")]),
88        OrderedDict([("distribution", "d2"), ("loss", "callable"),
89                     ("mode", "graph")]),
90        OrderedDict([("distribution", "d2"), ("loss", "tensor"),
91                     ("mode", "graph")]),
92        OrderedDict([("distribution", "d2"), ("loss", "callable"),
93                     ("mode", "eager")])
94    ], c4)
95
96  def test_times_variable_arguments(self):
97    c1 = combinations.combine(mode=["graph", "eager"])
98    c2 = combinations.combine(optimizer=["adam", "gd"])
99    c3 = combinations.combine(distribution=["d1", "d2"])
100    c4 = combinations.times(c3, c1, c2)
101    self.assertEqual([
102        OrderedDict([("distribution", "d1"), ("mode", "graph"),
103                     ("optimizer", "adam")]),
104        OrderedDict([("distribution", "d1"), ("mode", "graph"),
105                     ("optimizer", "gd")]),
106        OrderedDict([("distribution", "d1"), ("mode", "eager"),
107                     ("optimizer", "adam")]),
108        OrderedDict([("distribution", "d1"), ("mode", "eager"),
109                     ("optimizer", "gd")]),
110        OrderedDict([("distribution", "d2"), ("mode", "graph"),
111                     ("optimizer", "adam")]),
112        OrderedDict([("distribution", "d2"), ("mode", "graph"),
113                     ("optimizer", "gd")]),
114        OrderedDict([("distribution", "d2"), ("mode", "eager"),
115                     ("optimizer", "adam")]),
116        OrderedDict([("distribution", "d2"), ("mode", "eager"),
117                     ("optimizer", "gd")])
118    ], c4)
119    self.assertEqual(
120        combinations.combine(
121            mode=["graph", "eager"],
122            optimizer=["adam", "gd"],
123            distribution=["d1", "d2"]), c4)
124
125  def test_overlapping_keys(self):
126    c1 = combinations.combine(mode=["graph"], loss=["callable", "tensor"])
127    c2 = combinations.combine(mode=["eager"], loss=["callable"])
128    with self.assertRaisesRegex(ValueError, ".*Keys.+overlap.+"):
129      _ = combinations.times(c1, c2)
130
131
132@combinations.generate(combinations.combine(a=[1, 0], b=[2, 3], c=[1]))
133class CombineTheTestSuite(parameterized.TestCase):
134
135  def test_add_things(self, a, b, c):
136    self.assertLessEqual(3, a + b + c)
137    self.assertLessEqual(a + b + c, 5)
138
139  def test_add_things_one_more(self, a, b, c):
140    self.assertLessEqual(3, a + b + c)
141    self.assertLessEqual(a + b + c, 5)
142
143  def not_a_test(self, a=0, b=0, c=0):
144    del a, b, c
145    self.fail()
146
147  def _test_but_private(self, a=0, b=0, c=0):
148    del a, b, c
149    self.fail()
150
151  # Check that nothing funny happens to a non-callable that starts with "_test".
152  test_member = 0
153
154
155if __name__ == "__main__":
156  test.main()
157