1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Wrapper for Python TPU tests. 16 17The py_tpu_test macro will actually use this file as its main, building and 18executing the user-provided test file as a py_binary instead. This lets us do 19important work behind the scenes, without complicating the tests themselves. 20 21The main responsibilities of this file are: 22 - Define standard set of model flags if test did not. This allows us to 23 safely set flags at the Bazel invocation level using --test_arg. 24 - Pick a random directory on GCS to use for each test case, and set it as the 25 default value of --model_dir. This is similar to how Bazel provides each 26 test with a fresh local directory in $TEST_TMPDIR. 27""" 28 29from __future__ import absolute_import 30from __future__ import division 31from __future__ import print_function 32 33import ast 34import importlib 35import os 36import sys 37import uuid 38 39from tensorflow.python.platform import flags 40from tensorflow.python.util import tf_inspect 41 42FLAGS = flags.FLAGS 43flags.DEFINE_string( 44 'wrapped_tpu_test_module_relative', None, 45 'The Python-style relative path to the user-given test. If test is in same ' 46 'directory as BUILD file as is common, then "test.py" would be ".test".') 47flags.DEFINE_string('test_dir_base', 48 os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR'), 49 'GCS path to root directory for temporary test files.') 50flags.DEFINE_string( 51 'bazel_repo_root', 'tensorflow/python', 52 'Substring of a bazel filepath beginning the python absolute import path.') 53 54# List of flags which all TPU tests should accept. 55REQUIRED_FLAGS = ['tpu', 'zone', 'project', 'model_dir'] 56 57 58def maybe_define_flags(): 59 """Defines any required flags that are missing.""" 60 for f in REQUIRED_FLAGS: 61 try: 62 flags.DEFINE_string(f, None, 'flag defined by test lib') 63 except flags.DuplicateFlagError: 64 pass 65 66 67def set_random_test_dir(): 68 """Pick a random GCS directory under --test_dir_base, set as --model_dir.""" 69 path = os.path.join(FLAGS.test_dir_base, uuid.uuid4().hex) 70 FLAGS.set_default('model_dir', path) 71 72 73def calculate_parent_python_path(test_filepath): 74 """Returns the absolute import path for the containing directory. 75 76 Args: 77 test_filepath: The filepath which Bazel invoked 78 (ex: /filesystem/path/tensorflow/tensorflow/python/tpu/tpu_test) 79 80 Returns: 81 Absolute import path of parent (ex: tensorflow.python.tpu). 82 83 Raises: 84 ValueError: if bazel_repo_root does not appear within test_filepath. 85 """ 86 # We find the last occurrence of bazel_repo_root, and drop everything before. 87 split_path = test_filepath.rsplit(FLAGS.bazel_repo_root, 1) 88 if len(split_path) < 2: 89 raise ValueError('Filepath "%s" does not contain repo root "%s"' % 90 (test_filepath, FLAGS.bazel_repo_root)) 91 path = FLAGS.bazel_repo_root + split_path[1] 92 93 # We drop the last portion of the path, which is the name of the test wrapper. 94 path = path.rsplit('/', 1)[0] 95 96 # We convert the directory separators into dots. 97 return path.replace('/', '.') 98 99 100def import_user_module(): 101 """Imports the flag-specified user test code. 102 103 This runs all top-level statements in the user module, specifically flag 104 definitions. 105 106 Returns: 107 The user test module. 108 """ 109 return importlib.import_module(FLAGS.wrapped_tpu_test_module_relative, 110 calculate_parent_python_path(sys.argv[0])) 111 112 113def _is_test_class(obj): 114 """Check if arbitrary object is a test class (not a test object!). 115 116 Args: 117 obj: An arbitrary object from within a module. 118 119 Returns: 120 True iff obj is a test class inheriting at some point from a module 121 named "TestCase". This is because we write tests using different underlying 122 test libraries. 123 """ 124 return (tf_inspect.isclass(obj) 125 and 'TestCase' in (p.__name__ for p in tf_inspect.getmro(obj))) 126 127 128module_variables = vars() 129 130 131def move_test_classes_into_scope(wrapped_test_module): 132 """Add all test classes defined in wrapped module to our module. 133 134 The test runner works by inspecting the main module for TestCase classes, so 135 by adding a module-level reference to the TestCase we cause it to execute the 136 wrapped TestCase. 137 138 Args: 139 wrapped_test_module: The user-provided test code to run. 140 """ 141 for name, obj in wrapped_test_module.__dict__.items(): 142 if _is_test_class(obj): 143 module_variables['tpu_test_imported_%s' % name] = obj 144 145 146def run_user_main(wrapped_test_module): 147 """Runs the "if __name__ == '__main__'" at the bottom of a module. 148 149 TensorFlow practice is to have a main if at the bottom of the module which 150 might call an API compat function before calling test.main(). 151 152 Since this is a statement, not a function, we can't cleanly reference it, but 153 we can inspect it from the user module and run it in the context of that 154 module so all imports and variables are available to it. 155 156 Args: 157 wrapped_test_module: The user-provided test code to run. 158 159 Raises: 160 NotImplementedError: If main block was not found in module. This should not 161 be caught, as it is likely an error on the user's part -- absltest is all 162 too happy to report a successful status (and zero tests executed) if a 163 user forgets to end a class with "test.main()". 164 """ 165 tree = ast.parse(tf_inspect.getsource(wrapped_test_module)) 166 167 # Get string representation of just the condition `__name == "__main__"`. 168 target = ast.dump(ast.parse('if __name__ == "__main__": pass').body[0].test) 169 170 # `tree.body` is a list of top-level statements in the module, like imports 171 # and class definitions. We search for our main block, starting from the end. 172 for expr in reversed(tree.body): 173 if isinstance(expr, ast.If) and ast.dump(expr.test) == target: 174 break 175 else: 176 raise NotImplementedError( 177 'Could not find `if __name__ == "main":` block in %s.' % 178 wrapped_test_module.__name__) 179 180 # expr is defined because we would have raised an error otherwise. 181 new_ast = ast.Module(body=expr.body, type_ignores=[]) # pylint:disable=undefined-loop-variable 182 exec( # pylint:disable=exec-used 183 compile(new_ast, '<ast>', 'exec'), 184 globals(), 185 wrapped_test_module.__dict__, 186 ) 187 188 189if __name__ == '__main__': 190 # Partially parse flags, since module to import is specified by flag. 191 unparsed = FLAGS(sys.argv, known_only=True) 192 user_module = import_user_module() 193 maybe_define_flags() 194 # Parse remaining flags. 195 FLAGS(unparsed) 196 set_random_test_dir() 197 198 move_test_classes_into_scope(user_module) 199 run_user_main(user_module) 200