• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Testing."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22
23# pylint: disable=g-bad-import-order
24from tensorflow.python.framework import test_util as _test_util
25from tensorflow.python.platform import googletest as _googletest
26
27# pylint: disable=unused-import
28from tensorflow.python.framework.test_util import assert_equal_graph_def
29from tensorflow.python.framework.test_util import create_local_cluster
30from tensorflow.python.framework.test_util import TensorFlowTestCase as TestCase
31from tensorflow.python.framework.test_util import gpu_device_name
32from tensorflow.python.framework.test_util import is_gpu_available
33
34from tensorflow.python.ops.gradient_checker import compute_gradient_error
35from tensorflow.python.ops.gradient_checker import compute_gradient
36# pylint: enable=unused-import,g-bad-import-order
37
38import functools
39
40import sys
41from tensorflow.python.util.tf_export import tf_export
42if sys.version_info.major == 2:
43  import mock                # pylint: disable=g-import-not-at-top,unused-import
44else:
45  from unittest import mock  # pylint: disable=g-import-not-at-top,g-importing-member
46
47tf_export(v1=['test.mock'])(mock)
48
49# Import Benchmark class
50Benchmark = _googletest.Benchmark  # pylint: disable=invalid-name
51
52# Import StubOutForTesting class
53StubOutForTesting = _googletest.StubOutForTesting  # pylint: disable=invalid-name
54
55
56@tf_export('test.main')
57def main(argv=None):
58  """Runs all unit tests."""
59  _test_util.InstallStackTraceHandler()
60  return _googletest.main(argv)
61
62
63@tf_export(v1=['test.get_temp_dir'])
64def get_temp_dir():
65  """Returns a temporary directory for use during tests.
66
67  There is no need to delete the directory after the test.
68
69  @compatibility(TF2)
70  This function is removed in TF2. Please use `TestCase.get_temp_dir` instead
71  in a test case.
72  Outside of a unit test, obtain a temporary directory through Python's
73  `tempfile` module.
74  @end_compatibility
75
76  Returns:
77    The temporary directory.
78  """
79  return _googletest.GetTempDir()
80
81
82@tf_export(v1=['test.test_src_dir_path'])
83def test_src_dir_path(relative_path):
84  """Creates an absolute test srcdir path given a relative path.
85
86  Args:
87    relative_path: a path relative to tensorflow root.
88      e.g. "core/platform".
89
90  Returns:
91    An absolute path to the linked in runfiles.
92  """
93  return _googletest.test_src_dir_path(relative_path)
94
95
96@tf_export('test.is_built_with_cuda')
97def is_built_with_cuda():
98  """Returns whether TensorFlow was built with CUDA (GPU) support.
99
100  This method should only be used in tests written with `tf.test.TestCase`. A
101  typical usage is to skip tests that should only run with CUDA (GPU).
102
103  >>> class MyTest(tf.test.TestCase):
104  ...
105  ...   def test_add_on_gpu(self):
106  ...     if not tf.test.is_built_with_cuda():
107  ...       self.skipTest("test is only applicable on GPU")
108  ...
109  ...     with tf.device("GPU:0"):
110  ...       self.assertEqual(tf.math.add(1.0, 2.0), 3.0)
111
112  TensorFlow official binary is built with CUDA.
113  """
114  return _test_util.IsGoogleCudaEnabled()
115
116
117@tf_export('test.is_built_with_rocm')
118def is_built_with_rocm():
119  """Returns whether TensorFlow was built with ROCm (GPU) support.
120
121  This method should only be used in tests written with `tf.test.TestCase`. A
122  typical usage is to skip tests that should only run with ROCm (GPU).
123
124  >>> class MyTest(tf.test.TestCase):
125  ...
126  ...   def test_add_on_gpu(self):
127  ...     if not tf.test.is_built_with_rocm():
128  ...       self.skipTest("test is only applicable on GPU")
129  ...
130  ...     with tf.device("GPU:0"):
131  ...       self.assertEqual(tf.math.add(1.0, 2.0), 3.0)
132
133  TensorFlow official binary is NOT built with ROCm.
134  """
135  return _test_util.IsBuiltWithROCm()
136
137
138@tf_export('test.disable_with_predicate')
139def disable_with_predicate(pred, skip_message):
140  """Disables the test if pred is true."""
141
142  def decorator_disable_with_predicate(func):
143
144    @functools.wraps(func)
145    def wrapper_disable_with_predicate(self, *args, **kwargs):
146      if pred():
147        self.skipTest(skip_message)
148      else:
149        return func(self, *args, **kwargs)
150
151    return wrapper_disable_with_predicate
152
153  return decorator_disable_with_predicate
154
155
156@tf_export('test.is_built_with_gpu_support')
157def is_built_with_gpu_support():
158  """Returns whether TensorFlow was built with GPU (CUDA or ROCm) support.
159
160  This method should only be used in tests written with `tf.test.TestCase`. A
161  typical usage is to skip tests that should only run with GPU.
162
163  >>> class MyTest(tf.test.TestCase):
164  ...
165  ...   def test_add_on_gpu(self):
166  ...     if not tf.test.is_built_with_gpu_support():
167  ...       self.skipTest("test is only applicable on GPU")
168  ...
169  ...     with tf.device("GPU:0"):
170  ...       self.assertEqual(tf.math.add(1.0, 2.0), 3.0)
171
172  TensorFlow official binary is built with CUDA GPU support.
173  """
174  return is_built_with_cuda() or is_built_with_rocm()
175
176
177@tf_export('test.is_built_with_xla')
178def is_built_with_xla():
179  """Returns whether TensorFlow was built with XLA support.
180
181  This method should only be used in tests written with `tf.test.TestCase`. A
182  typical usage is to skip tests that should only run with XLA.
183
184  >>> class MyTest(tf.test.TestCase):
185  ...
186  ...   def test_add_on_xla(self):
187  ...     if not tf.test.is_built_with_xla():
188  ...       self.skipTest("test is only applicable on XLA")
189
190  ...     @tf.function(jit_compile=True)
191  ...     def add(x, y):
192  ...       return tf.math.add(x, y)
193  ...
194  ...     self.assertEqual(add(tf.ones(()), tf.ones(())), 2.0)
195
196  TensorFlow official binary is built with XLA.
197  """
198  return _test_util.IsBuiltWithXLA()
199