1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Testing.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22 23# pylint: disable=g-bad-import-order 24from tensorflow.python.framework import test_util as _test_util 25from tensorflow.python.platform import googletest as _googletest 26 27# pylint: disable=unused-import 28from tensorflow.python.framework.test_util import assert_equal_graph_def 29from tensorflow.python.framework.test_util import create_local_cluster 30from tensorflow.python.framework.test_util import TensorFlowTestCase as TestCase 31from tensorflow.python.framework.test_util import gpu_device_name 32from tensorflow.python.framework.test_util import is_gpu_available 33 34from tensorflow.python.ops.gradient_checker import compute_gradient_error 35from tensorflow.python.ops.gradient_checker import compute_gradient 36# pylint: enable=unused-import,g-bad-import-order 37 38import functools 39 40import sys 41from tensorflow.python.util.tf_export import tf_export 42if sys.version_info.major == 2: 43 import mock # pylint: disable=g-import-not-at-top,unused-import 44else: 45 from unittest import mock # pylint: disable=g-import-not-at-top,g-importing-member 46 47tf_export(v1=['test.mock'])(mock) 48 49# Import Benchmark class 50Benchmark = _googletest.Benchmark # pylint: disable=invalid-name 51 52# Import StubOutForTesting class 53StubOutForTesting = _googletest.StubOutForTesting # pylint: disable=invalid-name 54 55 56@tf_export('test.main') 57def main(argv=None): 58 """Runs all unit tests.""" 59 _test_util.InstallStackTraceHandler() 60 return _googletest.main(argv) 61 62 63@tf_export(v1=['test.get_temp_dir']) 64def get_temp_dir(): 65 """Returns a temporary directory for use during tests. 66 67 There is no need to delete the directory after the test. 68 69 @compatibility(TF2) 70 This function is removed in TF2. Please use `TestCase.get_temp_dir` instead 71 in a test case. 72 Outside of a unit test, obtain a temporary directory through Python's 73 `tempfile` module. 74 @end_compatibility 75 76 Returns: 77 The temporary directory. 78 """ 79 return _googletest.GetTempDir() 80 81 82@tf_export(v1=['test.test_src_dir_path']) 83def test_src_dir_path(relative_path): 84 """Creates an absolute test srcdir path given a relative path. 85 86 Args: 87 relative_path: a path relative to tensorflow root. 88 e.g. "core/platform". 89 90 Returns: 91 An absolute path to the linked in runfiles. 92 """ 93 return _googletest.test_src_dir_path(relative_path) 94 95 96@tf_export('test.is_built_with_cuda') 97def is_built_with_cuda(): 98 """Returns whether TensorFlow was built with CUDA (GPU) support. 99 100 This method should only be used in tests written with `tf.test.TestCase`. A 101 typical usage is to skip tests that should only run with CUDA (GPU). 102 103 >>> class MyTest(tf.test.TestCase): 104 ... 105 ... def test_add_on_gpu(self): 106 ... if not tf.test.is_built_with_cuda(): 107 ... self.skipTest("test is only applicable on GPU") 108 ... 109 ... with tf.device("GPU:0"): 110 ... self.assertEqual(tf.math.add(1.0, 2.0), 3.0) 111 112 TensorFlow official binary is built with CUDA. 113 """ 114 return _test_util.IsGoogleCudaEnabled() 115 116 117@tf_export('test.is_built_with_rocm') 118def is_built_with_rocm(): 119 """Returns whether TensorFlow was built with ROCm (GPU) support. 120 121 This method should only be used in tests written with `tf.test.TestCase`. A 122 typical usage is to skip tests that should only run with ROCm (GPU). 123 124 >>> class MyTest(tf.test.TestCase): 125 ... 126 ... def test_add_on_gpu(self): 127 ... if not tf.test.is_built_with_rocm(): 128 ... self.skipTest("test is only applicable on GPU") 129 ... 130 ... with tf.device("GPU:0"): 131 ... self.assertEqual(tf.math.add(1.0, 2.0), 3.0) 132 133 TensorFlow official binary is NOT built with ROCm. 134 """ 135 return _test_util.IsBuiltWithROCm() 136 137 138@tf_export('test.disable_with_predicate') 139def disable_with_predicate(pred, skip_message): 140 """Disables the test if pred is true.""" 141 142 def decorator_disable_with_predicate(func): 143 144 @functools.wraps(func) 145 def wrapper_disable_with_predicate(self, *args, **kwargs): 146 if pred(): 147 self.skipTest(skip_message) 148 else: 149 return func(self, *args, **kwargs) 150 151 return wrapper_disable_with_predicate 152 153 return decorator_disable_with_predicate 154 155 156@tf_export('test.is_built_with_gpu_support') 157def is_built_with_gpu_support(): 158 """Returns whether TensorFlow was built with GPU (CUDA or ROCm) support. 159 160 This method should only be used in tests written with `tf.test.TestCase`. A 161 typical usage is to skip tests that should only run with GPU. 162 163 >>> class MyTest(tf.test.TestCase): 164 ... 165 ... def test_add_on_gpu(self): 166 ... if not tf.test.is_built_with_gpu_support(): 167 ... self.skipTest("test is only applicable on GPU") 168 ... 169 ... with tf.device("GPU:0"): 170 ... self.assertEqual(tf.math.add(1.0, 2.0), 3.0) 171 172 TensorFlow official binary is built with CUDA GPU support. 173 """ 174 return is_built_with_cuda() or is_built_with_rocm() 175 176 177@tf_export('test.is_built_with_xla') 178def is_built_with_xla(): 179 """Returns whether TensorFlow was built with XLA support. 180 181 This method should only be used in tests written with `tf.test.TestCase`. A 182 typical usage is to skip tests that should only run with XLA. 183 184 >>> class MyTest(tf.test.TestCase): 185 ... 186 ... def test_add_on_xla(self): 187 ... if not tf.test.is_built_with_xla(): 188 ... self.skipTest("test is only applicable on XLA") 189 190 ... @tf.function(jit_compile=True) 191 ... def add(x, y): 192 ... return tf.math.add(x, y) 193 ... 194 ... self.assertEqual(add(tf.ones(()), tf.ones(())), 2.0) 195 196 TensorFlow official binary is built with XLA. 197 """ 198 return _test_util.IsBuiltWithXLA() 199