1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15 16"""Provides python test rules for Cloud TPU.""" 17 18def tpu_py_test( 19 name, 20 tags = None, 21 disable_v2 = False, 22 disable_v3 = False, 23 disable_experimental = False, 24 args = [], 25 **kwargs): 26 """Generates identical unit test variants for various Cloud TPU versions. 27 28 TODO(rsopher): actually generate v2 vs v3 tests. 29 30 Args: 31 name: Name of test. Will be prefixed by accelerator versions. 32 tags: BUILD tags to apply to tests. 33 disable_v2: If true, don't generate TPU v2 tests. 34 disable_v3: If true, don't generate TPU v3 tests. 35 disable_experimental: Unused. 36 args: Arguments to apply to tests. 37 **kwargs: Additional named arguments to apply to tests. 38 """ 39 tags = tags or [] 40 41 tags = [ 42 "tpu", 43 "no_pip", 44 "no_gpu", 45 "nomac", 46 ] + tags 47 48 test_main = kwargs.get("srcs") 49 if not test_main or len(test_main) > 1: 50 fail('"srcs" should be a list of exactly one python file.') 51 test_main = test_main[0] 52 53 wrapper_src = _copy_test_source( 54 "//tensorflow/python/tpu:tpu_test_wrapper.py", 55 ) 56 57 kwargs["python_version"] = kwargs.get("python_version", "PY3") 58 kwargs["srcs"].append(wrapper_src) 59 kwargs["deps"].append("//tensorflow/python:client_testlib") 60 kwargs["main"] = wrapper_src 61 62 args = [ 63 "--wrapped_tpu_test_module_relative=.%s" % test_main.rsplit(".", 1)[0], 64 ] + args 65 66 native.py_test( 67 name = name, 68 tags = tags, 69 args = args, 70 **kwargs 71 ) 72 73def _copy_test_source(src): 74 """Creates a genrule copying src into the current directory. 75 76 This silences a Bazel warning, and is necessary for relative import of the 77 user test to work. 78 79 This genrule checks existing rules to avoid duplicating the source if 80 another call has already produced the file. Note that this will fail 81 weirdly if two source files have the same filename, as whichever one is 82 copied in first will win and other tests will unexpectedly run the wrong 83 file. We don't expect to see this case, since we're only copying the one 84 test wrapper around. 85 86 Args: 87 src: The source file we would like to use. 88 89 Returns: 90 The path of a copy of this source file, inside the current package. 91 """ 92 name = src.rpartition(":")[-1].rpartition("/")[-1] # Get basename. 93 94 new_main = "%s/%s" % (native.package_name(), name) 95 new_name = "_gen_" + name 96 97 if not native.existing_rule(new_name): 98 native.genrule( 99 name = new_name, 100 srcs = [src], 101 outs = [new_main], 102 cmd = "cp $< $@", 103 ) 104 105 return new_main 106