• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The Abseil Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Tests for test sharding protocol."""
16
17import os
18import subprocess
19
20from absl.testing import _bazelize_command
21from absl.testing import absltest
22from absl.testing.tests import absltest_env
23
24
25NUM_TEST_METHODS = 8  # Hard-coded, based on absltest_sharding_test_helper.py
26
27
28class TestShardingTest(absltest.TestCase):
29  """Integration tests: Runs a test binary with sharding.
30
31  This is done by setting the sharding environment variables.
32  """
33
34  def setUp(self):
35    super().setUp()
36    self._test_name = 'absl/testing/tests/absltest_sharding_test_helper'
37    self._shard_file = None
38
39  def tearDown(self):
40    super().tearDown()
41    if self._shard_file is not None and os.path.exists(self._shard_file):
42      os.unlink(self._shard_file)
43
44  def _run_sharded(self,
45                   total_shards,
46                   shard_index,
47                   shard_file=None,
48                   additional_env=None):
49    """Runs the py_test binary in a subprocess.
50
51    Args:
52      total_shards: int, the total number of shards.
53      shard_index: int, the shard index.
54      shard_file: string, if not 'None', the path to the shard file.
55        This method asserts it is properly created.
56      additional_env: Additional environment variables to be set for the py_test
57        binary.
58
59    Returns:
60      (stdout, exit_code) tuple of (string, int).
61    """
62    env = absltest_env.inherited_env()
63    if additional_env:
64      env.update(additional_env)
65    env.update({
66        'TEST_TOTAL_SHARDS': str(total_shards),
67        'TEST_SHARD_INDEX': str(shard_index)
68    })
69    if shard_file:
70      self._shard_file = shard_file
71      env['TEST_SHARD_STATUS_FILE'] = shard_file
72      if os.path.exists(shard_file):
73        os.unlink(shard_file)
74
75    proc = subprocess.Popen(
76        args=[_bazelize_command.get_executable_path(self._test_name)],
77        env=env,
78        stdout=subprocess.PIPE,
79        stderr=subprocess.STDOUT,
80        universal_newlines=True)
81    stdout = proc.communicate()[0]
82
83    if shard_file:
84      self.assertTrue(os.path.exists(shard_file))
85
86    return (stdout, proc.wait())
87
88  def _assert_sharding_correctness(self, total_shards):
89    """Assert the primary correctness and performance of sharding.
90
91    1. Completeness (all methods are run)
92    2. Partition (each method run at most once)
93    3. Balance (for performance)
94
95    Args:
96      total_shards: int, total number of shards.
97    """
98
99    outerr_by_shard = []  # A list of lists of strings
100    combined_outerr = []  # A list of strings
101    exit_code_by_shard = []  # A list of ints
102
103    for i in range(total_shards):
104      (out, exit_code) = self._run_sharded(total_shards, i)
105      method_list = [x for x in out.split('\n') if x.startswith('class')]
106      outerr_by_shard.append(method_list)
107      combined_outerr.extend(method_list)
108      exit_code_by_shard.append(exit_code)
109
110    self.assertLen([x for x in exit_code_by_shard if x != 0], 1,
111                   'Expected exactly one failure')
112
113    # Test completeness and partition properties.
114    self.assertLen(combined_outerr, NUM_TEST_METHODS,
115                   'Partition requirement not met')
116    self.assertLen(set(combined_outerr), NUM_TEST_METHODS,
117                   'Completeness requirement not met')
118
119    # Test balance:
120    for i in range(len(outerr_by_shard)):
121      self.assertGreaterEqual(len(outerr_by_shard[i]),
122                              (NUM_TEST_METHODS / total_shards) - 1,
123                              'Shard %d of %d out of balance' %
124                              (i, len(outerr_by_shard)))
125
126  def test_shard_file(self):
127    self._run_sharded(3, 1, os.path.join(
128        absltest.TEST_TMPDIR.value, 'shard_file'))
129
130  def test_zero_shards(self):
131    out, exit_code = self._run_sharded(0, 0)
132    self.assertEqual(1, exit_code)
133    self.assertGreaterEqual(out.find('Bad sharding values. index=0, total=0'),
134                            0, 'Bad output: %s' % (out))
135
136  def test_with_four_shards(self):
137    self._assert_sharding_correctness(4)
138
139  def test_with_one_shard(self):
140    self._assert_sharding_correctness(1)
141
142  def test_with_ten_shards(self):
143    self._assert_sharding_correctness(10)
144
145  def test_sharding_with_randomization(self):
146    # If we're both sharding *and* randomizing, we need to confirm that we
147    # randomize within the shard; we use two seeds to confirm we're seeing the
148    # same tests (sharding is consistent) in a different order.
149    tests_seen = []
150    for seed in ('7', '17'):
151      out, exit_code = self._run_sharded(
152          2, 0, additional_env={'TEST_RANDOMIZE_ORDERING_SEED': seed})
153      self.assertEqual(0, exit_code)
154      tests_seen.append([x for x in out.splitlines() if x.startswith('class')])
155    first_tests, second_tests = tests_seen  # pylint: disable=unbalanced-tuple-unpacking
156    self.assertEqual(set(first_tests), set(second_tests))
157    self.assertNotEqual(first_tests, second_tests)
158
159
160if __name__ == '__main__':
161  absltest.main()
162