1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Library for multi-process testing.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20import multiprocessing 21import os 22import platform 23import sys 24import unittest 25from absl import app 26from absl import logging 27 28from tensorflow.python.eager import test 29 30 31def is_oss(): 32 """Returns whether the test is run under OSS.""" 33 return len(sys.argv) >= 1 and 'bazel' in sys.argv[0] 34 35 36def _is_enabled(): 37 # Note that flags may not be parsed at this point and simply importing the 38 # flags module causes a variety of unusual errors. 39 tpu_args = [arg for arg in sys.argv if arg.startswith('--tpu')] 40 if is_oss() and tpu_args: 41 return False 42 if sys.version_info == (3, 8) and platform.system() == 'Linux': 43 return False # TODO(b/171242147) 44 return sys.platform != 'win32' 45 46 47class _AbslProcess: 48 """A process that runs using absl.app.run.""" 49 50 def __init__(self, *args, **kwargs): 51 super(_AbslProcess, self).__init__(*args, **kwargs) 52 # Monkey-patch that is carried over into the spawned process by pickle. 53 self._run_impl = getattr(self, 'run') 54 self.run = self._run_with_absl 55 56 def _run_with_absl(self): 57 app.run(lambda _: self._run_impl()) 58 59 60if _is_enabled(): 61 62 class AbslForkServerProcess(_AbslProcess, 63 multiprocessing.context.ForkServerProcess): 64 """An absl-compatible Forkserver process. 65 66 Note: Forkserver is not available in windows. 67 """ 68 69 class AbslForkServerContext(multiprocessing.context.ForkServerContext): 70 _name = 'absl_forkserver' 71 Process = AbslForkServerProcess # pylint: disable=invalid-name 72 73 multiprocessing = AbslForkServerContext() 74 Process = multiprocessing.Process 75 76else: 77 78 class Process(object): 79 """A process that skips test (until windows is supported).""" 80 81 def __init__(self, *args, **kwargs): 82 del args, kwargs 83 raise unittest.SkipTest( 84 'TODO(b/150264776): Windows is not supported in MultiProcessRunner.') 85 86 87_test_main_called = False 88 89 90def _set_spawn_exe_path(): 91 """Set the path to the executable for spawned processes. 92 93 This utility searches for the binary the parent process is using, and sets 94 the executable of multiprocessing's context accordingly. 95 96 Raises: 97 RuntimeError: If the binary path cannot be determined. 98 """ 99 # TODO(b/150264776): This does not work with Windows. Find a solution. 100 if sys.argv[0].endswith('.py'): 101 def guess_path(package_root): 102 # If all we have is a python module path, we'll need to make a guess for 103 # the actual executable path. 104 if 'bazel-out' in sys.argv[0] and package_root in sys.argv[0]: 105 # Guess the binary path under bazel. For target 106 # //tensorflow/python/distribute:input_lib_test_multiworker_gpu, the 107 # argv[0] is in the form of 108 # /.../tensorflow/python/distribute/input_lib_test.py 109 # and the binary is 110 # /.../tensorflow/python/distribute/input_lib_test_multiworker_gpu 111 package_root_base = sys.argv[0][:sys.argv[0].rfind(package_root)] 112 binary = os.environ['TEST_TARGET'][2:].replace(':', '/', 1) 113 possible_path = os.path.join(package_root_base, package_root, 114 binary) 115 logging.info('Guessed test binary path: %s', possible_path) 116 if os.access(possible_path, os.X_OK): 117 return possible_path 118 return None 119 path = guess_path('org_tensorflow') 120 if not path: 121 path = guess_path('org_keras') 122 if path is None: 123 logging.error( 124 'Cannot determine binary path. sys.argv[0]=%s os.environ=%s', 125 sys.argv[0], os.environ) 126 raise RuntimeError('Cannot determine binary path') 127 sys.argv[0] = path 128 # Note that this sets the executable for *all* contexts. 129 multiprocessing.get_context().set_executable(sys.argv[0]) 130 131 132def _if_spawn_run_and_exit(): 133 """If spawned process, run requested spawn task and exit. Else a no-op.""" 134 135 # `multiprocessing` module passes a script "from multiprocessing.x import y" 136 # to subprocess, followed by a main function call. We use this to tell if 137 # the process is spawned. Examples of x are "forkserver" or 138 # "semaphore_tracker". 139 is_spawned = ('-c' in sys.argv[1:] and 140 sys.argv[sys.argv.index('-c') + 141 1].startswith('from multiprocessing.')) 142 143 if not is_spawned: 144 return 145 cmd = sys.argv[sys.argv.index('-c') + 1] 146 # As a subprocess, we disregarding all other interpreter command line 147 # arguments. 148 sys.argv = sys.argv[0:1] 149 150 # Run the specified command - this is expected to be one of: 151 # 1. Spawn the process for semaphore tracker. 152 # 2. Spawn the initial process for forkserver. 153 # 3. Spawn any process as requested by the "spawn" method. 154 exec(cmd) # pylint: disable=exec-used 155 sys.exit(0) # Semaphore tracker doesn't explicitly sys.exit. 156 157 158def test_main(): 159 """Main function to be called within `__main__` of a test file.""" 160 global _test_main_called 161 _test_main_called = True 162 163 os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' 164 165 if _is_enabled(): 166 _set_spawn_exe_path() 167 _if_spawn_run_and_exit() 168 169 # Only runs test.main() if not spawned process. 170 test.main() 171 172 173def initialized(): 174 """Returns whether the module is initialized.""" 175 return _test_main_called 176