1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Testing.""" 17 18 19# pylint: disable=g-bad-import-order 20from tensorflow.python.framework import test_util as _test_util 21from tensorflow.python.platform import googletest as _googletest 22 23# pylint: disable=unused-import 24from tensorflow.python.framework.test_util import assert_equal_graph_def 25from tensorflow.python.framework.test_util import create_local_cluster 26from tensorflow.python.framework.test_util import TensorFlowTestCase as TestCase 27from tensorflow.python.framework.test_util import gpu_device_name 28from tensorflow.python.framework.test_util import is_gpu_available 29 30from tensorflow.python.ops.gradient_checker import compute_gradient_error 31from tensorflow.python.ops.gradient_checker import compute_gradient 32# pylint: enable=unused-import,g-bad-import-order 33 34import functools 35 36import sys 37from tensorflow.python.util.tf_export import tf_export 38if sys.version_info.major == 2: 39 import mock # pylint: disable=g-import-not-at-top,unused-import 40else: 41 from unittest import mock # pylint: disable=g-import-not-at-top,g-importing-member 42 43tf_export(v1=['test.mock'])(mock) 44 45# Import Benchmark class 46Benchmark = _googletest.Benchmark # pylint: disable=invalid-name 47 48# Import StubOutForTesting class 49StubOutForTesting = _googletest.StubOutForTesting # pylint: disable=invalid-name 50 51 52@tf_export('test.main') 53def main(argv=None): 54 """Runs all unit tests.""" 55 _test_util.InstallStackTraceHandler() 56 return _googletest.main(argv) 57 58 59@tf_export(v1=['test.get_temp_dir']) 60def get_temp_dir(): 61 """Returns a temporary directory for use during tests. 62 63 There is no need to delete the directory after the test. 64 65 @compatibility(TF2) 66 This function is removed in TF2. Please use `TestCase.get_temp_dir` instead 67 in a test case. 68 Outside of a unit test, obtain a temporary directory through Python's 69 `tempfile` module. 70 @end_compatibility 71 72 Returns: 73 The temporary directory. 74 """ 75 return _googletest.GetTempDir() 76 77 78@tf_export(v1=['test.test_src_dir_path']) 79def test_src_dir_path(relative_path): 80 """Creates an absolute test srcdir path given a relative path. 81 82 Args: 83 relative_path: a path relative to tensorflow root. 84 e.g. "core/platform". 85 86 Returns: 87 An absolute path to the linked in runfiles. 88 """ 89 return _googletest.test_src_dir_path(relative_path) 90 91 92@tf_export('test.is_built_with_cuda') 93def is_built_with_cuda(): 94 """Returns whether TensorFlow was built with CUDA (GPU) support. 95 96 This method should only be used in tests written with `tf.test.TestCase`. A 97 typical usage is to skip tests that should only run with CUDA (GPU). 98 99 >>> class MyTest(tf.test.TestCase): 100 ... 101 ... def test_add_on_gpu(self): 102 ... if not tf.test.is_built_with_cuda(): 103 ... self.skipTest("test is only applicable on GPU") 104 ... 105 ... with tf.device("GPU:0"): 106 ... self.assertEqual(tf.math.add(1.0, 2.0), 3.0) 107 108 TensorFlow official binary is built with CUDA. 109 """ 110 return _test_util.IsGoogleCudaEnabled() 111 112 113@tf_export('test.is_built_with_rocm') 114def is_built_with_rocm(): 115 """Returns whether TensorFlow was built with ROCm (GPU) support. 116 117 This method should only be used in tests written with `tf.test.TestCase`. A 118 typical usage is to skip tests that should only run with ROCm (GPU). 119 120 >>> class MyTest(tf.test.TestCase): 121 ... 122 ... def test_add_on_gpu(self): 123 ... if not tf.test.is_built_with_rocm(): 124 ... self.skipTest("test is only applicable on GPU") 125 ... 126 ... with tf.device("GPU:0"): 127 ... self.assertEqual(tf.math.add(1.0, 2.0), 3.0) 128 129 TensorFlow official binary is NOT built with ROCm. 130 """ 131 return _test_util.IsBuiltWithROCm() 132 133 134@tf_export('test.disable_with_predicate') 135def disable_with_predicate(pred, skip_message): 136 """Disables the test if pred is true.""" 137 138 def decorator_disable_with_predicate(func): 139 140 @functools.wraps(func) 141 def wrapper_disable_with_predicate(self, *args, **kwargs): 142 if pred(): 143 self.skipTest(skip_message) 144 else: 145 return func(self, *args, **kwargs) 146 147 return wrapper_disable_with_predicate 148 149 return decorator_disable_with_predicate 150 151 152@tf_export('test.is_built_with_gpu_support') 153def is_built_with_gpu_support(): 154 """Returns whether TensorFlow was built with GPU (CUDA or ROCm) support. 155 156 This method should only be used in tests written with `tf.test.TestCase`. A 157 typical usage is to skip tests that should only run with GPU. 158 159 >>> class MyTest(tf.test.TestCase): 160 ... 161 ... def test_add_on_gpu(self): 162 ... if not tf.test.is_built_with_gpu_support(): 163 ... self.skipTest("test is only applicable on GPU") 164 ... 165 ... with tf.device("GPU:0"): 166 ... self.assertEqual(tf.math.add(1.0, 2.0), 3.0) 167 168 TensorFlow official binary is built with CUDA GPU support. 169 """ 170 return is_built_with_cuda() or is_built_with_rocm() 171 172 173@tf_export('test.is_built_with_xla') 174def is_built_with_xla(): 175 """Returns whether TensorFlow was built with XLA support. 176 177 This method should only be used in tests written with `tf.test.TestCase`. A 178 typical usage is to skip tests that should only run with XLA. 179 180 >>> class MyTest(tf.test.TestCase): 181 ... 182 ... def test_add_on_xla(self): 183 ... if not tf.test.is_built_with_xla(): 184 ... self.skipTest("test is only applicable on XLA") 185 186 ... @tf.function(jit_compile=True) 187 ... def add(x, y): 188 ... return tf.math.add(x, y) 189 ... 190 ... self.assertEqual(add(tf.ones(()), tf.ones(())), 2.0) 191 192 TensorFlow official binary is built with XLA. 193 """ 194 return _test_util.IsBuiltWithXLA() 195