1# Lint as: python2, python3 2# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Tests to improve the the of tensorflow. 17 18Basic tests show how to use the consistency test to test against function, 19eager, and xla function modes. 20""" 21 22import tensorflow as tf 23from tensorflow.python.platform import test 24from tensorflow.tools.consistency_integration_test.consistency_test_base import ConsistencyTestBase 25from tensorflow.tools.consistency_integration_test.consistency_test_base import Example 26from tensorflow.tools.consistency_integration_test.consistency_test_base import RunMode 27 28 29class BasicTests(ConsistencyTestBase): 30 """A few basic tests that are examples for other test writers.""" 31 32 def testSquare(self): 33 """Test basic testing infrastructure.""" 34 35 def f(x): 36 return x * x 37 38 # Tests involving type promotions are added to: 39 # //tensorflow/tools/consistency_integration_test/type_promotion_tests.py 40 self._generic_test(f, [ 41 Example(arg=(3,), out=9., failure=[], bugs=[]), 42 Example( 43 arg=(tf.constant(3.),), out=tf.constant(9.), failure=[], bugs=[]), 44 ]) 45 46 def testObjectInput(self): 47 """Test taking a Python object. Should work in tf.function but not sm.""" 48 49 class A: 50 51 def __init__(self): 52 self.value = 3.0 53 54 def f(x): 55 return x.value 56 57 self._generic_test( 58 f, [Example(arg=(A(),), out=3.0, failure=[RunMode.SAVED], bugs=[])]) 59 return 60 61 def testObjectOutput(self): 62 """Test returning a Python object. Doesn't and shouldn't work.""" 63 64 class A: 65 66 def __init__(self, x): 67 self.value = x 68 69 def f(x): 70 return A(x) 71 72 self._generic_test(f, [ 73 Example( 74 arg=(3.,), 75 out=3.0, 76 # It should fail in all modes. 77 failure=[RunMode.RAW, RunMode.XLA, RunMode.FUNCTION, RunMode.SAVED], 78 bugs=[]) 79 ]) 80 return 81 82 def testNotEqualOutput(self): 83 """Tests that an error is thrown if the outputs are not equal. 84 85 This test case is meant to test the consistency test infrastructure that the 86 output of executing `f()` matches the groundtruth we provide as the `out` 87 param in `_generic_test()`. 88 """ 89 mock_func = test.mock.MagicMock(name='method') 90 mock_func.return_value = 0 # This differs from the `expected` value below. 91 mock_func.__doc__ = 'Tested with a mock function.' 92 93 failure_modes = [RunMode.RAW, RunMode.FUNCTION, RunMode.XLA, RunMode.SAVED] 94 input_args = [3, 3.2, tf.constant(3.)] 95 expected = 1 # Randomly picked value just for testing purposes. 96 97 for input_arg in input_args: 98 self._generic_test(mock_func, [ 99 Example( 100 arg=(input_arg,), out=expected, failure=failure_modes, bugs=[]) 101 ]) 102 103 def testSkipModes(self): 104 """Tests `skip_modes` option available with `_generic_test`.""" 105 106 class A: 107 108 def __init__(self, x): 109 self.value = x 110 111 def f(x): 112 return A(x) 113 114 self._generic_test( 115 f, [Example(arg=(3.,), out=3.0, failure=[], bugs=[])], 116 # Skip all tests as the test will fail in all modes. 117 skip_modes=[RunMode.RAW, RunMode.XLA, RunMode.FUNCTION, RunMode.SAVED]) 118 return 119 120 def testTensorArrayBasic(self): 121 """Tests `_generic_test` with a `tf.TensorArray` as input to tf.function.""" 122 123 def f(x): 124 return x.stack() 125 126 ta = tf.TensorArray(dtype=tf.int32, dynamic_size=True, size=0) 127 ta = ta.write(0, tf.constant([1, 2, 3])) 128 ta = ta.write(1, tf.constant([4, 5, 6])) 129 130 self._generic_test( 131 f, 132 [ 133 Example( 134 arg=(ta,), 135 out=tf.constant([[1, 2, 3], [4, 5, 6]]), 136 failure=[RunMode.SAVED], # TODO(b/187250924): Investigate. 137 bugs=['b/180921284']) 138 ]) 139 return 140 141 def testFailureParamAsDict(self): 142 """Tests passing in a `dict` for `failure` param to `_generic_test`.""" 143 144 def f(ta): 145 return ta.stack() 146 147 ta = tf.TensorArray(dtype=tf.float32, dynamic_size=True, size=0) 148 ta = ta.write(0, tf.constant([1.0, 2.0])) 149 ta = ta.write(1, tf.constant([3.0, 4.0])) 150 151 out_t = tf.constant([[1.0, 2.0], [3.0, 4.0]]) 152 input_signature = [tf.TensorArraySpec(element_shape=None, 153 dtype=tf.float32, 154 dynamic_size=True)] 155 156 self._generic_test( 157 f, 158 [ 159 Example( 160 arg=(ta,), 161 out=out_t, 162 failure={ 163 RunMode.FUNCTION: 164 'If shallow structure is a sequence, input must also ' 165 'be a sequence', 166 RunMode.XLA: 167 'If shallow structure is a sequence, input must also ' 168 'be a sequence', 169 RunMode.SAVED: 170 'Found zero restored functions for caller function', 171 }, 172 bugs=['b/162452468']) 173 ], 174 input_signature=input_signature, 175 skip_modes=[]) 176 return 177 178 179if __name__ == '__main__': 180 test.main() 181