1# Copyright 2017 The Abseil Authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Tests for test sharding protocol.""" 16 17import os 18import subprocess 19 20from absl.testing import _bazelize_command 21from absl.testing import absltest 22from absl.testing.tests import absltest_env 23 24 25NUM_TEST_METHODS = 8 # Hard-coded, based on absltest_sharding_test_helper.py 26 27 28class TestShardingTest(absltest.TestCase): 29 """Integration tests: Runs a test binary with sharding. 30 31 This is done by setting the sharding environment variables. 32 """ 33 34 def setUp(self): 35 super().setUp() 36 self._test_name = 'absl/testing/tests/absltest_sharding_test_helper' 37 self._shard_file = None 38 39 def tearDown(self): 40 super().tearDown() 41 if self._shard_file is not None and os.path.exists(self._shard_file): 42 os.unlink(self._shard_file) 43 44 def _run_sharded(self, 45 total_shards, 46 shard_index, 47 shard_file=None, 48 additional_env=None): 49 """Runs the py_test binary in a subprocess. 50 51 Args: 52 total_shards: int, the total number of shards. 53 shard_index: int, the shard index. 54 shard_file: string, if not 'None', the path to the shard file. 55 This method asserts it is properly created. 56 additional_env: Additional environment variables to be set for the py_test 57 binary. 58 59 Returns: 60 (stdout, exit_code) tuple of (string, int). 61 """ 62 env = absltest_env.inherited_env() 63 if additional_env: 64 env.update(additional_env) 65 env.update({ 66 'TEST_TOTAL_SHARDS': str(total_shards), 67 'TEST_SHARD_INDEX': str(shard_index) 68 }) 69 if shard_file: 70 self._shard_file = shard_file 71 env['TEST_SHARD_STATUS_FILE'] = shard_file 72 if os.path.exists(shard_file): 73 os.unlink(shard_file) 74 75 proc = subprocess.Popen( 76 args=[_bazelize_command.get_executable_path(self._test_name)], 77 env=env, 78 stdout=subprocess.PIPE, 79 stderr=subprocess.STDOUT, 80 universal_newlines=True) 81 stdout = proc.communicate()[0] 82 83 if shard_file: 84 self.assertTrue(os.path.exists(shard_file)) 85 86 return (stdout, proc.wait()) 87 88 def _assert_sharding_correctness(self, total_shards): 89 """Assert the primary correctness and performance of sharding. 90 91 1. Completeness (all methods are run) 92 2. Partition (each method run at most once) 93 3. Balance (for performance) 94 95 Args: 96 total_shards: int, total number of shards. 97 """ 98 99 outerr_by_shard = [] # A list of lists of strings 100 combined_outerr = [] # A list of strings 101 exit_code_by_shard = [] # A list of ints 102 103 for i in range(total_shards): 104 (out, exit_code) = self._run_sharded(total_shards, i) 105 method_list = [x for x in out.split('\n') if x.startswith('class')] 106 outerr_by_shard.append(method_list) 107 combined_outerr.extend(method_list) 108 exit_code_by_shard.append(exit_code) 109 110 self.assertLen([x for x in exit_code_by_shard if x != 0], 1, 111 'Expected exactly one failure') 112 113 # Test completeness and partition properties. 114 self.assertLen(combined_outerr, NUM_TEST_METHODS, 115 'Partition requirement not met') 116 self.assertLen(set(combined_outerr), NUM_TEST_METHODS, 117 'Completeness requirement not met') 118 119 # Test balance: 120 for i in range(len(outerr_by_shard)): 121 self.assertGreaterEqual(len(outerr_by_shard[i]), 122 (NUM_TEST_METHODS / total_shards) - 1, 123 'Shard %d of %d out of balance' % 124 (i, len(outerr_by_shard))) 125 126 def test_shard_file(self): 127 self._run_sharded(3, 1, os.path.join( 128 absltest.TEST_TMPDIR.value, 'shard_file')) 129 130 def test_zero_shards(self): 131 out, exit_code = self._run_sharded(0, 0) 132 self.assertEqual(1, exit_code) 133 self.assertGreaterEqual(out.find('Bad sharding values. index=0, total=0'), 134 0, 'Bad output: %s' % (out)) 135 136 def test_with_four_shards(self): 137 self._assert_sharding_correctness(4) 138 139 def test_with_one_shard(self): 140 self._assert_sharding_correctness(1) 141 142 def test_with_ten_shards(self): 143 self._assert_sharding_correctness(10) 144 145 def test_sharding_with_randomization(self): 146 # If we're both sharding *and* randomizing, we need to confirm that we 147 # randomize within the shard; we use two seeds to confirm we're seeing the 148 # same tests (sharding is consistent) in a different order. 149 tests_seen = [] 150 for seed in ('7', '17'): 151 out, exit_code = self._run_sharded( 152 2, 0, additional_env={'TEST_RANDOMIZE_ORDERING_SEED': seed}) 153 self.assertEqual(0, exit_code) 154 tests_seen.append([x for x in out.splitlines() if x.startswith('class')]) 155 first_tests, second_tests = tests_seen # pylint: disable=unbalanced-tuple-unpacking 156 self.assertEqual(set(first_tests), set(second_tests)) 157 self.assertNotEqual(first_tests, second_tests) 158 159 160if __name__ == '__main__': 161 absltest.main() 162