1# Lint as: python3 2# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Multi-process runner for testing purpose.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23import contextlib 24import json 25import os 26import signal 27import sys 28import threading 29import time 30import unittest 31import weakref 32 33from absl import logging 34import six 35from six.moves import queue as Queue 36 37from tensorflow.python import tf2 38from tensorflow.python.compat import v2_compat 39from tensorflow.python.distribute import multi_process_lib 40from tensorflow.python.eager import context 41from tensorflow.python.util.tf_export import tf_export 42 43multiprocessing = multi_process_lib.multiprocessing 44 45# pylint: disable=g-import-not-at-top 46try: 47 # `faulthandler` is not available in py2. 48 import faulthandler 49except ImportError: 50 faulthandler = None 51 52# TODO(b/150264776): Remove after resolving CI issue. 53try: 54 import dill 55except ImportError: 56 dill = None 57 58# TODO(b/150264776): Remove after resolving CI issue. 59try: 60 import tblib.pickling_support 61 # For pickling traceback objects. 62 tblib.pickling_support.install() 63except ImportError: 64 pass 65 66 67# _ProcessStatusInfo contains process status information. When is_successful 68# attribute is True, the subprocess has ended successfully, or if False, the 69# exception stack trace info is stored in exc_info to pass on to parent process 70# to be re-raised. 71_ProcessStatusInfo = collections.namedtuple( 72 '_ProcessStatusInfo', 73 ['task_type', 'task_id', 'is_successful', 'exc_info', 'return_value']) 74 75# Information returned from a successful MultiProcessRunner run. 76MultiProcessRunnerResult = collections.namedtuple('MultiProcessRunnerResult', 77 ['return_value', 'stdout']) 78 79TestEnvironment = collections.namedtuple('TestEnvironment', [ 80 'task_type', 'task_id', 'cluster_spec', 'rpc_layer', 'grpc_fail_fast', 81 'v2_enabled', 'executing_eagerly' 82]) 83 84# Resources for communication between worker processes and the main process. 85# 86# `process_status_queue` is used by `multi_process_runner` internally for 87# communication from subprocesses to the parent process for whether it's been 88# successful, and if not what the error stack trace is. 89# `parent_to_sub_queue` is used for communications from parent to subprocess. 90# Currently this is only used to terminate subprocesses. 91# TODO(rchao): Remove this once subprocess is terminated by SIGKILL. 92# `streaming_pipe_w` is to stream stdout and stderr from subprocesses to parent 93# process. 94# `barrier` is a barrier for the party of all subprocesses. 95Resources = collections.namedtuple('Resources', [ 96 'process_status_queue', 'parent_to_sub_queue', 'streaming_pipe_w', 'barrier' 97]) 98 99# Default time out sec is selected so that it's handled before the default 100# "medium" timeout of the test runs. 101_DEFAULT_TIMEOUT_SEC = 200 102 103# The timeout in seconds to wait to force kill a child process. When a child 104# process times out we first try to SIGTERM it so that it has a chance to dump 105# stacktraces. However dumping stacktrace can take a long time. 106_FORCE_KILL_WAIT_SEC = 30 107 108 109class MultiProcessRunner(object): 110 """A utility class to start multiple processes to simulate a cluster. 111 112 We need to use multiple processes to simulate a cluster in TF 2.0 tests 113 because TF 2.0 has some process-global data structures that have to be 114 separated by processes. We also need child processes to test out our fault 115 tolerance because shutting down a standard TensorFlow server within its 116 process is not supported. 117 118 Note: the main test program that uses this runner class must run main program 119 via `test_main` defined in this file. Using this runner in non-test binaries 120 is not supported yet. 121 122 This class is not thread-safe. Child processes will inherit TF2 behavior flag. 123 """ 124 125 def __init__(self, 126 fn, 127 cluster_spec, 128 rpc_layer=None, 129 max_run_time=None, 130 grpc_fail_fast=None, 131 stream_output=True, 132 return_output=False, 133 use_dill_for_args=True, 134 daemon=False, 135 dependence_on_chief=True, 136 auto_restart=False, 137 args=None, 138 kwargs=None): 139 """Instantiation of a `MultiProcessRunner`. 140 141 Args: 142 fn: Function to be run on child processes. This will be run on processes 143 for all task types. 144 cluster_spec: Dict for cluster spec. The utility function 145 `tf.__internal__.distribute.multi_process_runner.create_cluster_spec` 146 can be conveniently used to create such dict. The following is an 147 example of cluster with three workers and two ps's. 148 {"worker": ["worker0.example.com:2222", 149 "worker1.example.com:2222", 150 "worker2.example.com:2222"], 151 "ps": ["ps0.example.com:2222", 152 "ps1.example.com:2222"]} 153 rpc_layer: RPC layer to use. Default value is 'grpc'. 154 max_run_time: `None` or integer. If not `None`, child processes are forced 155 to exit at approximately this many seconds after this utility is called. 156 We achieve this through `signal.alarm()` api. Note that this is best 157 effort at Python level since Python signal handler does not get executed 158 when it runs lower level C/C++ code. So it can be delayed for 159 arbitrarily long time. If any of the child process is still running when 160 `max_run_time` is up, they will be force-terminated and an 161 `UnexpectedSubprocessExitError` may be raised. If `None`, child 162 processes are not forced to exit. 163 grpc_fail_fast: Whether GRPC connection between processes should fail 164 without retrying. Defaults to None, in which case the environment 165 variable is not explicitly set. 166 stream_output: True if the output/error from the subprocesses should be 167 streamed to be printed in parent process' log. Defaults to True. 168 return_output: If True, the output/error from the subprocesses should be 169 collected to be attached to the resulting namedtuple returned from 170 `join()`. The list of output can be retrieved via `stdout` attribute. 171 Defaults to False. 172 use_dill_for_args: Whether to use dill to pickle `args` and `kwargs`. dill 173 can pickle more objects, but doesn't work with types in 174 `multiprocessing` library like `Mutex`. 175 daemon: Whether to start processes as daemons. 176 dependence_on_chief: Whether to terminates the cluster if the chief exits. 177 If auto_restart is True, it only terminates the cluster if the chief 178 exits with a zero exit code. 179 auto_restart: Whether to automatically restart processes that exit with 180 non-zero exit code. 181 args: Positional arguments to be sent to `fn` run on subprocesses. 182 kwargs: Keyword arguments to be sent to `fn` run on subprocesses. 183 184 Raises: 185 RuntimeError: if `multi_process_runner.test_main()` is not called. 186 ValueError: if there are more than one chief in the `cluster_spec`. 187 """ 188 189 assert cluster_spec is not None 190 if 'chief' in cluster_spec and len(cluster_spec['chief']) > 1: 191 raise ValueError('If chief exists in the cluster, there must be at most ' 192 'one chief. Current `cluster_spec` has {} chiefs.' 193 .format(len(cluster_spec['chief']))) 194 if not multi_process_lib.initialized(): 195 raise NotInitializedError( 196 '`multi_process_runner` is not initialized. ' 197 'Please call `tf.__internal__.distribute.multi_process_runner.' 198 'test_main()` within `if __name__ == \'__main__\':` block ' 199 'in your python module to properly initialize ' 200 '`multi_process_runner`.') 201 if not callable(fn): 202 raise ValueError('fn is not a callable') 203 204 self._fn = fn 205 self._cluster_spec = cluster_spec 206 self._rpc_layer = rpc_layer or 'grpc' 207 self._max_run_time = max_run_time 208 self._grpc_fail_fast = grpc_fail_fast 209 self._stream_output = stream_output 210 # TODO(rchao): Revisit return_output argument to consider other solution. 211 self._return_output = return_output 212 self._dependence_on_chief = dependence_on_chief 213 self._use_dill_for_args = use_dill_for_args 214 self._daemon = daemon 215 self._auto_restart = auto_restart 216 self._args = args or () 217 self._kwargs = kwargs or {} 218 219 # Child processes should have the same v2 and eager behavior. 220 self._v2_enabled = tf2.enabled() 221 self._executing_eagerly = context.executing_eagerly() 222 223 self._joined = False 224 self._process_lock = threading.Lock() 225 # Guarded by self._process_lock. 226 self._processes = {} 227 # Record which processes are terminated. Due to a bug in Python<3.7, 228 # terminated processes return 255 exit code, which should cause an exception 229 # in join(). 230 # https://bugs.python.org/issue30589 231 # Guarded by self._process_lock. 232 self._terminated = set() 233 self._reading_threads = [] 234 235 self._manager = manager() 236 self._process_status_queue = self._manager.Queue() 237 self._parent_to_sub_queue = self._manager.Queue() 238 parties = sum(len(addresses) for addresses in self._cluster_spec.values()) 239 self._barrier = self._manager.Barrier(parties) 240 241 # We use a queue to collect outputs from worker processes since it's thread 242 # safe. 243 self._streaming_queue = self._manager.Queue() 244 245 self._watchdog_thread = None 246 247 def set_args(self, args=None, kwargs=None): 248 self._args = args or self._args 249 self._kwargs = kwargs or self._kwargs 250 251 def _continuously_readline_from_sub(self, pipe_r, task_type, task_id): 252 """Function to continuously read lines from subprocesses.""" 253 with os.fdopen(pipe_r.fileno(), 'r', closefd=False) as reader: 254 for line in reader: 255 task_string = '[{}-{}]:'.format(task_type, task_id) 256 formatted_line = '{} {}'.format(task_string.ljust(14), line) 257 if self._stream_output: 258 # TODO(rchao): Use a lock here to ensure the printed lines are not 259 # broken. 260 print(formatted_line, end='', flush=True) 261 if self._return_output: 262 self._streaming_queue.put(formatted_line) 263 264 def _start_subprocess_and_reading_thread(self, 265 task_type, 266 task_id, 267 cluster_spec=None, 268 fn=None, 269 args=None, 270 kwargs=None): 271 """Start a subprocess and a thread the reads lines from the subprocess.""" 272 273 if dill is None: 274 raise unittest.SkipTest( 275 'TODO(b/150264776): Resolve dependency issue in CI') 276 277 test_env = TestEnvironment( 278 task_type=task_type, 279 task_id=task_id, 280 cluster_spec=cluster_spec or self._cluster_spec, 281 rpc_layer=self._rpc_layer, 282 grpc_fail_fast=self._grpc_fail_fast, 283 v2_enabled=self._v2_enabled, 284 executing_eagerly=self._executing_eagerly, 285 ) 286 pipe_r, pipe_w = multiprocessing.Pipe(duplex=False) 287 resources = Resources( 288 process_status_queue=self._process_status_queue, 289 parent_to_sub_queue=self._parent_to_sub_queue, 290 streaming_pipe_w=pipe_w, 291 barrier=self._barrier, 292 ) 293 if fn is None: 294 fn, args, kwargs = self._fn, self._args, self._kwargs 295 # Always use dill to pickle fn so that we support more callable 296 # types, e.g. lambda. 297 fn = dill.dumps(fn, dill.HIGHEST_PROTOCOL) 298 if self._use_dill_for_args: 299 args = dill.dumps(args, dill.HIGHEST_PROTOCOL) 300 kwargs = dill.dumps(kwargs, dill.HIGHEST_PROTOCOL) 301 302 p = _Process( 303 test_env=test_env, 304 target=_ProcFunc(), 305 args=(resources, test_env, fn, args, kwargs, self._use_dill_for_args), 306 daemon=self._daemon) 307 p.start() 308 self._processes[(task_type, task_id)] = p 309 self._terminated.discard((task_type, task_id)) 310 311 # For each subprocess, we dedicate a thread continuously reading lines 312 # from them. 313 thread = threading.Thread( # pylint: disable=unexpected-keyword-arg 314 target=self._continuously_readline_from_sub, 315 args=(pipe_r, task_type, task_id)) 316 thread.start() 317 self._reading_threads.append(thread) 318 319 if self._watchdog_thread is None or not self._watchdog_thread.is_alive(): 320 self._watchdog_thread = threading.Thread(target=self._process_watchdog) 321 self._watchdog_thread.start() 322 323 def start(self): 324 """Starts processes, one for each task in `cluster_spec`. 325 326 Note that this is best effort by the applicable multiprocessing library, 327 and it may take up to seconds for a subprocess to be successfully started. 328 """ 329 with self._process_lock: 330 if self._processes: 331 raise ValueError('MultiProcessRunner already started.') 332 if self._joined: 333 raise ValueError('cannot start new processes after' 334 'MultiProcessRunner.join() is called') 335 336 for task_type, addresses in self._cluster_spec.items(): 337 for task_id, _ in enumerate(addresses): 338 self._start_subprocess_and_reading_thread(task_type, task_id) 339 340 # TODO(rchao): Remove the need of using SIGALRM if possible. At this time, 341 # without this the tests become very flaky. 342 if self._max_run_time is not None: 343 344 def handler(signum, frame): 345 del signum, frame 346 self.terminate_all() 347 348 signal.signal(signal.SIGALRM, handler) 349 signal.alarm(self._max_run_time) 350 351 def start_in_process_as(self, as_task_type, as_task_id): 352 """Start the processes, with the specified task run in main process. 353 354 This is similar to `start()` except that the task with task_type 355 `as_task_type` and task_id `as_task_id` is run in the main process. 356 This method is particularly useful when debugging tool such as `pdb` is 357 needed in some specific task. Note that since this method is blocking until 358 that specific task exits, additional actions would need a thread to be 359 called: 360 361 ```python 362 def fn(): 363 # user code to be run 364 import pdb; pdb.set_trace() 365 366 def follow_ups(): 367 time.sleep(5) 368 mpr.start_single_process( 369 task_type='evaluator', 370 task_id=0) 371 372 mpr = multi_process_runner.MultiProcessRunner( 373 fn, 374 multi_worker_test_base.create_cluster_spec( 375 has_chief=True, num_workers=1)) 376 threading.Thread(target=follow_ups).start() 377 mpr.start_in_process_as(as_task_type='chief', as_task_id=0) 378 mpr.join() 379 ``` 380 381 Note that if `return_output=True`, the logs/stdout by task 382 run by the main process is not available in result.stdout. 383 384 Args: 385 as_task_type: The task type to be run in the main process. 386 as_task_id: The task id to be run in the main process. 387 """ 388 if self._processes: 389 raise ValueError('MultiProcessRunner already started.') 390 with self._process_lock: 391 if self._joined: 392 raise ValueError('cannot start new processes after' 393 'MultiProcessRunner.join() is called') 394 for task_type, addresses in self._cluster_spec.items(): 395 for task_id, _ in enumerate(addresses): 396 if not (task_type == as_task_type and task_id == as_task_id): 397 self._start_subprocess_and_reading_thread(task_type, task_id) 398 399 _set_tf_config(as_task_type, as_task_id, self._cluster_spec, 400 self._rpc_layer) 401 self._fn(*self._args, **self._kwargs) 402 403 def start_single_process(self, 404 task_type, 405 task_id, 406 cluster_spec=None, 407 fn=None, 408 args=None, 409 kwargs=None): 410 """Starts a single process. 411 412 This starts a process in the cluster with the task type, task id, and the 413 process function (`fn`). If process function is `None`, the function 414 provided at `__init__` will be used. If `cluster_spec` is `None`, the 415 cluster spec provided at `__init__` will be used. 416 417 TODO(rchao): It is meant that all subprocesses will be updated with the new 418 cluster spec, but this has yet to be implemented. At this time only the 419 newly started subprocess picks up this updated cluster spec. 420 421 Args: 422 task_type: The task type. 423 task_id: The task id. 424 cluster_spec: The cluster spec to be used on the newly started 425 process. If `None`, the cluster spec provided at `__init__` will be 426 used. 427 fn: The process function to be run on the newly started 428 process. If specified, specify `args` and `kwargs` as well. If `None`, 429 the function provided at `__init__` will be used. 430 args: Optional positional arguments to be supplied in `fn`. 431 kwargs: Optional keyword arguments to be supplied in `fn`. 432 """ 433 with self._process_lock: 434 if self._joined: 435 raise ValueError('cannot start new processes after' 436 'MultiProcessRunner.join() is called') 437 self._start_subprocess_and_reading_thread( 438 task_type, 439 task_id, 440 cluster_spec=cluster_spec, 441 fn=fn, 442 args=args or (), 443 kwargs=kwargs or {}) 444 445 def _queue_to_list(self, queue_to_convert): 446 """Convert `queue.Queue` to `list`.""" 447 list_to_return = [] 448 # Calling `queue.empty()` is not reliable. 449 while True: 450 try: 451 list_to_return.append(queue_to_convert.get(block=False)) 452 except Queue.Empty: 453 break 454 return list_to_return 455 456 def _get_process_statuses(self): 457 # One worker may have multiple statuses. We only keep the last one. 458 statuses = {} 459 for status in self._queue_to_list(self._process_status_queue): 460 statuses[(status.task_type, status.task_id)] = status 461 return statuses 462 463 def get_process_id(self, task_type, task_id): 464 """Returns the subprocess id given the task type and task id.""" 465 with self._process_lock: 466 p = self._processes.get((task_type, task_id), None) 467 return p.pid if p else None 468 469 def get_process_exit_code(self, task_type, task_id): 470 """Returns the subprocess exit code given the task type and task id. 471 472 Args: 473 task_type: The task type. 474 task_id: The task id. 475 476 Returns: 477 The subprocess exit code; `None` if the subprocess has not exited yet. 478 479 Raises: 480 KeyError: If the corresponding subprocess is not found with `task_type` 481 and `task_id`. 482 """ 483 with self._process_lock: 484 p = self._processes[(task_type, task_id)] 485 return p.exitcode if p else None 486 487 def process_exists(self, task_type, task_id): 488 """Returns whether the subprocess still exists given the task type and id. 489 490 Args: 491 task_type: The task type. 492 task_id: The task id. 493 494 Returns: 495 Boolean; whether the subprocess still exists. If the subprocess has 496 exited, this returns False. 497 """ 498 return self.get_process_exit_code(task_type, task_id) is None 499 500 def _process_watchdog(self): 501 """Simulates a cluster management system. 502 503 - If auto_restart is True, it restarts processes that exit with a non-zero 504 exit code. Note that when join() times out it overrides auto_restart to 505 False. 506 - If dependence_on_chief is True, it terminates all processes once the chief 507 exits. If auto_restart is also True, it only terminates all processes if 508 the chief exit with a zero exit code, otherwise it restarts the chief. 509 510 This runs in self._watchdog_thread. 511 """ 512 while True: 513 time.sleep(1) 514 with self._process_lock: 515 chief = self._processes.get(('chief', 0), None) 516 # Terminate the cluster when _dependence_on_chief is True if either: 517 # - chief has exited with zero exit code. 518 # - chief has exited with non-zero exit code and self._auto_restart is 519 # False. 520 if chief and self._dependence_on_chief and chief.exitcode is not None: 521 if chief.exitcode == 0 or (not self._auto_restart): 522 for p in self._processes.values(): 523 # Give other processes a chance to exit on their own. 524 p.join(timeout=3) 525 self._terminate_all() 526 for p in self._processes.values(): 527 p.join() 528 return 529 530 # Auto restart failed processes if self._auto_restart is True. 531 if self._auto_restart: 532 has_failure = False 533 for (task_type, task_id), p in self._processes.items(): 534 if p.exitcode is not None and p.exitcode != 0: 535 has_failure = True 536 logging.info('Restarting failed %s-%d', task_type, task_id) 537 self._start_subprocess_and_reading_thread(task_type, task_id) 538 if has_failure: 539 continue 540 541 # Exit the thread if all processes have exited at this point. 542 if all(p.exitcode is not None for p in self._processes.values()): 543 return 544 545 def _reraise_if_subprocess_error(self, process_statuses): 546 for process_status in process_statuses.values(): 547 assert isinstance(process_status, _ProcessStatusInfo) 548 if not process_status.is_successful: 549 process_status.exc_info[1].mpr_result = self._get_mpr_result( 550 process_statuses) 551 six.reraise(*process_status.exc_info) 552 553 def join(self, timeout=_DEFAULT_TIMEOUT_SEC): 554 """Joins all the processes with timeout. 555 556 If any of the subprocesses does not exit approximately after `timeout` 557 seconds has passed after `join` call, this raises a 558 `SubprocessTimeoutError`. 559 560 Note: At timeout, it uses SIGTERM to terminate the subprocesses, in order to 561 log the stack traces of the subprocesses when they exit. However, this 562 results in timeout when the test runs with tsan (thread sanitizer); if tsan 563 is being run on the test targets that rely on timeout to assert information, 564 `MultiProcessRunner.terminate_all()` must be called after `join()`, before 565 the test exits, so the subprocesses are terminated with SIGKILL, and data 566 race is removed. 567 568 Args: 569 timeout: optional integer or `None`. If provided as an integer, and not 570 all processes report status within roughly `timeout` seconds, a 571 `SubprocessTimeoutError` exception will be raised. If `None`, `join` never 572 times out. 573 574 Returns: 575 A `MultiProcessRunnerResult` object, which has two attributes, 576 `return_value` and `stdout`. `return_value` always contains a list of 577 return values from the subprocesses, although the order is not meaningful. 578 If `return_output` argument is True at `__init__`, `stdout` is available 579 that contains a list of all messages from subprocesses' stdout and stderr. 580 581 Raises: 582 SubprocessTimeoutError: if not all processes report status approximately 583 within `timeout` seconds. When this is raised, a 584 `MultiProcessRunnerResult` object can be retrieved by 585 `SubprocessTimeoutError`'s mpr_result attribute, which has the same 586 structure as above 'Returns' section describes. 587 UnexpectedSubprocessExitError: If any of the subprocesses did not exit 588 properly (for example, they exit on SIGTERM or SIGKILL signal). When 589 this is raised, a `MultiProcessRunnerResult` object can be retrieved by 590 `UnexpectedSubprocessExitError`'s mpr_result attribute, which has the 591 same structure as above 'Returns' section describes. If `max_run_time` 592 is not `None`, it is expected that some subprocesses may be 593 force-killed when `max_run_time` is up, and this is raised in those 594 cases. 595 Exception: if there is an Exception propagated from any subprocess. When 596 this is raised, a `MultiProcessRunnerResult` object can be retrieved by 597 `UnexpectedSubprocessExitError`'s mpr_result attribute, which has the 598 same structure as above 'Returns' section describes. 599 """ 600 if timeout and not isinstance(timeout, int): 601 raise ValueError('`timeout` must be an integer or `None`.') 602 with self._process_lock: 603 if self._joined: 604 raise ValueError("MultiProcessRunner can't be joined twice.") 605 self._joined = True 606 607 self._watchdog_thread.join(timeout) 608 if self._watchdog_thread.is_alive(): 609 # Timeout. Force termination to dump worker processes stack trace. 610 with self._process_lock: 611 self._auto_restart = False 612 logging.error('Timeout when joining for child processes. Terminating...') 613 self.terminate_all(sig=signal.SIGTERM) 614 # Wait for the processes to terminate by themselves first, so they have a 615 # chance to dump stacktraces. After _FORCE_KILL_WAIT_SEC, we SIGKILL them. 616 self._watchdog_thread.join(_FORCE_KILL_WAIT_SEC) 617 if self._watchdog_thread.is_alive(): 618 logging.error('Timeout when waiting for child processes to ' 619 'print stacktrace. Sending SIGKILL...') 620 self.terminate_all() 621 self._watchdog_thread.join() 622 process_statuses = self._get_process_statuses() 623 self._reraise_if_subprocess_error(process_statuses) 624 raise SubprocessTimeoutError( 625 'One or more subprocesses timed out, where timeout was set to {}s. ' 626 'Please change the `timeout` argument for ' 627 '`MultiProcessRunner.join()` or `multi_process_runner.run()` ' 628 'if it should be adjusted.'.format(timeout), 629 self._get_mpr_result(process_statuses)) 630 631 for (task_type, task_id), p in self._processes.items(): 632 logging.info('%s-%d exit code: %s', task_type, task_id, p.exitcode) 633 634 process_statuses = self._get_process_statuses() 635 self._reraise_if_subprocess_error(process_statuses) 636 637 # Checking all the processes that are expected to exit properly. 638 for (task_type, task_id), p in self._processes.items(): 639 # Successfully exiting process has exit code 0. We ignore processes that 640 # are terminated. 641 assert p.exitcode is not None 642 if (p.exitcode > 0 and (task_type, task_id) not in self._terminated): 643 raise UnexpectedSubprocessExitError( 644 'Subprocess %s-%d exited with exit code %s. See logs for details.' 645 % (task_type, task_id, p.exitcode), 646 self._get_mpr_result(process_statuses)) 647 648 logging.info('Joining log reading threads.') 649 for thread in self._reading_threads: 650 thread.join() 651 logging.info('Joined log reading threads.') 652 653 # Clear the alarm. 654 signal.alarm(0) 655 656 return self._get_mpr_result(process_statuses) 657 658 def _get_mpr_result(self, process_statuses): 659 stdout = self._queue_to_list(self._streaming_queue) 660 return_values = [] 661 for process_status in process_statuses.values(): 662 if process_status.return_value is not None: 663 return_values.append(process_status.return_value) 664 return MultiProcessRunnerResult(stdout=stdout, return_value=return_values) 665 666 def terminate(self, task_type, task_id): 667 """Terminates the process with `task_type` and `task_id`. 668 669 If auto_retart=True, the terminated task will be restarted unless the chief 670 has already exited with zero exit code. 671 672 Args: 673 task_type: the task type. 674 task_id: the task id. 675 676 """ 677 with self._process_lock: 678 p = self._processes.get((task_type, task_id), None) 679 if p is None: 680 raise ValueError('{}-{} does not exist'.format(task_type, task_id)) 681 self._terminated.add((task_type, task_id)) 682 # TODO(crccw): change to use Process.terminate() as well. 683 self._parent_to_sub_queue.put('terminate {} {}'.format( 684 task_type, task_id)) 685 p.join() 686 687 def _terminate_all(self, sig=None): 688 """Terminates all subprocesses. 689 690 The caller is required to hold self._process_lock. 691 692 Args: 693 sig: the signal used to terminate the process. The default is SIGKILL. 694 """ 695 696 # Use SIGKILL as default. In systems where that's unavailable such as 697 # windows, use SIGTERM. 698 sig = sig or getattr(signal, 'SIGKILL', signal.SIGTERM) 699 for (task_type, task_id), p in self._processes.items(): 700 if p.exitcode is not None: 701 logging.info('%s-%d has already exited. Not terminating.', task_type, 702 task_id) 703 continue 704 try: 705 os.kill(p.pid, sig) 706 self._terminated.add((task_type, task_id)) 707 logging.info('%s-%d terminated with signal %r.', task_type, task_id, 708 sig) 709 except ProcessLookupError: 710 logging.info('Attempting to kill %s-%d but it does not exist.', 711 task_type, task_id) 712 713 def terminate_all(self, sig=None): 714 """Terminates all subprocesses.""" 715 with self._process_lock: 716 self._terminate_all(sig) 717 718 719class _Process(multi_process_lib.Process): 720 """A modified `multiprocessing.Process` that can set up environment variables.""" 721 722 # TODO(crccw): consider moving other logics in _ProcFunc to _Process. 723 724 def __init__(self, test_env, **kwargs): 725 super(_Process, self).__init__(**kwargs) 726 self._test_env = test_env 727 self._actual_run = getattr(self, 'run') 728 self.run = self._run_with_setenv 729 730 def _run_with_setenv(self): 731 # We need to set environment variables before doing anything because 732 # setenv() is not thread-safe. 733 test_env = self._test_env 734 if test_env.grpc_fail_fast is not None: 735 os.environ['GRPC_FAIL_FAST'] = str(test_env.grpc_fail_fast) 736 _set_tf_config(test_env.task_type, test_env.task_id, test_env.cluster_spec, 737 test_env.rpc_layer) 738 return self._actual_run() 739 740 741class _ProcFunc(object): 742 """Represents a callable to run in a subprocess.""" 743 744 @contextlib.contextmanager 745 def _runtime_mode(self, executing_eagerly): 746 if executing_eagerly: 747 with context.eager_mode(): 748 yield 749 else: 750 with context.graph_mode(): 751 yield 752 753 def _message_checking_func(self, task_type, task_id): 754 """A function that regularly checks messages from parent process.""" 755 # TODO(rchao): Remove this once parent uses SIGKILL to terminate subprocess. 756 while True: 757 try: 758 message = self._resources.parent_to_sub_queue.get(block=False) 759 760 # Currently the only possible message is termination. 761 if not message.startswith('terminate'): 762 raise ValueError('Unrecognized message: {}'.format(message)) 763 764 if message == 'terminate {} {}'.format(task_type, task_id): 765 break 766 else: 767 # If the message is not targeting this process, put it back to the 768 # queue. 769 self._resources.parent_to_sub_queue.put(message) 770 time.sleep(1) 771 except Queue.Empty: 772 time.sleep(0.1) 773 self._resources.process_status_queue.put( 774 _ProcessStatusInfo( 775 task_type=task_type, 776 task_id=task_id, 777 is_successful=True, 778 exc_info=None, 779 return_value=None)) 780 # `os._exit(1)` is used to more reliably terminate a subprocess. 781 os._exit(1) # pylint: disable=protected-access 782 783 def _close_streaming(self): 784 """Close stdout, stderr and streaming pipe. 785 786 We need to explicitly close them since Tensorflow may take a while to exit, 787 so that the reading threads in the main process can exit more quickly. 788 """ 789 sys.stdout.flush() 790 sys.stderr.flush() 791 sys.stdout.close() 792 sys.stderr.close() 793 self._resources.streaming_pipe_w.close() 794 795 def __call__(self, resources, test_env, fn, args, kwargs, use_dill_for_args): 796 """The wrapper function that actually gets run in child process(es).""" 797 798 global _barrier 799 800 self._resources = resources 801 _barrier = self._resources.barrier 802 fn = dill.loads(fn) 803 if use_dill_for_args: 804 args = dill.loads(args) 805 kwargs = dill.loads(kwargs) 806 807 if faulthandler is not None: 808 faulthandler.enable() 809 faulthandler.register(signal.SIGTERM, chain=True) 810 811 # All logging should go to stderr to be streamed to the main process. 812 logging.set_stderrthreshold(logging.DEBUG) 813 814 # Assign sys.stdout and sys.stderr as duplicates of `streaming_pipe_w` so 815 # print() and logging.*() write directly to `streaming_pipe_w`. 816 # Unfortunately since we cannot prepend task_type and task_id information to 817 # the streamed logs we will need a thread per subprocess to distinguish 818 # where the piece of message is from. 819 os.dup2(resources.streaming_pipe_w.fileno(), sys.stdout.fileno()) 820 os.dup2(resources.streaming_pipe_w.fileno(), sys.stderr.fileno()) 821 822 pid = os.getpid() 823 logging.info('Subprocess with PID %d (%s, %d) is now being started.', pid, 824 test_env.task_type, test_env.task_id) 825 826 # The thread will be dedicated to checking messages from the parent process. 827 threading.Thread( # pylint: disable=unexpected-keyword-arg 828 target=self._message_checking_func, 829 args=(test_env.task_type, test_env.task_id), 830 daemon=True).start() 831 832 if test_env.v2_enabled: 833 v2_compat.enable_v2_behavior() 834 835 with self._runtime_mode(test_env.executing_eagerly): 836 info = _run_contained(test_env.task_type, test_env.task_id, fn, args, 837 kwargs) 838 self._resources.process_status_queue.put(info) 839 840 # Re-raise the exception in addition to reporting it to the parent 841 # process, so that even if `--test_timeout` flag is set and the 842 # error doesn't make it to be shown in parent process before bazel's 843 # timeout, the log would still show what happens in this subprocess, 844 # instead of silently suppressing the error due to early bazel 845 # timeout. Raising an error in the subprocess produces stack trace in 846 # the log, but the program continues running. 847 if not info.is_successful: 848 six.reraise(*info.exc_info) 849 850 self._close_streaming() 851 852 # Exit with code 0 as it's considered successful exit at this point. 853 sys.exit(0) 854 855 856# Active MultiProcessPoolRunner. We need to shut them down when the program 857# exits, and this is by setting the `tearDownModule` of the module containing 858# `__main__`. Note this it set in both the parent process and the subprocesses. 859_active_pool_runners = weakref.WeakSet() 860 861 862def _shutdown_all_pool_runners(): 863 for pool in _active_pool_runners: 864 pool.shutdown() 865 866 867def is_oss(): 868 """Returns whether the test is run under OSS.""" 869 return len(sys.argv) >= 1 and 'bazel' in sys.argv[0] 870 871 872class MultiProcessPoolRunner(object): 873 """A utility class to start a process pool to simulate a cluster. 874 875 It's similar to MultiProcessRunner, but uses a pool of processes to avoid the 876 expensive initialization cost of Tensorflow. 877 """ 878 879 def __init__(self, cluster_spec, initializer=None): 880 """Creates a multi-process pool runner. 881 882 Args: 883 cluster_spec: Dict for cluster spec. The following is an example of 884 cluster with three workers. 885 {"worker": ["worker0.example.com:2222", 886 "worker1.example.com:2222", 887 "worker2.example.com:2222"]} 888 initializer: a callable to called at the startup of worker processes. 889 890 Raises: 891 RuntimeError: if `multi_process_runner.test_main()` is not called. 892 ValueError: if there are more than one chief in the `cluster_spec`. 893 """ 894 _active_pool_runners.add(self) 895 self._cluster_spec = cluster_spec 896 self._initializer = initializer 897 self._conn = {} 898 self._runner = None 899 900 def __del__(self): 901 self.shutdown() 902 903 def shutdown(self): 904 """Shuts down the worker pool.""" 905 for conn in self._conn.values(): 906 conn.close() 907 self._conn = {} 908 if self._runner is not None: 909 try: 910 self._runner.join() 911 except Exception as e: # pylint: disable=broad-except 912 logging.error( 913 'Ignoring exception when shutting down MultiProcessPoolRunner: %s', 914 e) 915 self._runner = None 916 917 def _start(self): 918 """Starts the worker pool.""" 919 # We need different arguments for different processes so we're passing a 920 # no-op fn here and use start_single_process instead. 921 922 if dill is None: 923 raise unittest.SkipTest( 924 'TODO(b/150264776): Resolve dependency issue in CI') 925 926 self._runner = MultiProcessRunner( 927 fn=lambda: None, 928 cluster_spec=self._cluster_spec, 929 use_dill_for_args=False) 930 if self._initializer: 931 initializer = dill.dumps(self._initializer, dill.HIGHEST_PROTOCOL) 932 else: 933 initializer = None 934 for task_type, addresses in self._cluster_spec.items(): 935 for task_id, _ in enumerate(addresses): 936 conn1, conn2 = multiprocessing.Pipe(duplex=True) 937 self._conn[(task_type, task_id)] = conn1 938 self._runner.start_single_process( 939 task_type, 940 task_id, 941 fn=_pool_runner_worker, 942 args=(task_type, task_id, initializer, conn2)) 943 944 def run(self, fn, args=None, kwargs=None): 945 """Runs `fn` with `args` and `kwargs` on all jobs. 946 947 Args: 948 fn: The function to be run. 949 args: Optional positional arguments to be supplied in `fn`. 950 kwargs: Optional keyword arguments to be supplied in `fn`. 951 952 Returns: 953 A list of return values. 954 """ 955 # TODO(b/150264776): skip in OSS until it's implemented. 956 multi_process_lib.Process() 957 if self._runner is None: 958 self._start() 959 960 fn = dill.dumps(fn, dill.HIGHEST_PROTOCOL) 961 for conn in self._conn.values(): 962 conn.send((fn, args or [], kwargs or {})) 963 964 process_statuses = [] 965 for (task_type, task_id), conn in self._conn.items(): 966 logging.info('Waiting for the result from %s-%d', task_type, task_id) 967 try: 968 process_statuses.append(conn.recv()) 969 except EOFError: 970 # This shouldn't happen due to exceptions in fn. This usually 971 # means bugs in the runner. 972 self.shutdown() 973 raise RuntimeError('Unexpected EOF. Worker process may have died. ' 974 'Please report a bug') 975 976 return_values = [] 977 for process_status in process_statuses: 978 assert isinstance(process_status, _ProcessStatusInfo) 979 if not process_status.is_successful: 980 six.reraise(*process_status.exc_info) 981 if process_status.return_value is not None: 982 return_values.append(process_status.return_value) 983 984 return return_values 985 986 987def _pool_runner_worker(task_type, task_id, initializer, conn): 988 """Function that runs on the workers in a pool. 989 990 It listens for callables to run and returns the result until `conn` is closed. 991 It captures the exceptions during executing the callable and return it through 992 `conn`. 993 994 Args: 995 task_type: the task type. 996 task_id: the task index. 997 initializer: a callable to execute during startup. 998 conn: a multiprocessing.Connection object to listen for tasks and send 999 results. 1000 """ 1001 if initializer: 1002 initializer = dill.loads(initializer) 1003 initializer() 1004 while True: 1005 try: 1006 fn, args, kwargs = conn.recv() 1007 except EOFError: 1008 break 1009 fn = dill.loads(fn) 1010 info = _run_contained(task_type, task_id, fn, args, kwargs) 1011 sys.stdout.flush() 1012 sys.stderr.flush() 1013 conn.send(info) 1014 1015 1016def _run_contained(task_type, task_id, fn, args, kwargs): 1017 """Runs `fn` with `args` and `kwargs`. 1018 1019 The function returns _ProcessStatusInfo which captures the return value and 1020 the exception. 1021 1022 Args: 1023 task_type: the task type. 1024 task_id: the task index. 1025 fn: the function to be run. 1026 args: optional positional arguments to be supplied in `fn`. 1027 kwargs: optional keyword arguments to be supplied in `fn`. 1028 1029 Returns: 1030 a _ProcessStatusInfo. 1031 1032 """ 1033 is_successful = False 1034 return_value = None 1035 exc_info = None 1036 try: 1037 return_value = fn(*args, **kwargs) 1038 is_successful = True 1039 return _ProcessStatusInfo( 1040 task_type=task_type, 1041 task_id=task_id, 1042 is_successful=is_successful, 1043 exc_info=exc_info, 1044 return_value=return_value) 1045 1046 # If `fn` ends up exiting with `sys.exit()`, the `SystemExit` is not 1047 # handled here. 1048 except Exception: # pylint: disable=broad-except 1049 exc_info = sys.exc_info() 1050 return _ProcessStatusInfo( 1051 task_type=task_type, 1052 task_id=task_id, 1053 is_successful=is_successful, 1054 exc_info=exc_info, 1055 return_value=return_value) 1056 1057 1058@tf_export('__internal__.distribute.multi_process_runner' 1059 '.SubprocessTimeoutError', 1060 v1=[]) 1061class SubprocessTimeoutError(RuntimeError): 1062 """An error that indicates there is at least one subprocess timing out. 1063 1064 When this is raised, a namedtuple object representing the multi-process run 1065 result can be retrieved by 1066 `tf.__internal__.distribute.multi_process_runner.SubprocessTimeoutError`'s 1067 `mpr_result` attribute. See 1068 `tf.__internal__.distribute.multi_process_runner.run` for more information. 1069 """ 1070 1071 def __init__(self, msg, mpr_result): 1072 super(SubprocessTimeoutError, self).__init__(msg) 1073 self.mpr_result = mpr_result 1074 1075 1076@tf_export('__internal__.distribute.multi_process_runner' 1077 '.UnexpectedSubprocessExitError', 1078 v1=[]) 1079class UnexpectedSubprocessExitError(RuntimeError): 1080 """An error indicating there is at least one subprocess with unexpected exit. 1081 1082 When this is raised, a namedtuple object representing the multi-process run 1083 result can be retrieved by 1084 `tf.__internal__.distribute.multi_process_runner 1085 .UnexpectedSubprocessExitError`'s 1086 `mpr_result` attribute. See 1087 `tf.__internal__.distribute.multi_process_runner.run` for more information. 1088 """ 1089 1090 def __init__(self, msg, mpr_result): 1091 super(UnexpectedSubprocessExitError, self).__init__(msg) 1092 self.mpr_result = mpr_result 1093 1094 1095@tf_export( 1096 '__internal__.distribute.multi_process_runner.NotInitializedError', v1=[]) 1097class NotInitializedError(RuntimeError): 1098 """An error indicating `multi_process_runner.run` is used without init. 1099 1100 When this is raised, user is supposed to call 1101 `tf.__internal__.distribute.multi_process_runner.test_main()` within 1102 `if __name__ == '__main__':` block to properly initialize 1103 `multi_process_runner.run`. 1104 """ 1105 pass 1106 1107 1108def _set_tf_config(task_type, task_id, cluster_spec, rpc_layer=None): 1109 """Set TF_CONFIG environment variable.""" 1110 tf_config_dict = { 1111 'cluster': cluster_spec, 1112 'task': { 1113 'type': task_type, 1114 'index': task_id, 1115 }, 1116 } 1117 if rpc_layer is not None: 1118 tf_config_dict['rpc_layer'] = rpc_layer 1119 os.environ['TF_CONFIG'] = json.dumps(tf_config_dict) 1120 1121 1122@tf_export('__internal__.distribute.multi_process_runner.run', v1=[]) 1123def run(fn, 1124 cluster_spec, 1125 rpc_layer=None, 1126 max_run_time=None, 1127 return_output=False, 1128 timeout=_DEFAULT_TIMEOUT_SEC, 1129 args=None, 1130 kwargs=None): 1131 """Run `fn` in multiple processes according to `cluster_spec`. 1132 1133 Given a callable `fn`, `tf.__internal__.distribute.multi_process_runner.run` 1134 launches multiple processes, each of which runs `fn`. These processes are 1135 referred to as "subprocesses" or "child processes". Each of those subprocesses 1136 will have their `TF_CONFIG` environment variable set, according to 1137 `cluster_spec` and their task types. The stdout of the subprocesses are 1138 streamed to the main process' and thus available in logs (if `stream_output` 1139 is True), with [type-id] prefix. 1140 1141 `tf.__internal__.distribute.multi_process_runner.run` will block until all 1142 subprocesses have successfully exited, and return a namedtuple object that 1143 represents the run result. This object has a `return_value` attribute, which 1144 is a list that contains subprocesses `fn`'s return values, for those 1145 subprocesses that successfully returned from `fn`. The order of `return_value` 1146 list is not meaningful. If an optional arg `return_output` (default to False) 1147 is set to True, the namedtuple object will have an additional attribute 1148 `stdout`, which is a list containing the stdout of the subprocesses. If any 1149 subprocess' `fn` ends up raising an error, that error will be reraised from 1150 `tf.__internal__.distribute.multi_process_runner.run`, and the aforementioned 1151 namedtuple object will be available through the exception's 1152 `mpr_result` attribute. 1153 1154 This utility is used for simulating running TensorFlow programs across 1155 multiple task types, and each of the task type may contain more than one task 1156 (except for "chief" where more than one task is prohibited). Test coverage of 1157 multi-worker training is the main application of this utility, where code 1158 written for multi-worker training can be realistically covered in unit tests. 1159 1160 Any test module that uses 1161 `tf.__internal__.distribute.multi_process_runner.run()` must call 1162 `tf.__internal__.distribute.multi_process_runner.test_main()` instead of 1163 regular `test.main()` inside `if __name__ == '__main__':` block for proper 1164 initialization. 1165 1166 Args: 1167 fn: Function to be run on child processes. This will be run on processes for 1168 all task types. 1169 cluster_spec: Dict for cluster spec. The utility function 1170 `tf.__internal__.distribute.multi_process_runner.create_cluster_spec` can 1171 be conveniently used to create such dict. The following is an example of 1172 cluster with three workers and two ps's. 1173 {"worker": ["worker0.example.com:2222", 1174 "worker1.example.com:2222", 1175 "worker2.example.com:2222"], 1176 "ps": ["ps0.example.com:2222", 1177 "ps1.example.com:2222"]} 1178 rpc_layer: RPC layer to use. Default value is 'grpc'. 1179 max_run_time: `None` or integer. If not `None`, child processes are forced 1180 to exit at approximately this many seconds after this utility is called. 1181 We achieve this through `signal.alarm()` api. Note that this is best 1182 effort at Python level since Python signal handler does not get executed 1183 when it runs lower level C/C++ code. So it can be delayed for arbitrarily 1184 long time. If any of the child process is still running when 1185 `max_run_time` is up, they will be force-terminated and an 1186 `tf.__internal__.distribute.multi_process_runner 1187 .UnexpectedSubprocessExitError` 1188 may be raised. If `None`, child processes are not forced to exit. 1189 return_output: If True, the output/error from the subprocesses should be 1190 collected to be attached to the resulting namedtuple returned from this 1191 utility. The list of output can be retrieved via `stdout` attribute. 1192 Defaults to False. 1193 timeout: optional integer or `None`. If provided as an integer, and not all 1194 processes report status within roughly `timeout` seconds, a 1195 `tf.__internal__.distribute.multi_process_runner.SubprocessTimeoutError` 1196 exception will be raised. If `None`, 1197 `tf.__internal__.distribute.multi_process_runner.run` never times out. 1198 Defaults to the constant `_DEFAULT_TIMEOUT_SEC` defined in 1199 `multi_process_runner` module. 1200 args: Positional arguments to be sent to `fn` run on subprocesses. 1201 kwargs: Keyword arguments to be sent to `fn` run on subprocesses. 1202 1203 Returns: 1204 A namedtuple object, which has two attributes, 1205 `return_value` and `stdout`. `return_value` always contains a list of 1206 returnvalues from the subprocesses, although the order is not meaningful. 1207 If `return_output` argument is True, `stdout` is available that contains a 1208 list of all messages from subprocesses' stdout and stderr, and the order 1209 is mostly chronological. 1210 1211 Raises: 1212 RuntimeError: if 1213 `tf.__internal__.distribute.multi_process_runner.test_main()` is 1214 not called in test's `if __name__ == '__main__':` block. 1215 ValueError: if there are more than one chief in the `cluster_spec`. 1216 tf.__internal__.distribute.multi_process_runner.SubprocessTimeoutError: if 1217 not all processes report status approximately 1218 within `timeout` seconds. When this is raised, a 1219 namedtuple object can be retrieved by 1220 `tf.__internal__.distribute.multi_process_runner.SubprocessTimeoutError`'s 1221 `mpr_result` attribute, which has the same 1222 structure as above 'Returns' section describes. 1223 tf.__internal__.distribute.multi_process_runner 1224 .UnexpectedSubprocessExitError: 1225 If any of the subprocesses did not exit 1226 properly (for example, they exit on SIGTERM or SIGKILL signal). When 1227 this is raised, a namedtuple object can be retrieved by 1228 `tf.__internal__.distribute.multi_process_runner 1229 .UnexpectedSubprocessExitError`'s 1230 `mpr_result` attribute, which has the 1231 same structure as above 'Returns' section describes. If `max_run_time` 1232 is not `None`, it is expected that some subprocesses may be 1233 force-killed when `max_run_time` is up, and this is raised in those 1234 cases. 1235 Exception: if there is an Exception propagated from any subprocess. When 1236 this is raised, a namedtuple object can be retrieved by 1237 `tf.__internal__.distribute.multi_process_runner 1238 .UnexpectedSubprocessExitError` 1239 `mpr_result` attribute, which has the 1240 same structure as above 'Returns' section describes. 1241 1242 Examples: 1243 1244 ```python 1245 class SimpleMultiProcessTest(tf.test.TestCase): 1246 1247 def test_simple_printing_and_return(self): 1248 1249 def fn(): 1250 resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver() 1251 1252 # This will print "[chief-0]: Task type: chief , task id: 0" 1253 # for chief, for example. 1254 logging.info('Task type: %s, task id: %d', 1255 resolver.task_type, resolver.task_id) 1256 1257 return resolver.task_type 1258 1259 result = tf.__internal__.distribute.multi_process_runner.run( 1260 fn=fn, 1261 cluster_spec=( 1262 tf.__internal__ 1263 .distribute.multi_process_runner.create_cluster_spec( 1264 has_chief=True, num_workers=2))) 1265 assert sorted(result.return_value) == ['chief', 'worker', 'worker'] 1266 1267 def test_error_from_fn(self): 1268 1269 def fn(): 1270 resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver() 1271 raise ValueError('Task type {}, task id {} is errors out'.format( 1272 resolver.task_type, resolver.task_id)) 1273 1274 with self.assertRaisesRegexp(ValueError, 1275 'Task type worker, task id 0 is errors out'): 1276 cluster_spec = ( 1277 tf.__internal__.distribute.multi_process_runner.create_cluster_spec( 1278 num_workers=1)) 1279 tf.__internal__.distribute.multi_process_runner.run( 1280 fn=fn, cluster_spec=cluster_spec) 1281 1282 1283 if __name__ == '__main__': 1284 tf.__internal__.distribute.multi_process_runner.test_main() 1285 ``` 1286 """ 1287 runner = MultiProcessRunner( 1288 fn, 1289 cluster_spec, 1290 rpc_layer, 1291 max_run_time=max_run_time, 1292 return_output=return_output, 1293 args=args, 1294 kwargs=kwargs) 1295 runner.start() 1296 return runner.join(timeout) 1297 1298 1299# This is set by MultiProcessRunner in worker processes. 1300_barrier = None 1301 1302 1303@tf_export('__internal__.distribute.multi_process_runner.get_barrier', v1=[]) 1304def get_barrier(): 1305 """Returns a `multiprocessing.Barrier` for `multi_process_runner.run`. 1306 1307 `tf.__internal__.distribute.multi_process_runner.get_barrier()` returns 1308 a `multiprocessing.Barrier` object which can be used within `fn` of 1309 `tf.__internal__.distribute.multi_process_runner` to wait with 1310 `barrier.wait()` call until all other tasks have also reached the 1311 `barrier.wait()` call, before they can proceed individually. 1312 1313 Note that all tasks (subprocesses) have to reach `barrier.wait()` call to 1314 proceed. Currently it is not supported to block on only a subset of tasks 1315 in the cluster. 1316 1317 Example: 1318 ```python 1319 1320 def fn(): 1321 some_work_to_be_done_by_all_tasks() 1322 1323 tf.__internal__.distribute.multi_process_runner.get_barrier().wait() 1324 1325 # The barrier guarantees that at this point, all tasks have finished 1326 # `some_work_to_be_done_by_all_tasks()` 1327 some_other_work_to_be_done_by_all_tasks() 1328 1329 result = tf.__internal__.distribute.multi_process_runner.run( 1330 fn=fn, 1331 cluster_spec=( 1332 tf.__internal__ 1333 .distribute.multi_process_runner.create_cluster_spec( 1334 num_workers=2))) 1335 ``` 1336 1337 1338 Returns: 1339 A `multiprocessing.Barrier` for `multi_process_runner.run`. 1340 """ 1341 if _barrier is None: 1342 raise ValueError( 1343 'barrier is not defined. It is likely because you are calling ' 1344 'get_barrier() in the main process. get_barrier() can only be called ' 1345 'in the subprocesses.' 1346 ) 1347 return _barrier 1348 1349 1350_manager = None 1351_manager_lock = threading.Lock() 1352 1353 1354def manager(): 1355 """Returns the multiprocessing manager object for concurrency tools. 1356 1357 The manager object is useful as it controls a server process that holds 1358 the python objects that can be shared across processes. This can be used 1359 for parent-subprocess communication: 1360 1361 ```python 1362 manager = multi_process_runner.manager() 1363 some_event_happening_in_subprocess = manager.Event() 1364 mpr = multi_process_runner.MultiProcessRunner(fn, cluster_spec, 1365 args=(some_event_happening_in_subprocess,)) 1366 mpr.start() 1367 some_event_happening_in_subprocess.wait() 1368 # Do something that only should after some event happens in subprocess. 1369 ``` 1370 1371 Note that the user of multi_process_runner should not create additional 1372 `multiprocessing.Manager()` objects; doing so can result in segfault in 1373 some cases. 1374 1375 This method should only be called after multi_process_runner.test_main() is 1376 called. 1377 """ 1378 global _manager 1379 with _manager_lock: 1380 if _manager is None: 1381 _manager = multiprocessing.Manager() 1382 return _manager 1383 1384 1385@tf_export('__internal__.distribute.multi_process_runner.test_main', v1=[]) 1386def test_main(): 1387 """Main function to be called within `__main__` of a test file. 1388 1389 Any test module that uses 1390 `tf.__internal__.distribute.multi_process_runner.run()` 1391 must call this instead of regular `test.main()` inside 1392 `if __name__ == '__main__':` block, or an error will be raised when 1393 `tf.__internal__.distribute.multi_process_runner.run()` is used. This method 1394 takes 1395 care of needed initialization for launching multiple subprocesses. 1396 1397 Example: 1398 ```python 1399 class MyTestClass(tf.test.TestCase): 1400 def testSomething(self): 1401 # Testing code making use of 1402 # `tf.__internal__.distribute.multi_process_runner.run()`. 1403 1404 if __name__ == '__main__': 1405 tf.__internal__.distribute.multi_process_runner.test_main() 1406 ``` 1407 """ 1408 # Inject tearDownModule() to shut down all pool runners. Active pool runners 1409 # will block the program from exiting. This is necessary for global pool 1410 # runners. We tried atexit in the past, and it doesn't work in some 1411 # deployment. 1412 old_tear_down_module = getattr(sys.modules['__main__'], 'tearDownModule', 1413 None) 1414 1415 def tear_down_module(): 1416 _shutdown_all_pool_runners() 1417 if old_tear_down_module is not None: 1418 old_tear_down_module() 1419 1420 setattr(sys.modules['__main__'], 'tearDownModule', tear_down_module) 1421 multi_process_lib.test_main() 1422