• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15
16"""Provides python test rules for Cloud TPU."""
17
18def tpu_py_test(
19        name,
20        tags = None,
21        disable_v2 = False,
22        disable_v3 = False,
23        disable_experimental = False,
24        args = [],
25        **kwargs):
26    """Generates identical unit test variants for various Cloud TPU versions.
27
28    TODO(rsopher): actually generate v2 vs v3 tests.
29
30    Args:
31        name: Name of test. Will be prefixed by accelerator versions.
32        tags: BUILD tags to apply to tests.
33        disable_v2: If true, don't generate TPU v2 tests.
34        disable_v3: If true, don't generate TPU v3 tests.
35        disable_experimental: Unused.
36        args: Arguments to apply to tests.
37        **kwargs: Additional named arguments to apply to tests.
38    """
39    tags = tags or []
40
41    tags = [
42        "tpu",
43        "no_pip",
44        "no_gpu",
45        "nomac",
46    ] + tags
47
48    test_main = kwargs.get("srcs")
49    if not test_main or len(test_main) > 1:
50        fail('"srcs" should be a list of exactly one python file.')
51    test_main = test_main[0]
52
53    wrapper_src = _copy_test_source(
54        "//tensorflow/python/tpu:tpu_test_wrapper.py",
55    )
56
57    kwargs["python_version"] = kwargs.get("python_version", "PY3")
58    kwargs["srcs"].append(wrapper_src)
59    kwargs["deps"].append("//tensorflow/python:client_testlib")
60    kwargs["main"] = wrapper_src
61
62    args = [
63        "--wrapped_tpu_test_module_relative=.%s" % test_main.rsplit(".", 1)[0],
64    ] + args
65
66    native.py_test(
67        name = name,
68        tags = tags,
69        args = args,
70        **kwargs
71    )
72
73def _copy_test_source(src):
74    """Creates a genrule copying src into the current directory.
75
76    This silences a Bazel warning, and is necessary for relative import of the
77    user test to work.
78
79    This genrule checks existing rules to avoid duplicating the source if
80    another call has already produced the file. Note that this will fail
81    weirdly if two source files have the same filename, as whichever one is
82    copied in first will win and other tests will unexpectedly run the wrong
83    file. We don't expect to see this case, since we're only copying the one
84    test wrapper around.
85
86    Args:
87        src: The source file we would like to use.
88
89    Returns:
90        The path of a copy of this source file, inside the current package.
91    """
92    name = src.rpartition(":")[-1].rpartition("/")[-1]  # Get basename.
93
94    new_main = "%s/%s" % (native.package_name(), name)
95    new_name = "_gen_" + name
96
97    if not native.existing_rule(new_name):
98        native.genrule(
99            name = new_name,
100            srcs = [src],
101            outs = [new_main],
102            cmd = "cp $< $@",
103        )
104
105    return new_main
106