• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Testing."""
17
18
19# pylint: disable=g-bad-import-order
20from tensorflow.python.framework import test_util as _test_util
21from tensorflow.python.platform import googletest as _googletest
22
23# pylint: disable=unused-import
24from tensorflow.python.framework.test_util import assert_equal_graph_def
25from tensorflow.python.framework.test_util import create_local_cluster
26from tensorflow.python.framework.test_util import TensorFlowTestCase as TestCase
27from tensorflow.python.framework.test_util import gpu_device_name
28from tensorflow.python.framework.test_util import is_gpu_available
29
30from tensorflow.python.ops.gradient_checker import compute_gradient_error
31from tensorflow.python.ops.gradient_checker import compute_gradient
32# pylint: enable=unused-import,g-bad-import-order
33
34import functools
35
36import sys
37from tensorflow.python.util.tf_export import tf_export
38if sys.version_info.major == 2:
39  import mock                # pylint: disable=g-import-not-at-top,unused-import
40else:
41  from unittest import mock  # pylint: disable=g-import-not-at-top,g-importing-member
42
43tf_export(v1=['test.mock'])(mock)
44
45# Import Benchmark class
46Benchmark = _googletest.Benchmark  # pylint: disable=invalid-name
47
48# Import StubOutForTesting class
49StubOutForTesting = _googletest.StubOutForTesting  # pylint: disable=invalid-name
50
51
52@tf_export('test.main')
53def main(argv=None):
54  """Runs all unit tests."""
55  _test_util.InstallStackTraceHandler()
56  return _googletest.main(argv)
57
58
59@tf_export(v1=['test.get_temp_dir'])
60def get_temp_dir():
61  """Returns a temporary directory for use during tests.
62
63  There is no need to delete the directory after the test.
64
65  @compatibility(TF2)
66  This function is removed in TF2. Please use `TestCase.get_temp_dir` instead
67  in a test case.
68  Outside of a unit test, obtain a temporary directory through Python's
69  `tempfile` module.
70  @end_compatibility
71
72  Returns:
73    The temporary directory.
74  """
75  return _googletest.GetTempDir()
76
77
78@tf_export(v1=['test.test_src_dir_path'])
79def test_src_dir_path(relative_path):
80  """Creates an absolute test srcdir path given a relative path.
81
82  Args:
83    relative_path: a path relative to tensorflow root.
84      e.g. "core/platform".
85
86  Returns:
87    An absolute path to the linked in runfiles.
88  """
89  return _googletest.test_src_dir_path(relative_path)
90
91
92@tf_export('test.is_built_with_cuda')
93def is_built_with_cuda():
94  """Returns whether TensorFlow was built with CUDA (GPU) support.
95
96  This method should only be used in tests written with `tf.test.TestCase`. A
97  typical usage is to skip tests that should only run with CUDA (GPU).
98
99  >>> class MyTest(tf.test.TestCase):
100  ...
101  ...   def test_add_on_gpu(self):
102  ...     if not tf.test.is_built_with_cuda():
103  ...       self.skipTest("test is only applicable on GPU")
104  ...
105  ...     with tf.device("GPU:0"):
106  ...       self.assertEqual(tf.math.add(1.0, 2.0), 3.0)
107
108  TensorFlow official binary is built with CUDA.
109  """
110  return _test_util.IsGoogleCudaEnabled()
111
112
113@tf_export('test.is_built_with_rocm')
114def is_built_with_rocm():
115  """Returns whether TensorFlow was built with ROCm (GPU) support.
116
117  This method should only be used in tests written with `tf.test.TestCase`. A
118  typical usage is to skip tests that should only run with ROCm (GPU).
119
120  >>> class MyTest(tf.test.TestCase):
121  ...
122  ...   def test_add_on_gpu(self):
123  ...     if not tf.test.is_built_with_rocm():
124  ...       self.skipTest("test is only applicable on GPU")
125  ...
126  ...     with tf.device("GPU:0"):
127  ...       self.assertEqual(tf.math.add(1.0, 2.0), 3.0)
128
129  TensorFlow official binary is NOT built with ROCm.
130  """
131  return _test_util.IsBuiltWithROCm()
132
133
134@tf_export('test.disable_with_predicate')
135def disable_with_predicate(pred, skip_message):
136  """Disables the test if pred is true."""
137
138  def decorator_disable_with_predicate(func):
139
140    @functools.wraps(func)
141    def wrapper_disable_with_predicate(self, *args, **kwargs):
142      if pred():
143        self.skipTest(skip_message)
144      else:
145        return func(self, *args, **kwargs)
146
147    return wrapper_disable_with_predicate
148
149  return decorator_disable_with_predicate
150
151
152@tf_export('test.is_built_with_gpu_support')
153def is_built_with_gpu_support():
154  """Returns whether TensorFlow was built with GPU (CUDA or ROCm) support.
155
156  This method should only be used in tests written with `tf.test.TestCase`. A
157  typical usage is to skip tests that should only run with GPU.
158
159  >>> class MyTest(tf.test.TestCase):
160  ...
161  ...   def test_add_on_gpu(self):
162  ...     if not tf.test.is_built_with_gpu_support():
163  ...       self.skipTest("test is only applicable on GPU")
164  ...
165  ...     with tf.device("GPU:0"):
166  ...       self.assertEqual(tf.math.add(1.0, 2.0), 3.0)
167
168  TensorFlow official binary is built with CUDA GPU support.
169  """
170  return is_built_with_cuda() or is_built_with_rocm()
171
172
173@tf_export('test.is_built_with_xla')
174def is_built_with_xla():
175  """Returns whether TensorFlow was built with XLA support.
176
177  This method should only be used in tests written with `tf.test.TestCase`. A
178  typical usage is to skip tests that should only run with XLA.
179
180  >>> class MyTest(tf.test.TestCase):
181  ...
182  ...   def test_add_on_xla(self):
183  ...     if not tf.test.is_built_with_xla():
184  ...       self.skipTest("test is only applicable on XLA")
185
186  ...     @tf.function(jit_compile=True)
187  ...     def add(x, y):
188  ...       return tf.math.add(x, y)
189  ...
190  ...     self.assertEqual(add(tf.ones(()), tf.ones(())), 2.0)
191
192  TensorFlow official binary is built with XLA.
193  """
194  return _test_util.IsBuiltWithXLA()
195