• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Testing.
17
18See the [Testing](https://tensorflow.org/api_guides/python/test) guide.
19
20Note: `tf.test.mock` is an alias to the python `mock` or `unittest.mock`
21depending on the python version.
22"""
23
24from __future__ import absolute_import
25from __future__ import division
26from __future__ import print_function
27
28
29# pylint: disable=g-bad-import-order
30from tensorflow.python.framework import test_util as _test_util
31from tensorflow.python.platform import googletest as _googletest
32
33# pylint: disable=unused-import
34from tensorflow.python.framework.test_util import assert_equal_graph_def
35from tensorflow.python.framework.test_util import create_local_cluster
36from tensorflow.python.framework.test_util import TensorFlowTestCase as TestCase
37from tensorflow.python.framework.test_util import gpu_device_name
38from tensorflow.python.framework.test_util import is_gpu_available
39
40from tensorflow.python.ops.gradient_checker import compute_gradient_error
41from tensorflow.python.ops.gradient_checker import compute_gradient
42# pylint: enable=unused-import,g-bad-import-order
43
44import sys
45from tensorflow.python.util.tf_export import tf_export
46if sys.version_info.major == 2:
47  import mock                # pylint: disable=g-import-not-at-top,unused-import
48else:
49  from unittest import mock  # pylint: disable=g-import-not-at-top,g-importing-member
50
51tf_export(v1=['test.mock'])(mock)
52
53# Import Benchmark class
54Benchmark = _googletest.Benchmark  # pylint: disable=invalid-name
55
56# Import StubOutForTesting class
57StubOutForTesting = _googletest.StubOutForTesting  # pylint: disable=invalid-name
58
59
60@tf_export('test.main')
61def main(argv=None):
62  """Runs all unit tests."""
63  _test_util.InstallStackTraceHandler()
64  return _googletest.main(argv)
65
66
67@tf_export(v1=['test.get_temp_dir'])
68def get_temp_dir():
69  """Returns a temporary directory for use during tests.
70
71  There is no need to delete the directory after the test.
72
73  Returns:
74    The temporary directory.
75  """
76  return _googletest.GetTempDir()
77
78
79@tf_export(v1=['test.test_src_dir_path'])
80def test_src_dir_path(relative_path):
81  """Creates an absolute test srcdir path given a relative path.
82
83  Args:
84    relative_path: a path relative to tensorflow root.
85      e.g. "core/platform".
86
87  Returns:
88    An absolute path to the linked in runfiles.
89  """
90  return _googletest.test_src_dir_path(relative_path)
91
92
93@tf_export('test.is_built_with_cuda')
94def is_built_with_cuda():
95  """Returns whether TensorFlow was built with CUDA (GPU) support."""
96  return _test_util.IsGoogleCudaEnabled()
97