• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python2, python3
2# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Tests to improve the the of tensorflow.
17
18Basic tests show how to use the consistency test to test against function,
19eager, and xla function modes.
20"""
21
22import tensorflow as tf
23from tensorflow.python.platform import test
24from tensorflow.tools.consistency_integration_test.consistency_test_base import ConsistencyTestBase
25from tensorflow.tools.consistency_integration_test.consistency_test_base import Example
26from tensorflow.tools.consistency_integration_test.consistency_test_base import RunMode
27
28
29class BasicTests(ConsistencyTestBase):
30  """A few basic tests that are examples for other test writers."""
31
32  def testSquare(self):
33    """Test basic testing infrastructure."""
34
35    def f(x):
36      return x * x
37
38    # Tests involving type promotions are added to:
39    # //tensorflow/tools/consistency_integration_test/type_promotion_tests.py
40    self._generic_test(f, [
41        Example(arg=(3,), out=9., failure=[], bugs=[]),
42        Example(
43            arg=(tf.constant(3.),), out=tf.constant(9.), failure=[], bugs=[]),
44    ])
45
46  def testObjectInput(self):
47    """Test taking a Python object. Should work in tf.function but not sm."""
48
49    class A:
50
51      def __init__(self):
52        self.value = 3.0
53
54    def f(x):
55      return x.value
56
57    self._generic_test(
58        f, [Example(arg=(A(),), out=3.0, failure=[RunMode.SAVED], bugs=[])])
59    return
60
61  def testObjectOutput(self):
62    """Test returning a Python object. Doesn't and shouldn't work."""
63
64    class A:
65
66      def __init__(self, x):
67        self.value = x
68
69    def f(x):
70      return A(x)
71
72    self._generic_test(f, [
73        Example(
74            arg=(3.,),
75            out=3.0,
76            # It should fail in all modes.
77            failure=[RunMode.RAW, RunMode.XLA, RunMode.FUNCTION, RunMode.SAVED],
78            bugs=[])
79    ])
80    return
81
82  def testNotEqualOutput(self):
83    """Tests that an error is thrown if the outputs are not equal.
84
85    This test case is meant to test the consistency test infrastructure that the
86    output of executing `f()` matches the groundtruth we provide as the `out`
87    param in `_generic_test()`.
88    """
89    mock_func = test.mock.MagicMock(name='method')
90    mock_func.return_value = 0  # This differs from the `expected` value below.
91    mock_func.__doc__ = 'Tested with a mock function.'
92
93    failure_modes = [RunMode.RAW, RunMode.FUNCTION, RunMode.XLA, RunMode.SAVED]
94    input_args = [3, 3.2, tf.constant(3.)]
95    expected = 1  # Randomly picked value just for testing purposes.
96
97    for input_arg in input_args:
98      self._generic_test(mock_func, [
99          Example(
100              arg=(input_arg,), out=expected, failure=failure_modes, bugs=[])
101      ])
102
103  def testSkipModes(self):
104    """Tests `skip_modes` option available with `_generic_test`."""
105
106    class A:
107
108      def __init__(self, x):
109        self.value = x
110
111    def f(x):
112      return A(x)
113
114    self._generic_test(
115        f, [Example(arg=(3.,), out=3.0, failure=[], bugs=[])],
116        # Skip all tests as the test will fail in all modes.
117        skip_modes=[RunMode.RAW, RunMode.XLA, RunMode.FUNCTION, RunMode.SAVED])
118    return
119
120  def testTensorArrayBasic(self):
121    """Tests `_generic_test` with a `tf.TensorArray` as input to tf.function."""
122
123    def f(x):
124      return x.stack()
125
126    ta = tf.TensorArray(dtype=tf.int32, dynamic_size=True, size=0)
127    ta = ta.write(0, tf.constant([1, 2, 3]))
128    ta = ta.write(1, tf.constant([4, 5, 6]))
129
130    self._generic_test(
131        f,
132        [
133            Example(
134                arg=(ta,),
135                out=tf.constant([[1, 2, 3], [4, 5, 6]]),
136                failure=[RunMode.SAVED],  # TODO(b/187250924): Investigate.
137                bugs=['b/180921284'])
138        ])
139    return
140
141  def testFailureParamAsDict(self):
142    """Tests passing in a `dict` for `failure` param to `_generic_test`."""
143
144    def f(ta):
145      return ta.stack()
146
147    ta = tf.TensorArray(dtype=tf.float32, dynamic_size=True, size=0)
148    ta = ta.write(0, tf.constant([1.0, 2.0]))
149    ta = ta.write(1, tf.constant([3.0, 4.0]))
150
151    out_t = tf.constant([[1.0, 2.0], [3.0, 4.0]])
152    input_signature = [tf.TensorArraySpec(element_shape=None,
153                                          dtype=tf.float32,
154                                          dynamic_size=True)]
155
156    self._generic_test(
157        f,
158        [
159            Example(
160                arg=(ta,),
161                out=out_t,
162                failure={
163                    RunMode.FUNCTION:
164                        'If shallow structure is a sequence, input must also '
165                        'be a sequence',
166                    RunMode.XLA:
167                        'If shallow structure is a sequence, input must also '
168                        'be a sequence',
169                    RunMode.SAVED:
170                        'Found zero restored functions for caller function',
171                },
172                bugs=['b/162452468'])
173        ],
174        input_signature=input_signature,
175        skip_modes=[])
176    return
177
178
179if __name__ == '__main__':
180  test.main()
181