• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Tests for tpu_test_wrapper.py."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import importlib.util  # Python 3 only.
22import os
23
24from absl.testing import flagsaver
25
26from tensorflow.python.platform import flags
27from tensorflow.python.platform import test
28from tensorflow.python.tpu import tpu_test_wrapper
29
30
31class TPUTestWrapperTest(test.TestCase):
32
33  @flagsaver.flagsaver()
34  def test_flags_undefined(self):
35    tpu_test_wrapper.maybe_define_flags()
36
37    self.assertIn('tpu', flags.FLAGS)
38    self.assertIn('zone', flags.FLAGS)
39    self.assertIn('project', flags.FLAGS)
40    self.assertIn('model_dir', flags.FLAGS)
41
42  @flagsaver.flagsaver()
43  def test_flags_already_defined_not_overridden(self):
44    flags.DEFINE_string('tpu', 'tpuname', 'helpstring')
45    tpu_test_wrapper.maybe_define_flags()
46
47    self.assertIn('tpu', flags.FLAGS)
48    self.assertIn('zone', flags.FLAGS)
49    self.assertIn('project', flags.FLAGS)
50    self.assertIn('model_dir', flags.FLAGS)
51    self.assertEqual(flags.FLAGS.tpu, 'tpuname')
52
53  @flagsaver.flagsaver(bazel_repo_root='tensorflow/python')
54  def test_parent_path(self):
55    filepath = '/filesystem/path/tensorflow/python/tpu/example_test.runfiles/tensorflow/python/tpu/example_test'  # pylint: disable=line-too-long
56    self.assertEqual(
57        tpu_test_wrapper.calculate_parent_python_path(filepath),
58        'tensorflow.python.tpu')
59
60  @flagsaver.flagsaver(bazel_repo_root='tensorflow/python')
61  def test_parent_path_raises(self):
62    filepath = '/bad/path'
63    with self.assertRaisesWithLiteralMatch(
64        ValueError,
65        'Filepath "/bad/path" does not contain repo root "tensorflow/python"'):
66      tpu_test_wrapper.calculate_parent_python_path(filepath)
67
68  def test_is_test_class_positive(self):
69
70    class A(test.TestCase):
71      pass
72
73    self.assertTrue(tpu_test_wrapper._is_test_class(A))
74
75  def test_is_test_class_negative(self):
76
77    class A(object):
78      pass
79
80    self.assertFalse(tpu_test_wrapper._is_test_class(A))
81
82  @flagsaver.flagsaver(wrapped_tpu_test_module_relative='.tpu_test_wrapper_test'
83                      )
84  def test_move_test_classes_into_scope(self):
85    # Test the class importer by having the wrapper module import this test
86    # into itself.
87    with test.mock.patch.object(
88        tpu_test_wrapper, 'calculate_parent_python_path') as mock_parent_path:
89      mock_parent_path.return_value = (
90          tpu_test_wrapper.__name__.rpartition('.')[0])
91
92      module = tpu_test_wrapper.import_user_module()
93      tpu_test_wrapper.move_test_classes_into_scope(module)
94
95    self.assertEqual(
96        tpu_test_wrapper.tpu_test_imported_TPUTestWrapperTest.__name__,
97        self.__class__.__name__)
98
99  @flagsaver.flagsaver(test_dir_base='gs://example-bucket/tempfiles')
100  def test_set_random_test_dir(self):
101    tpu_test_wrapper.maybe_define_flags()
102    tpu_test_wrapper.set_random_test_dir()
103
104    self.assertStartsWith(flags.FLAGS.model_dir,
105                          'gs://example-bucket/tempfiles')
106    self.assertGreater(
107        len(flags.FLAGS.model_dir), len('gs://example-bucket/tempfiles'))
108
109  @flagsaver.flagsaver(test_dir_base='gs://example-bucket/tempfiles')
110  def test_set_random_test_dir_repeatable(self):
111    tpu_test_wrapper.maybe_define_flags()
112    tpu_test_wrapper.set_random_test_dir()
113    first = flags.FLAGS.model_dir
114    tpu_test_wrapper.set_random_test_dir()
115    second = flags.FLAGS.model_dir
116
117    self.assertNotEqual(first, second)
118
119  def test_run_user_main(self):
120    test_module = _write_and_load_module("""
121from __future__ import absolute_import
122from __future__ import division
123from __future__ import print_function
124
125VARS = 1
126
127if 'unrelated_if' == 'should_be_ignored':
128  VARS = 2
129
130if __name__ == '__main__':
131  VARS = 3
132
133if 'extra_if_at_bottom' == 'should_be_ignored':
134  VARS = 4
135""")
136
137    self.assertEqual(test_module.VARS, 1)
138    tpu_test_wrapper.run_user_main(test_module)
139    self.assertEqual(test_module.VARS, 3)
140
141  def test_run_user_main_missing_if(self):
142    test_module = _write_and_load_module("""
143from __future__ import absolute_import
144from __future__ import division
145from __future__ import print_function
146
147VARS = 1
148""")
149
150    self.assertEqual(test_module.VARS, 1)
151    with self.assertRaises(NotImplementedError):
152      tpu_test_wrapper.run_user_main(test_module)
153
154  def test_run_user_main_double_quotes(self):
155    test_module = _write_and_load_module("""
156from __future__ import absolute_import
157from __future__ import division
158from __future__ import print_function
159
160VARS = 1
161
162if "unrelated_if" == "should_be_ignored":
163  VARS = 2
164
165if __name__ == "__main__":
166  VARS = 3
167
168if "extra_if_at_bottom" == "should_be_ignored":
169  VARS = 4
170""")
171
172    self.assertEqual(test_module.VARS, 1)
173    tpu_test_wrapper.run_user_main(test_module)
174    self.assertEqual(test_module.VARS, 3)
175
176  def test_run_user_main_test(self):
177    test_module = _write_and_load_module("""
178from __future__ import absolute_import
179from __future__ import division
180from __future__ import print_function
181
182from tensorflow.python.platform import test as unique_name
183
184class DummyTest(unique_name.TestCase):
185  def test_fail(self):
186    self.fail()
187
188if __name__ == '__main__':
189  unique_name.main()
190""")
191
192    # We're actually limited in what we can test here -- we can't call
193    # test.main() without deleting this current test from locals(), or we'll
194    # recurse infinitely. We settle for testing that the test imports and calls
195    # the right test module.
196
197    with test.mock.patch.object(test, 'main') as mock_main:
198      tpu_test_wrapper.run_user_main(test_module)
199    mock_main.assert_called_once()
200
201
202def _write_and_load_module(source):
203  fp = os.path.join(test.get_temp_dir(), 'testmod.py')
204  with open(fp, 'w') as f:
205    f.write(source)
206  spec = importlib.util.spec_from_file_location('testmodule', fp)
207  test_module = importlib.util.module_from_spec(spec)
208  spec.loader.exec_module(test_module)
209  return test_module
210
211
212if __name__ == '__main__':
213  test.main()
214