1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Tests for tpu_test_wrapper.py.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import importlib.util # Python 3 only. 22import os 23 24from absl.testing import flagsaver 25 26from tensorflow.python.platform import flags 27from tensorflow.python.platform import test 28from tensorflow.python.tpu import tpu_test_wrapper 29 30 31class TPUTestWrapperTest(test.TestCase): 32 33 @flagsaver.flagsaver() 34 def test_flags_undefined(self): 35 tpu_test_wrapper.maybe_define_flags() 36 37 self.assertIn('tpu', flags.FLAGS) 38 self.assertIn('zone', flags.FLAGS) 39 self.assertIn('project', flags.FLAGS) 40 self.assertIn('model_dir', flags.FLAGS) 41 42 @flagsaver.flagsaver() 43 def test_flags_already_defined_not_overridden(self): 44 flags.DEFINE_string('tpu', 'tpuname', 'helpstring') 45 tpu_test_wrapper.maybe_define_flags() 46 47 self.assertIn('tpu', flags.FLAGS) 48 self.assertIn('zone', flags.FLAGS) 49 self.assertIn('project', flags.FLAGS) 50 self.assertIn('model_dir', flags.FLAGS) 51 self.assertEqual(flags.FLAGS.tpu, 'tpuname') 52 53 @flagsaver.flagsaver(bazel_repo_root='tensorflow/python') 54 def test_parent_path(self): 55 filepath = '/filesystem/path/tensorflow/python/tpu/example_test.runfiles/tensorflow/python/tpu/example_test' # pylint: disable=line-too-long 56 self.assertEqual( 57 tpu_test_wrapper.calculate_parent_python_path(filepath), 58 'tensorflow.python.tpu') 59 60 @flagsaver.flagsaver(bazel_repo_root='tensorflow/python') 61 def test_parent_path_raises(self): 62 filepath = '/bad/path' 63 with self.assertRaisesWithLiteralMatch( 64 ValueError, 65 'Filepath "/bad/path" does not contain repo root "tensorflow/python"'): 66 tpu_test_wrapper.calculate_parent_python_path(filepath) 67 68 def test_is_test_class_positive(self): 69 70 class A(test.TestCase): 71 pass 72 73 self.assertTrue(tpu_test_wrapper._is_test_class(A)) 74 75 def test_is_test_class_negative(self): 76 77 class A(object): 78 pass 79 80 self.assertFalse(tpu_test_wrapper._is_test_class(A)) 81 82 @flagsaver.flagsaver(wrapped_tpu_test_module_relative='.tpu_test_wrapper_test' 83 ) 84 def test_move_test_classes_into_scope(self): 85 # Test the class importer by having the wrapper module import this test 86 # into itself. 87 with test.mock.patch.object( 88 tpu_test_wrapper, 'calculate_parent_python_path') as mock_parent_path: 89 mock_parent_path.return_value = ( 90 tpu_test_wrapper.__name__.rpartition('.')[0]) 91 92 module = tpu_test_wrapper.import_user_module() 93 tpu_test_wrapper.move_test_classes_into_scope(module) 94 95 self.assertEqual( 96 tpu_test_wrapper.tpu_test_imported_TPUTestWrapperTest.__name__, 97 self.__class__.__name__) 98 99 @flagsaver.flagsaver(test_dir_base='gs://example-bucket/tempfiles') 100 def test_set_random_test_dir(self): 101 tpu_test_wrapper.maybe_define_flags() 102 tpu_test_wrapper.set_random_test_dir() 103 104 self.assertStartsWith(flags.FLAGS.model_dir, 105 'gs://example-bucket/tempfiles') 106 self.assertGreater( 107 len(flags.FLAGS.model_dir), len('gs://example-bucket/tempfiles')) 108 109 @flagsaver.flagsaver(test_dir_base='gs://example-bucket/tempfiles') 110 def test_set_random_test_dir_repeatable(self): 111 tpu_test_wrapper.maybe_define_flags() 112 tpu_test_wrapper.set_random_test_dir() 113 first = flags.FLAGS.model_dir 114 tpu_test_wrapper.set_random_test_dir() 115 second = flags.FLAGS.model_dir 116 117 self.assertNotEqual(first, second) 118 119 def test_run_user_main(self): 120 test_module = _write_and_load_module(""" 121from __future__ import absolute_import 122from __future__ import division 123from __future__ import print_function 124 125VARS = 1 126 127if 'unrelated_if' == 'should_be_ignored': 128 VARS = 2 129 130if __name__ == '__main__': 131 VARS = 3 132 133if 'extra_if_at_bottom' == 'should_be_ignored': 134 VARS = 4 135""") 136 137 self.assertEqual(test_module.VARS, 1) 138 tpu_test_wrapper.run_user_main(test_module) 139 self.assertEqual(test_module.VARS, 3) 140 141 def test_run_user_main_missing_if(self): 142 test_module = _write_and_load_module(""" 143from __future__ import absolute_import 144from __future__ import division 145from __future__ import print_function 146 147VARS = 1 148""") 149 150 self.assertEqual(test_module.VARS, 1) 151 with self.assertRaises(NotImplementedError): 152 tpu_test_wrapper.run_user_main(test_module) 153 154 def test_run_user_main_double_quotes(self): 155 test_module = _write_and_load_module(""" 156from __future__ import absolute_import 157from __future__ import division 158from __future__ import print_function 159 160VARS = 1 161 162if "unrelated_if" == "should_be_ignored": 163 VARS = 2 164 165if __name__ == "__main__": 166 VARS = 3 167 168if "extra_if_at_bottom" == "should_be_ignored": 169 VARS = 4 170""") 171 172 self.assertEqual(test_module.VARS, 1) 173 tpu_test_wrapper.run_user_main(test_module) 174 self.assertEqual(test_module.VARS, 3) 175 176 def test_run_user_main_test(self): 177 test_module = _write_and_load_module(""" 178from __future__ import absolute_import 179from __future__ import division 180from __future__ import print_function 181 182from tensorflow.python.platform import test as unique_name 183 184class DummyTest(unique_name.TestCase): 185 def test_fail(self): 186 self.fail() 187 188if __name__ == '__main__': 189 unique_name.main() 190""") 191 192 # We're actually limited in what we can test here -- we can't call 193 # test.main() without deleting this current test from locals(), or we'll 194 # recurse infinitely. We settle for testing that the test imports and calls 195 # the right test module. 196 197 with test.mock.patch.object(test, 'main') as mock_main: 198 tpu_test_wrapper.run_user_main(test_module) 199 mock_main.assert_called_once() 200 201 202def _write_and_load_module(source): 203 fp = os.path.join(test.get_temp_dir(), 'testmod.py') 204 with open(fp, 'w') as f: 205 f.write(source) 206 spec = importlib.util.spec_from_file_location('testmodule', fp) 207 test_module = importlib.util.module_from_spec(spec) 208 spec.loader.exec_module(test_module) 209 return test_module 210 211 212if __name__ == '__main__': 213 test.main() 214