1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Base testing class for strategies that require multiple nodes.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import contextlib 22import copy 23import json 24import os 25import subprocess 26import sys 27import threading 28import unittest 29 30import six 31 32_portpicker_import_error = None 33try: 34 import portpicker # pylint: disable=g-import-not-at-top 35except (ImportError, ModuleNotFoundError) as _error: # pylint: disable=invalid-name 36 _portpicker_import_error = _error 37 portpicker = None 38 39# pylint: disable=g-import-not-at-top 40from tensorflow.core.protobuf import config_pb2 41from tensorflow.core.protobuf import rewriter_config_pb2 42from tensorflow.python.client import session 43from tensorflow.python.distribute import distribute_coordinator as dc 44from tensorflow.python.distribute import multi_process_runner 45from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver 46from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver 47from tensorflow.python.eager import context 48from tensorflow.python.eager import remote 49from tensorflow.python.framework import errors 50from tensorflow.python.framework import ops 51from tensorflow.python.framework import test_util 52from tensorflow.python.platform import test 53from tensorflow.python.platform import tf_logging as logging 54from tensorflow.python.training import coordinator 55from tensorflow.python.training import server_lib 56from tensorflow.python.util import deprecation 57from tensorflow.python.util import nest 58from tensorflow.python.util.compat import collections_abc 59from tensorflow.python.util.tf_export import tf_export 60 61 62original_run_std_server = dc._run_std_server # pylint: disable=protected-access 63 64ASSIGNED_PORTS = set() 65lock = threading.Lock() 66 67 68def pick_unused_port(): 69 """Returns an unused and unassigned local port.""" 70 if _portpicker_import_error: 71 raise _portpicker_import_error # pylint: disable=raising-bad-type 72 73 global ASSIGNED_PORTS 74 with lock: 75 while True: 76 try: 77 port = portpicker.pick_unused_port() 78 except portpicker.NoFreePortFoundError: 79 raise unittest.SkipTest('Flakes in portpicker library do not represent ' 80 'TensorFlow errors.') 81 if port > 10000 and port not in ASSIGNED_PORTS: 82 ASSIGNED_PORTS.add(port) 83 logging.info('Using local port %r', port) 84 return port 85 86 87def _create_cluster(num_workers, 88 num_ps, 89 has_chief=False, 90 has_eval=False, 91 protocol='grpc', 92 worker_config=None, 93 ps_config=None, 94 eval_config=None, 95 worker_name='worker', 96 ps_name='ps', 97 chief_name='chief'): 98 """Creates and starts local servers and returns the cluster_spec dict.""" 99 if _portpicker_import_error: 100 raise _portpicker_import_error # pylint: disable=raising-bad-type 101 worker_ports = [pick_unused_port() for _ in range(num_workers)] 102 ps_ports = [pick_unused_port() for _ in range(num_ps)] 103 104 cluster_dict = {} 105 if num_workers > 0: 106 cluster_dict[worker_name] = ['localhost:%s' % port for port in worker_ports] 107 if num_ps > 0: 108 cluster_dict[ps_name] = ['localhost:%s' % port for port in ps_ports] 109 if has_eval: 110 cluster_dict['evaluator'] = ['localhost:%s' % pick_unused_port()] 111 if has_chief: 112 cluster_dict[chief_name] = ['localhost:%s' % pick_unused_port()] 113 114 cs = server_lib.ClusterSpec(cluster_dict) 115 116 for i in range(num_workers): 117 server_lib.Server( 118 cs, 119 job_name=worker_name, 120 protocol=protocol, 121 task_index=i, 122 config=worker_config, 123 start=True) 124 125 for i in range(num_ps): 126 server_lib.Server( 127 cs, 128 job_name=ps_name, 129 protocol=protocol, 130 task_index=i, 131 config=ps_config, 132 start=True) 133 134 if has_chief: 135 server_lib.Server( 136 cs, 137 job_name=chief_name, 138 protocol=protocol, 139 task_index=0, 140 config=worker_config, 141 start=True) 142 143 if has_eval: 144 server_lib.Server( 145 cs, 146 job_name='evaluator', 147 protocol=protocol, 148 task_index=0, 149 config=eval_config, 150 start=True) 151 152 return cluster_dict 153 154 155def create_in_process_cluster(num_workers, 156 num_ps, 157 has_chief=False, 158 has_eval=False, 159 rpc_layer='grpc'): 160 """Create an in-process cluster that consists of only standard server.""" 161 # Leave some memory for cuda runtime. 162 gpu_mem_frac = 0.7 / (num_workers + int(has_chief) + int(has_eval)) 163 worker_config = config_pb2.ConfigProto() 164 worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac 165 166 # The cluster may hang if workers don't have enough inter_op threads. See 167 # b/172296720 for more details. 168 if worker_config.inter_op_parallelism_threads < num_workers + 1: 169 worker_config.inter_op_parallelism_threads = num_workers + 1 170 171 # Enable collective ops which has no impact on non-collective ops. 172 # TODO(yuefengz, tucker): removing this after we move the initialization of 173 # collective mgr to the session level. 174 if has_chief: 175 worker_config.experimental.collective_group_leader = ( 176 '/job:chief/replica:0/task:0') 177 else: 178 worker_config.experimental.collective_group_leader = ( 179 '/job:worker/replica:0/task:0') 180 181 ps_config = config_pb2.ConfigProto() 182 ps_config.device_count['GPU'] = 0 183 184 eval_config = config_pb2.ConfigProto() 185 eval_config.experimental.collective_group_leader = '' 186 187 # Create in-process servers. Once an in-process tensorflow server is created, 188 # there is no way to terminate it. So we create one cluster per test process. 189 # We could've started the server in another process, we could then kill that 190 # process to terminate the server. The reasons why we don't want multiple 191 # processes are 192 # 1) it is more difficult to manage these processes; 193 # 2) there is something global in CUDA such that if we initialize CUDA in the 194 # parent process, the child process cannot initialize it again and thus cannot 195 # use GPUs (https://stackoverflow.com/questions/22950047). 196 cluster = None 197 try: 198 cluster = _create_cluster( 199 num_workers, 200 num_ps=num_ps, 201 has_chief=has_chief, 202 has_eval=has_eval, 203 worker_config=worker_config, 204 ps_config=ps_config, 205 eval_config=eval_config, 206 protocol=rpc_layer) 207 except errors.UnknownError as e: 208 if 'Could not start gRPC server' in e.message: 209 raise unittest.SkipTest('Cannot start std servers.') 210 else: 211 raise 212 return cluster 213 214 215class MultiProcessCluster(object): 216 """A cluster of TensorFlow servers in separate processes. 217 218 This class is not thread-safe. 219 """ 220 221 def __init__(self, 222 cluster_resolver, 223 stream_output=False, 224 collective_leader=None): 225 self._cluster_resolver = cluster_resolver 226 self._cluster_spec = cluster_resolver.cluster_spec().as_dict() 227 self._rpc_layer = cluster_resolver.rpc_layer 228 self._stream_output = stream_output 229 self._start_events = {} 230 self._finish_events = {} 231 self._mpr_manager = multi_process_runner.manager() 232 233 def task_function(start_events, finish_events): 234 cluster_resolver = TFConfigClusterResolver() 235 cluster_spec = cluster_resolver.cluster_spec() 236 task_type = cluster_resolver.task_type 237 task_id = cluster_resolver.task_id 238 rpc_layer = cluster_resolver.rpc_layer 239 240 # TODO(yuefengz): support GPU clusters. 241 server_config = config_pb2.ConfigProto() 242 server_config.device_count['GPU'] = 0 243 244 if collective_leader: 245 server_config.experimental.collective_group_leader = collective_leader 246 server_config.experimental.collective_nccl = False 247 248 logging.info( 249 'Enabling collective ops with cluster_spec = %r, task_type = %r, ' 250 'task_id = %r, rpc_layer = %r, collective_leader = %s', 251 cluster_spec, task_type, task_id, rpc_layer, collective_leader) 252 else: 253 logging.info( 254 'Starting server with cluster_spec = %r, task_type = %r, ' 255 'task_id = %r, rpc_layer = %r', cluster_spec, task_type, task_id, 256 rpc_layer) 257 258 server_lib.Server( 259 cluster_spec, 260 job_name=task_type, 261 protocol=rpc_layer, 262 task_index=task_id, 263 config=server_config, 264 start=True) 265 266 start_event = start_events[task_type][task_id] 267 start_event.set() 268 269 finish_event = finish_events[task_type][task_id] 270 finish_event.wait() 271 272 os._exit(0) # pylint: disable=protected-access 273 274 self._task_function = task_function 275 self._mpr = None 276 277 def start(self): 278 """Starts one TensorFlow server for each task in the cluster_resolver. 279 280 It will wait until all the servers are up before returns. 281 """ 282 if self._mpr: 283 raise ValueError('The cluster has already been started.') 284 for task_type, task_addresses in self._cluster_spec.items(): 285 self._start_events[task_type] = [] 286 self._finish_events[task_type] = [] 287 for _ in task_addresses: 288 self._start_events[task_type].append(self._mpr_manager.Event()) 289 self._finish_events[task_type].append(self._mpr_manager.Event()) 290 291 self._mpr = multi_process_runner.MultiProcessRunner( 292 self._task_function, 293 self._cluster_spec, 294 args=(self._start_events, self._finish_events), 295 rpc_layer=self._rpc_layer, 296 stream_output=self._stream_output, 297 return_output=False, 298 use_dill_for_args=False) 299 self._mpr.start() 300 for task_type, task_addresses in self._cluster_spec.items(): 301 for i in range(len(task_addresses)): 302 self._start_events[task_type][i].wait() 303 304 def stop(self): 305 """Stops all the servers.""" 306 for task_type, task_addresses in self._cluster_spec.items(): 307 for i in range(len(task_addresses)): 308 self._finish_events[task_type][i].set() 309 try: 310 self._mpr.join() 311 except multi_process_runner.UnexpectedSubprocessExitError: 312 # TODO(yuefengz): investigate why processes exit with 255. 313 pass 314 self._mpr = None 315 self._start_events = {} 316 self._finish_events = {} 317 318 def kill_task(self, task_type, task_id): 319 """Kill a server given task_type and task_id. 320 321 Args: 322 task_type: the type of the task such as "worker". 323 task_id: the id the task such as 1. 324 """ 325 assert self._mpr 326 if (not self._start_events[task_type][task_id].is_set() or 327 self._finish_events[task_type][task_id].is_set()): 328 raise ValueError("The task %s:%d doesn't exist." % (task_type, task_id)) 329 330 self._finish_events[task_type][task_id].set() 331 self._mpr._processes[(task_type, task_id)].join() 332 333 def start_task(self, task_type, task_id): 334 """Starts a server given task_type and task_id. 335 336 Args: 337 task_type: the type of the task such as "worker". 338 task_id: the id the task such as 1. 339 340 Raises: 341 ValueError: if the server alreay exists. 342 """ 343 assert self._mpr 344 345 if (not self._start_events[task_type][task_id].is_set() or 346 not self._finish_events[task_type][task_id].is_set()): 347 raise ValueError( 348 'The task %s:%d is still alive. You cannot start another one.' % 349 (task_type, task_id)) 350 self._start_events[task_type][task_id] = self._mpr_manager.Event() 351 self._finish_events[task_type][task_id] = self._mpr_manager.Event() 352 self._mpr.start_single_process(task_type=task_type, task_id=task_id) 353 self._start_events[task_type][task_id].wait() 354 355 @property 356 def cluster_resolver(self): 357 return copy.deepcopy(self._cluster_resolver) 358 359 360def create_multi_process_cluster(num_workers, 361 num_ps, 362 has_chief=False, 363 has_eval=False, 364 rpc_layer='grpc', 365 stream_output=False, 366 collective_leader=None): 367 cluster_spec = create_cluster_spec( 368 has_chief=has_chief, 369 num_workers=num_workers, 370 num_ps=num_ps, 371 has_eval=has_eval) 372 373 cluster = MultiProcessCluster( 374 SimpleClusterResolver( 375 server_lib.ClusterSpec(cluster_spec), rpc_layer=rpc_layer), 376 stream_output=stream_output, 377 collective_leader=collective_leader) 378 cluster.start() 379 return cluster 380 381 382@tf_export( 383 '__internal__.distribute.multi_process_runner.create_cluster_spec', v1=[]) 384def create_cluster_spec(has_chief=False, 385 num_workers=1, 386 num_ps=0, 387 has_eval=False): 388 """Create a cluster spec with tasks with unused local ports. 389 390 This utility finds available ports at localhost, and returns a dict that 391 represents the cluster spec that utilizes those ports, according to the 392 arguments. The dict representing the cluster spec contains task types, and 393 their instances' addresses. Note that this is usually only for testing purpose 394 using multiple processes in the local machine, and should not be used for real 395 multi-worker TensorFlow programs, where the addresses need to point to the 396 processes at separate machines. 397 398 This util is useful when creating the `cluster_spec` arg for 399 `tf.__internal__.distribute.multi_process_runner.run`. 400 401 Args: 402 has_chief: Whether the generated cluster spec should contain "chief" task 403 type. 404 num_workers: Number of workers to use in the cluster spec. 405 num_ps: Number of parameter servers to use in the cluster spec. 406 has_eval: Whether this cluster spec has evaluator. 407 408 Returns: 409 A dict that represents the cluster spec using localhost ports for the tasks. 410 411 Example: 412 413 ```python 414 cluster_spec = 415 tf.__internal__.distribute.multi_process_runner.create_cluster_spec( 416 has_chief=True, num_workers=2, num_ps=2) 417 # An example of cluster_spec is 418 # {'chief': ['localhost:23381'], 419 # 'worker': ['localhost:19197', 'localhost:22903'], 420 # 'ps': ['localhost:16912', 'localhost:21535']} 421 422 cluster_spec = 423 tf.__internal__.distribute.multi_process_runner.create_cluster_spec( 424 has_chief=False, num_workers=0, num_ps=0, has_eval=True) 425 # An example of cluster_spec is 426 # {'evaluator': ['localhost:23381']} 427 ``` 428 """ 429 if _portpicker_import_error: 430 raise _portpicker_import_error # pylint: disable=raising-bad-type 431 432 cluster_spec = {} 433 if has_chief: 434 cluster_spec['chief'] = ['localhost:%s' % pick_unused_port()] 435 if num_workers: 436 cluster_spec['worker'] = [ 437 'localhost:%s' % pick_unused_port() for _ in range(num_workers) 438 ] 439 if num_ps: 440 cluster_spec['ps'] = [ 441 'localhost:%s' % pick_unused_port() for _ in range(num_ps) 442 ] 443 if has_eval: 444 cluster_spec['evaluator'] = ['localhost:%s' % pick_unused_port()] 445 return cluster_spec 446 447 448@contextlib.contextmanager 449def skip_if_grpc_server_cant_be_started(test_obj): 450 try: 451 yield 452 except errors.UnknownError as e: 453 if 'Could not start gRPC server' in e.message: 454 reason = 'Cannot start std servers.' 455 test_obj.test_skipped_reason = reason 456 test_obj.skipTest(reason) 457 else: 458 raise 459 460 461class MultiWorkerTestBase(test.TestCase): 462 """Base class for testing multi node strategy and dataset.""" 463 464 @classmethod 465 def setUpClass(cls, num_workers=2, num_ps=1): # pylint: disable=g-missing-super-call 466 """Create a local cluster with 2 workers.""" 467 cls._cluster_spec = create_in_process_cluster(num_workers=num_workers, 468 num_ps=num_ps) 469 cls._default_target = 'grpc://' + cls._cluster_spec['worker'][0] 470 471 def setUp(self): 472 # We only cache the session in one test because another test may have a 473 # different session config or master target. 474 self._thread_local = threading.local() 475 self._thread_local.cached_session = None 476 self._coord = coordinator.Coordinator() 477 478 @contextlib.contextmanager 479 def session(self, graph=None, config=None, target=None): 480 """Create a test session with master target set to the testing cluster. 481 482 Creates a test session that connects to the local testing cluster. 483 484 Args: 485 graph: Optional graph to use during the returned session. 486 config: An optional config_pb2.ConfigProto to use to configure the 487 session. 488 target: the target of session to connect to. 489 490 Yields: 491 A Session object that should be used as a context manager to surround 492 the graph building and execution code in a test case. 493 """ 494 config = self._create_config(config) 495 496 if target is None: 497 target = self._default_target 498 with session.Session(graph=graph, config=config, target=target) as sess: 499 yield sess 500 501 @contextlib.contextmanager 502 # TODO(b/117573461): Overwrite self.evaluate() to use this function. 503 def cached_session(self, graph=None, config=None, target=None): 504 """Create a test session with master target set to the testing cluster. 505 506 Creates a test session that connects to the local testing cluster. 507 The session is only created once per test and then reused. 508 509 Args: 510 graph: Optional graph to use during the returned session. 511 config: An optional config_pb2.ConfigProto to use to configure the 512 session. 513 target: the target of session to connect to. 514 515 Yields: 516 A Session object that should be used as a context manager to surround 517 the graph building and execution code in a test case. Note that the 518 session will live until the end of the test. 519 """ 520 config = self._create_config(config) 521 522 if target is None: 523 target = self._default_target 524 if getattr(self._thread_local, 'cached_session', None) is None: 525 self._thread_local.cached_session = session.Session( 526 graph=None, config=config, target=target) 527 sess = self._thread_local.cached_session 528 with sess.graph.as_default(), sess.as_default(): 529 yield sess 530 531 def _create_config(self, config): 532 if config is None: 533 config = config_pb2.ConfigProto(allow_soft_placement=True) 534 else: 535 config = copy.deepcopy(config) 536 # Don't perform optimizations for tests so we don't inadvertently run 537 # gpu ops on cpu 538 config.graph_options.optimizer_options.opt_level = -1 539 config.graph_options.rewrite_options.constant_folding = ( 540 rewriter_config_pb2.RewriterConfig.OFF) 541 542 return config 543 544 def _run_client(self, client_fn, task_type, task_id, num_gpus, eager_mode, 545 *args, **kwargs): 546 547 def wrapped_client_fn(): 548 with self._coord.stop_on_exception(): 549 client_fn(task_type, task_id, num_gpus, *args, **kwargs) 550 551 if eager_mode: 552 with context.eager_mode(): 553 wrapped_client_fn() 554 else: 555 with context.graph_mode(): 556 wrapped_client_fn() 557 558 def _run_between_graph_clients(self, client_fn, cluster_spec, num_gpus, *args, 559 **kwargs): 560 """Runs several clients for between-graph replication. 561 562 Args: 563 client_fn: a function that needs to accept `task_type`, `task_id`, 564 `num_gpus`. 565 cluster_spec: a dict specifying jobs in a cluster. 566 num_gpus: number of GPUs per worker. 567 *args: will be passed to `client_fn`. 568 **kwargs: will be passed to `client_fn`. 569 """ 570 threads = [] 571 for task_type in ['chief', 'worker']: 572 for task_id in range(len(cluster_spec.get(task_type, []))): 573 t = threading.Thread( 574 target=self._run_client, 575 args=(client_fn, task_type, task_id, num_gpus, 576 context.executing_eagerly()) + args, 577 kwargs=kwargs) 578 t.start() 579 threads.append(t) 580 self._coord.join(threads) 581 582 583class SingleWorkerTestBaseGraph(MultiWorkerTestBase): 584 """Base class for testing remote single worker strategy graph and dataset.""" 585 586 @classmethod 587 def setUpClass(cls): 588 super(SingleWorkerTestBaseGraph, cls).setUpClass(num_workers=1) 589 590 591class SingleWorkerTestBaseEager(test.TestCase): 592 """Base class for testing remote single worker strategy eager and dataset.""" 593 594 def setUp(self): 595 super(SingleWorkerTestBaseEager, self).setUp() 596 workers, _ = test_util.create_local_cluster(num_workers=1, num_ps=0) 597 remote.connect_to_remote_host(workers[0].target) 598 599 def cached_session(self): 600 return DummySession() 601 602 603class DummySession(object): 604 605 def __enter__(self): 606 return 607 608 def __exit__(self, exception_type, exception_value, traceback): 609 pass 610 611 612class MockOsEnv(collections_abc.Mapping): 613 """A class that allows per-thread TF_CONFIG.""" 614 615 def __init__(self, *args): 616 self._dict = dict() 617 self._thread_local = threading.local() 618 super(MockOsEnv, self).__init__(*args) 619 620 def get(self, key, default=None): 621 if not hasattr(self._thread_local, 'dict'): 622 self._thread_local.dict = dict() 623 if key == 'TF_CONFIG': 624 return dict.get(self._thread_local.dict, key, default) 625 else: 626 return dict.get(self._dict, key, default) 627 628 def __getitem__(self, key): 629 if not hasattr(self._thread_local, 'dict'): 630 self._thread_local.dict = dict() 631 if key == 'TF_CONFIG': 632 return dict.__getitem__(self._thread_local.dict, key) 633 else: 634 return dict.__getitem__(self._dict, key) 635 636 def __setitem__(self, key, val): 637 if not hasattr(self._thread_local, 'dict'): 638 self._thread_local.dict = dict() 639 if key == 'TF_CONFIG': 640 return dict.__setitem__(self._thread_local.dict, key, val) 641 else: 642 return dict.__setitem__(self._dict, key, val) 643 644 def __iter__(self): 645 if not hasattr(self._thread_local, 'dict'): 646 self._thread_local.dict = dict() 647 for x in self._thread_local.dict: 648 yield x 649 for x in self._dict: 650 yield x 651 652 def __len__(self): 653 if not hasattr(self._thread_local, 'dict'): 654 self._thread_local.dict = dict() 655 return self._thread_local.dict.__len__() + self._dict.__len__() 656 657 658class IndependentWorkerTestBase(test.TestCase): 659 """Testing infra for independent workers.""" 660 661 def _make_mock_run_std_server(self): 662 663 def _mock_run_std_server(*args, **kwargs): 664 """Returns the std server once all threads have started it.""" 665 with skip_if_grpc_server_cant_be_started(self): 666 ret = original_run_std_server(*args, **kwargs) 667 # Wait for all std servers to be brought up in order to reduce the chance 668 # of remote sessions taking local ports that have been assigned to std 669 # servers. Only call this barrier the first time this function is run for 670 # each thread. 671 if not getattr(self._thread_local, 'server_started', False): 672 self._barrier.wait() 673 self._thread_local.server_started = True 674 return ret 675 676 return _mock_run_std_server 677 678 def setUp(self): 679 self._mock_os_env = MockOsEnv() 680 self._mock_context = test.mock.patch.object(os, 'environ', 681 self._mock_os_env) 682 self._coord = coordinator.Coordinator() 683 super(IndependentWorkerTestBase, self).setUp() 684 self._mock_context.__enter__() 685 # threading local object to be shared by all threads 686 self._thread_local = threading.local() 687 688 def tearDown(self): 689 self._mock_context.__exit__(None, None, None) 690 super(IndependentWorkerTestBase, self).tearDown() 691 692 def _task_thread(self, task_fn, tf_config, executing_eagerly, *args, 693 **kwargs): 694 with self._coord.stop_on_exception(): 695 os.environ['TF_CONFIG'] = json.dumps(tf_config) 696 # Force the new thread simulating a worker to run in the same context 697 # mode as the parent thread does. 698 if executing_eagerly: 699 with context.eager_mode(): 700 task_fn(*args, **kwargs) 701 else: 702 with ops.Graph().as_default(), context.graph_mode(): 703 task_fn(*args, **kwargs) 704 705 def _run_task_in_thread(self, task_fn, cluster_spec, task_type, task_id, 706 *args, **kwargs): 707 """Run tasks in a thread. 708 709 If `tf_config` is provided, use it for the new thread; if not, construct one 710 from `cluster_spec`, `task_type`, and `task_id`, and provide it to the new 711 thread to be set as `TF_CONFIG` environment. 712 713 Args: 714 task_fn: The function to run in the new thread. 715 cluster_spec: The cluster spec. 716 task_type: The task type. 717 task_id: The task id. 718 *args: Additional positional arguments to provide to the thread's task_fn. 719 **kwargs: Additional keyword arguments to provide to the thread's task_fn. 720 If `tf_config` is provided, that dict will be used for the TF_CONFIG for 721 the new thread. 722 723 Returns: 724 The thread that has started. 725 """ 726 tf_config = kwargs.pop('tf_config', None) 727 if tf_config is None: 728 if task_type: 729 tf_config = { 730 'cluster': cluster_spec, 731 'task': { 732 'type': task_type, 733 'index': task_id 734 } 735 } 736 else: 737 tf_config = { 738 'cluster': cluster_spec, 739 } 740 t = threading.Thread( 741 target=self._task_thread, 742 args=(task_fn, tf_config, context.executing_eagerly()) + args, 743 kwargs=kwargs) 744 t.start() 745 return t 746 747 def run_multiple_tasks_in_threads(self, task_fn, cluster_spec, *args, 748 **kwargs): 749 # The task_fn should create std_server by itself. 750 threads = {} 751 for task_type in cluster_spec.keys(): 752 threads[task_type] = [] 753 for task_id in range(len(cluster_spec[task_type])): 754 t = self._run_task_in_thread(task_fn, cluster_spec, task_type, task_id, 755 *args, **kwargs) 756 threads[task_type].append(t) 757 return threads 758 759 def join_independent_workers(self, worker_threads): 760 with skip_if_grpc_server_cant_be_started(self): 761 self._coord.join(worker_threads) 762 763 764class MultiWorkerMultiProcessTest(test.TestCase): 765 """Testing infra for independent workers using multiple processes.""" 766 767 def _run_task_in_process(self, cmd_args, cluster_spec, task_type, task_id): 768 env = os.environ.copy() 769 env['TF_CONFIG'] = json.dumps({ 770 'cluster': cluster_spec, 771 'task': { 772 'type': task_type, 773 'index': task_id 774 } 775 }) 776 return subprocess.Popen( 777 cmd_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env) 778 779 @deprecation.deprecated( 780 None, '`run_multiple_tasks_in_processes` is deprecated; any new test ' 781 'requiring multiple processes should use `multi_process_runner` for ' 782 'better support of log printing, streaming, and more functionality.') 783 def run_multiple_tasks_in_processes(self, cmd_args, cluster_spec): 784 """Run `cmd_args` in a process for each task in `cluster_spec`.""" 785 processes = {} 786 for task_type in cluster_spec.keys(): 787 processes[task_type] = [] 788 for task_id in range(len(cluster_spec[task_type])): 789 p = self._run_task_in_process(cmd_args, cluster_spec, task_type, 790 task_id) 791 processes[task_type].append(p) 792 return processes 793 794 @deprecation.deprecated( 795 None, '`join_independent_workers` is deprecated; any new test ' 796 'requiring multiple processes should use `multi_process_runner` for ' 797 'better support of log printing, streaming, and more functionality.') 798 def join_independent_workers(self, worker_processes): 799 return_codes = [] 800 for p in nest.flatten(worker_processes): 801 try: 802 # Calling p.wait() will hang if we don't consume its output. 803 p.communicate() 804 except ValueError: 805 # The output of the process may have been consumed, in which case 806 # calling `p.communicate()` will raise a ValueError. 807 pass 808 finally: 809 return_codes.append(p.returncode) 810 for return_code in return_codes: 811 self.assertEqual(return_code, 0) 812 813 @deprecation.deprecated( 814 None, '`stream_stderr` is deprecated; any new test ' 815 'requiring multiple processes should use `multi_process_runner` for ' 816 'better support of log printing, streaming, and more functionality.') 817 def stream_stderr(self, processes, print_only_first=False): 818 """Consume stderr of all processes and print to stdout. 819 820 To reduce the amount of logging, caller can set print_only_first to True. 821 In that case, this function only prints stderr from the first process of 822 each type. 823 824 Args: 825 processes: A dictionary from process type string -> list of processes. 826 print_only_first: If true, only print output from first process of each 827 type. 828 """ 829 830 def _stream_stderr_single_process(process, type_string, index, 831 print_to_stdout): 832 """Consume a single process's stderr and optionally print to stdout.""" 833 while True: 834 output = process.stderr.readline() 835 if not output and process.poll() is not None: 836 break 837 if output and print_to_stdout: 838 print('{}{} {}'.format(type_string, index, output.strip())) 839 sys.stdout.flush() 840 841 stream_threads = [] 842 for process_type, process_list in six.iteritems(processes): 843 for i in range(len(process_list)): 844 print_to_stdout = (not print_only_first) or (i == 0) 845 thread = threading.Thread( 846 target=_stream_stderr_single_process, 847 args=(process_list[i], process_type, i, print_to_stdout)) 848 thread.start() 849 stream_threads.append(thread) 850 for thread in stream_threads: 851 thread.join() 852 853 854def get_tf_config_task(): 855 return json.loads(os.environ['TF_CONFIG'])['task'] 856 857 858def get_tf_config_cluster_spec(): 859 return json.loads(os.environ['TF_CONFIG'])['cluster'] 860 861 862def get_task_type(): 863 return get_tf_config_task()['type'] 864 865 866def get_task_index(): 867 return get_tf_config_task()['index'] 868 869 870def is_chief(): 871 return ('chief' not in get_tf_config_cluster_spec() 872 and get_task_type() == 'worker' 873 and get_task_index() == 0) 874