• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Testing.
17
18See the @{$python/test} guide.
19
20Note: `tf.test.mock` is an alias to the python `mock` or `unittest.mock`
21depending on the python version.
22
23@@main
24@@TestCase
25@@test_src_dir_path
26@@assert_equal_graph_def
27@@get_temp_dir
28@@is_built_with_cuda
29@@is_gpu_available
30@@gpu_device_name
31@@compute_gradient
32@@compute_gradient_error
33@@create_local_cluster
34
35"""
36
37from __future__ import absolute_import
38from __future__ import division
39from __future__ import print_function
40
41
42# pylint: disable=g-bad-import-order
43from tensorflow.python.framework import test_util as _test_util
44from tensorflow.python.platform import googletest as _googletest
45from tensorflow.python.util.all_util import remove_undocumented
46
47# pylint: disable=unused-import
48from tensorflow.python.framework.test_util import assert_equal_graph_def
49from tensorflow.python.framework.test_util import create_local_cluster
50from tensorflow.python.framework.test_util import TensorFlowTestCase as TestCase
51from tensorflow.python.framework.test_util import gpu_device_name
52from tensorflow.python.framework.test_util import is_gpu_available
53
54from tensorflow.python.ops.gradient_checker import compute_gradient_error
55from tensorflow.python.ops.gradient_checker import compute_gradient
56# pylint: enable=unused-import,g-bad-import-order
57
58import sys
59from tensorflow.python.util.tf_export import tf_export
60if sys.version_info.major == 2:
61  import mock                # pylint: disable=g-import-not-at-top,unused-import
62else:
63  from unittest import mock  # pylint: disable=g-import-not-at-top
64
65# Import Benchmark class
66Benchmark = _googletest.Benchmark  # pylint: disable=invalid-name
67
68# Import StubOutForTesting class
69StubOutForTesting = _googletest.StubOutForTesting  # pylint: disable=invalid-name
70
71
72@tf_export('test.main')
73def main(argv=None):
74  """Runs all unit tests."""
75  _test_util.InstallStackTraceHandler()
76  return _googletest.main(argv)
77
78
79@tf_export('test.get_temp_dir')
80def get_temp_dir():
81  """Returns a temporary directory for use during tests.
82
83  There is no need to delete the directory after the test.
84
85  Returns:
86    The temporary directory.
87  """
88  return _googletest.GetTempDir()
89
90
91@tf_export('test.test_src_dir_path')
92def test_src_dir_path(relative_path):
93  """Creates an absolute test srcdir path given a relative path.
94
95  Args:
96    relative_path: a path relative to tensorflow root.
97      e.g. "core/platform".
98
99  Returns:
100    An absolute path to the linked in runfiles.
101  """
102  return _googletest.test_src_dir_path(relative_path)
103
104
105@tf_export('test.is_built_with_cuda')
106def is_built_with_cuda():
107  """Returns whether TensorFlow was built with CUDA (GPU) support."""
108  return _test_util.IsGoogleCudaEnabled()
109
110
111_allowed_symbols = [
112    # We piggy-back googletest documentation.
113    'Benchmark',
114    'mock',
115    'StubOutForTesting',
116]
117
118remove_undocumented(__name__, _allowed_symbols)
119