1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Testing. 17 18See the [Testing](https://tensorflow.org/api_guides/python/test) guide. 19 20Note: `tf.test.mock` is an alias to the python `mock` or `unittest.mock` 21depending on the python version. 22""" 23 24from __future__ import absolute_import 25from __future__ import division 26from __future__ import print_function 27 28 29# pylint: disable=g-bad-import-order 30from tensorflow.python.framework import test_util as _test_util 31from tensorflow.python.platform import googletest as _googletest 32 33# pylint: disable=unused-import 34from tensorflow.python.framework.test_util import assert_equal_graph_def 35from tensorflow.python.framework.test_util import create_local_cluster 36from tensorflow.python.framework.test_util import TensorFlowTestCase as TestCase 37from tensorflow.python.framework.test_util import gpu_device_name 38from tensorflow.python.framework.test_util import is_gpu_available 39 40from tensorflow.python.ops.gradient_checker import compute_gradient_error 41from tensorflow.python.ops.gradient_checker import compute_gradient 42# pylint: enable=unused-import,g-bad-import-order 43 44import sys 45from tensorflow.python.util.tf_export import tf_export 46if sys.version_info.major == 2: 47 import mock # pylint: disable=g-import-not-at-top,unused-import 48else: 49 from unittest import mock # pylint: disable=g-import-not-at-top,g-importing-member 50 51tf_export(v1=['test.mock'])(mock) 52 53# Import Benchmark class 54Benchmark = _googletest.Benchmark # pylint: disable=invalid-name 55 56# Import StubOutForTesting class 57StubOutForTesting = _googletest.StubOutForTesting # pylint: disable=invalid-name 58 59 60@tf_export('test.main') 61def main(argv=None): 62 """Runs all unit tests.""" 63 _test_util.InstallStackTraceHandler() 64 return _googletest.main(argv) 65 66 67@tf_export(v1=['test.get_temp_dir']) 68def get_temp_dir(): 69 """Returns a temporary directory for use during tests. 70 71 There is no need to delete the directory after the test. 72 73 Returns: 74 The temporary directory. 75 """ 76 return _googletest.GetTempDir() 77 78 79@tf_export(v1=['test.test_src_dir_path']) 80def test_src_dir_path(relative_path): 81 """Creates an absolute test srcdir path given a relative path. 82 83 Args: 84 relative_path: a path relative to tensorflow root. 85 e.g. "core/platform". 86 87 Returns: 88 An absolute path to the linked in runfiles. 89 """ 90 return _googletest.test_src_dir_path(relative_path) 91 92 93@tf_export('test.is_built_with_cuda') 94def is_built_with_cuda(): 95 """Returns whether TensorFlow was built with CUDA (GPU) support.""" 96 return _test_util.IsGoogleCudaEnabled() 97