• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Library for multi-process testing."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20import multiprocessing
21import os
22import platform
23import sys
24import unittest
25from absl import app
26from absl import logging
27
28from tensorflow.python.eager import test
29
30
31def is_oss():
32  """Returns whether the test is run under OSS."""
33  return len(sys.argv) >= 1 and 'bazel' in sys.argv[0]
34
35
36def _is_enabled():
37  # Note that flags may not be parsed at this point and simply importing the
38  # flags module causes a variety of unusual errors.
39  tpu_args = [arg for arg in sys.argv if arg.startswith('--tpu')]
40  if is_oss() and tpu_args:
41    return False
42  if sys.version_info == (3, 8) and platform.system() == 'Linux':
43    return False  # TODO(b/171242147)
44  return sys.platform != 'win32'
45
46
47class _AbslProcess:
48  """A process that runs using absl.app.run."""
49
50  def __init__(self, *args, **kwargs):
51    super(_AbslProcess, self).__init__(*args, **kwargs)
52    # Monkey-patch that is carried over into the spawned process by pickle.
53    self._run_impl = getattr(self, 'run')
54    self.run = self._run_with_absl
55
56  def _run_with_absl(self):
57    app.run(lambda _: self._run_impl())
58
59
60if _is_enabled():
61
62  class AbslForkServerProcess(_AbslProcess,
63                              multiprocessing.context.ForkServerProcess):
64    """An absl-compatible Forkserver process.
65
66    Note: Forkserver is not available in windows.
67    """
68
69  class AbslForkServerContext(multiprocessing.context.ForkServerContext):
70    _name = 'absl_forkserver'
71    Process = AbslForkServerProcess  # pylint: disable=invalid-name
72
73  multiprocessing = AbslForkServerContext()
74  Process = multiprocessing.Process
75
76else:
77
78  class Process(object):
79    """A process that skips test (until windows is supported)."""
80
81    def __init__(self, *args, **kwargs):
82      del args, kwargs
83      raise unittest.SkipTest(
84          'TODO(b/150264776): Windows is not supported in MultiProcessRunner.')
85
86
87_test_main_called = False
88
89
90def _set_spawn_exe_path():
91  """Set the path to the executable for spawned processes.
92
93  This utility searches for the binary the parent process is using, and sets
94  the executable of multiprocessing's context accordingly.
95
96  Raises:
97    RuntimeError: If the binary path cannot be determined.
98  """
99  # TODO(b/150264776): This does not work with Windows. Find a solution.
100  if sys.argv[0].endswith('.py'):
101    def guess_path(package_root):
102      # If all we have is a python module path, we'll need to make a guess for
103      # the actual executable path.
104      if 'bazel-out' in sys.argv[0] and package_root in sys.argv[0]:
105        # Guess the binary path under bazel. For target
106        # //tensorflow/python/distribute:input_lib_test_multiworker_gpu, the
107        # argv[0] is in the form of
108        # /.../tensorflow/python/distribute/input_lib_test.py
109        # and the binary is
110        # /.../tensorflow/python/distribute/input_lib_test_multiworker_gpu
111        package_root_base = sys.argv[0][:sys.argv[0].rfind(package_root)]
112        binary = os.environ['TEST_TARGET'][2:].replace(':', '/', 1)
113        possible_path = os.path.join(package_root_base, package_root,
114                                     binary)
115        logging.info('Guessed test binary path: %s', possible_path)
116        if os.access(possible_path, os.X_OK):
117          return possible_path
118        return None
119    path = guess_path('org_tensorflow')
120    if not path:
121      path = guess_path('org_keras')
122    if path is None:
123      logging.error(
124          'Cannot determine binary path. sys.argv[0]=%s os.environ=%s',
125          sys.argv[0], os.environ)
126      raise RuntimeError('Cannot determine binary path')
127    sys.argv[0] = path
128  # Note that this sets the executable for *all* contexts.
129  multiprocessing.get_context().set_executable(sys.argv[0])
130
131
132def _if_spawn_run_and_exit():
133  """If spawned process, run requested spawn task and exit. Else a no-op."""
134
135  # `multiprocessing` module passes a script "from multiprocessing.x import y"
136  # to subprocess, followed by a main function call. We use this to tell if
137  # the process is spawned. Examples of x are "forkserver" or
138  # "semaphore_tracker".
139  is_spawned = ('-c' in sys.argv[1:] and
140                sys.argv[sys.argv.index('-c') +
141                         1].startswith('from multiprocessing.'))
142
143  if not is_spawned:
144    return
145  cmd = sys.argv[sys.argv.index('-c') + 1]
146  # As a subprocess, we disregarding all other interpreter command line
147  # arguments.
148  sys.argv = sys.argv[0:1]
149
150  # Run the specified command - this is expected to be one of:
151  # 1. Spawn the process for semaphore tracker.
152  # 2. Spawn the initial process for forkserver.
153  # 3. Spawn any process as requested by the "spawn" method.
154  exec(cmd)  # pylint: disable=exec-used
155  sys.exit(0)  # Semaphore tracker doesn't explicitly sys.exit.
156
157
158def test_main():
159  """Main function to be called within `__main__` of a test file."""
160  global _test_main_called
161  _test_main_called = True
162
163  os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
164
165  if _is_enabled():
166    _set_spawn_exe_path()
167    _if_spawn_run_and_exit()
168
169  # Only runs test.main() if not spawned process.
170  test.main()
171
172
173def initialized():
174  """Returns whether the module is initialized."""
175  return _test_main_called
176