1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Base testing class for strategies that require multiple nodes.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import contextlib 22import copy 23import json 24import multiprocessing 25import os 26import subprocess 27import sys 28import threading 29import unittest 30 31import six 32 33_portpicker_import_error = None 34try: 35 import portpicker # pylint: disable=g-import-not-at-top 36except (ImportError, ModuleNotFoundError) as _error: # pylint: disable=invalid-name 37 _portpicker_import_error = _error 38 portpicker = None 39 40# pylint: disable=g-import-not-at-top 41from tensorflow.core.protobuf import config_pb2 42from tensorflow.core.protobuf import rewriter_config_pb2 43from tensorflow.python.client import session 44from tensorflow.python.distribute import distribute_coordinator as dc 45from tensorflow.python.distribute import multi_process_runner 46from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver 47from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver 48from tensorflow.python.eager import context 49from tensorflow.python.eager import remote 50from tensorflow.python.framework import errors 51from tensorflow.python.framework import ops 52from tensorflow.python.framework import test_util 53from tensorflow.python.platform import test 54from tensorflow.python.platform import tf_logging as logging 55from tensorflow.python.training import coordinator 56from tensorflow.python.training import server_lib 57from tensorflow.python.util import deprecation 58from tensorflow.python.util import nest 59from tensorflow.python.util.compat import collections_abc 60from tensorflow.python.util.tf_export import tf_export 61 62 63original_run_std_server = dc._run_std_server # pylint: disable=protected-access 64 65ASSIGNED_PORTS = set() 66lock = threading.Lock() 67 68 69def pick_unused_port(): 70 """Returns an unused and unassigned local port.""" 71 if _portpicker_import_error: 72 raise _portpicker_import_error # pylint: disable=raising-bad-type 73 74 global ASSIGNED_PORTS 75 with lock: 76 while True: 77 try: 78 port = portpicker.pick_unused_port() 79 except portpicker.NoFreePortFoundError: 80 raise unittest.SkipTest('Flakes in portpicker library do not represent ' 81 'TensorFlow errors.') 82 if port > 10000 and port not in ASSIGNED_PORTS: 83 ASSIGNED_PORTS.add(port) 84 logging.info('Using local port %r', port) 85 return port 86 87 88def _create_cluster(num_workers, 89 num_ps, 90 has_chief=False, 91 has_eval=False, 92 protocol='grpc', 93 worker_config=None, 94 ps_config=None, 95 eval_config=None, 96 worker_name='worker', 97 ps_name='ps', 98 chief_name='chief'): 99 """Creates and starts local servers and returns the cluster_spec dict.""" 100 if _portpicker_import_error: 101 raise _portpicker_import_error # pylint: disable=raising-bad-type 102 worker_ports = [pick_unused_port() for _ in range(num_workers)] 103 ps_ports = [pick_unused_port() for _ in range(num_ps)] 104 105 cluster_dict = {} 106 if num_workers > 0: 107 cluster_dict[worker_name] = ['localhost:%s' % port for port in worker_ports] 108 if num_ps > 0: 109 cluster_dict[ps_name] = ['localhost:%s' % port for port in ps_ports] 110 if has_eval: 111 cluster_dict['evaluator'] = ['localhost:%s' % pick_unused_port()] 112 if has_chief: 113 cluster_dict[chief_name] = ['localhost:%s' % pick_unused_port()] 114 115 cs = server_lib.ClusterSpec(cluster_dict) 116 117 for i in range(num_workers): 118 server_lib.Server( 119 cs, 120 job_name=worker_name, 121 protocol=protocol, 122 task_index=i, 123 config=worker_config, 124 start=True) 125 126 for i in range(num_ps): 127 server_lib.Server( 128 cs, 129 job_name=ps_name, 130 protocol=protocol, 131 task_index=i, 132 config=ps_config, 133 start=True) 134 135 if has_chief: 136 server_lib.Server( 137 cs, 138 job_name=chief_name, 139 protocol=protocol, 140 task_index=0, 141 config=worker_config, 142 start=True) 143 144 if has_eval: 145 server_lib.Server( 146 cs, 147 job_name='evaluator', 148 protocol=protocol, 149 task_index=0, 150 config=eval_config, 151 start=True) 152 153 return cluster_dict 154 155 156def create_in_process_cluster(num_workers, 157 num_ps, 158 has_chief=False, 159 has_eval=False, 160 rpc_layer='grpc'): 161 """Create an in-process cluster that consists of only standard server.""" 162 # Leave some memory for cuda runtime. 163 gpu_mem_frac = 0.7 / (num_workers + int(has_chief) + int(has_eval)) 164 worker_config = config_pb2.ConfigProto() 165 worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac 166 167 # The cluster may hang if workers don't have enough inter_op threads. See 168 # b/172296720 for more details. 169 if multiprocessing.cpu_count() < 4: 170 worker_config.inter_op_parallelism_threads = 4 171 172 # Enable collective ops which has no impact on non-collective ops. 173 # TODO(yuefengz, tucker): removing this after we move the initialization of 174 # collective mgr to the session level. 175 if has_chief: 176 worker_config.experimental.collective_group_leader = ( 177 '/job:chief/replica:0/task:0') 178 else: 179 worker_config.experimental.collective_group_leader = ( 180 '/job:worker/replica:0/task:0') 181 182 ps_config = config_pb2.ConfigProto() 183 ps_config.device_count['GPU'] = 0 184 185 eval_config = config_pb2.ConfigProto() 186 eval_config.experimental.collective_group_leader = '' 187 188 # Create in-process servers. Once an in-process tensorflow server is created, 189 # there is no way to terminate it. So we create one cluster per test process. 190 # We could've started the server in another process, we could then kill that 191 # process to terminate the server. The reasons why we don't want multiple 192 # processes are 193 # 1) it is more difficult to manage these processes; 194 # 2) there is something global in CUDA such that if we initialize CUDA in the 195 # parent process, the child process cannot initialize it again and thus cannot 196 # use GPUs (https://stackoverflow.com/questions/22950047). 197 cluster = None 198 try: 199 cluster = _create_cluster( 200 num_workers, 201 num_ps=num_ps, 202 has_chief=has_chief, 203 has_eval=has_eval, 204 worker_config=worker_config, 205 ps_config=ps_config, 206 eval_config=eval_config, 207 protocol=rpc_layer) 208 except errors.UnknownError as e: 209 if 'Could not start gRPC server' in e.message: 210 raise unittest.SkipTest('Cannot start std servers.') 211 else: 212 raise 213 return cluster 214 215 216class MultiProcessCluster(object): 217 """A cluster of TensorFlow servers in separate processes. 218 219 This class is not thread-safe. 220 """ 221 222 def __init__(self, cluster_resolver): 223 self._cluster_resolver = cluster_resolver 224 self._cluster_spec = cluster_resolver.cluster_spec().as_dict() 225 self._rpc_layer = cluster_resolver.rpc_layer 226 self._start_events = {} 227 self._finish_events = {} 228 self._mpr_manager = multi_process_runner.manager() 229 230 def task_function(start_events, finish_events): 231 cluster_resolver = TFConfigClusterResolver() 232 cluster_spec = cluster_resolver.cluster_spec() 233 task_type = cluster_resolver.task_type 234 task_id = cluster_resolver.task_id 235 rpc_layer = cluster_resolver.rpc_layer 236 237 logging.info( 238 'Starting server with cluster_spec = %r, task_type = %r, ' 239 'task_id = %r, rpc_layer = %r', cluster_spec, task_type, task_id, 240 rpc_layer) 241 242 # TODO(yuefengz): support GPU clusters. 243 server_config = config_pb2.ConfigProto() 244 server_config.device_count['GPU'] = 0 245 246 # Set the environment variable to prevent hanging upon job failure and 247 # restart. Note that it defaults to 'use_caller' at Google, but defaults 248 # to False in OSS. 249 os.environ['GRPC_FAIL_FAST'] = 'use_caller' 250 251 server_lib.Server( 252 cluster_spec, 253 job_name=task_type, 254 protocol=rpc_layer, 255 task_index=task_id, 256 config=server_config, 257 start=True) 258 259 start_event = start_events[task_type][task_id] 260 start_event.set() 261 262 finish_event = finish_events[task_type][task_id] 263 finish_event.wait() 264 265 os._exit(0) # pylint: disable=protected-access 266 267 self._task_function = task_function 268 self._mpr = None 269 270 def start(self): 271 """Starts one TensorFlow server for each task in the cluster_resolver. 272 273 It will wait until all the servers are up before returns. 274 """ 275 if self._mpr: 276 raise ValueError('The cluster has already been started.') 277 for task_type, task_addresses in self._cluster_spec.items(): 278 self._start_events[task_type] = [] 279 self._finish_events[task_type] = [] 280 for _ in task_addresses: 281 self._start_events[task_type].append(self._mpr_manager.Event()) 282 self._finish_events[task_type].append(self._mpr_manager.Event()) 283 284 self._mpr = multi_process_runner.MultiProcessRunner( 285 self._task_function, 286 self._cluster_spec, 287 args=(self._start_events, self._finish_events), 288 rpc_layer=self._rpc_layer, 289 stream_output=False, 290 return_output=False, 291 use_dill_for_args=False) 292 self._mpr.start() 293 for task_type, task_addresses in self._cluster_spec.items(): 294 for i in range(len(task_addresses)): 295 self._start_events[task_type][i].wait() 296 297 def stop(self): 298 """Stops all the servers.""" 299 for task_type, task_addresses in self._cluster_spec.items(): 300 for i in range(len(task_addresses)): 301 self._finish_events[task_type][i].set() 302 try: 303 self._mpr.join() 304 except multi_process_runner.UnexpectedSubprocessExitError: 305 # TODO(yuefengz): investigate why processes exit with 255. 306 pass 307 self._mpr = None 308 self._start_events = {} 309 self._finish_events = {} 310 311 def kill_task(self, task_type, task_id): 312 """Kill a server given task_type and task_id. 313 314 Args: 315 task_type: the type of the task such as "worker". 316 task_id: the id the task such as 1. 317 """ 318 assert self._mpr 319 if (not self._start_events[task_type][task_id].is_set() or 320 self._finish_events[task_type][task_id].is_set()): 321 raise ValueError("The task %s:%d doesn't exist." % (task_type, task_id)) 322 323 self._finish_events[task_type][task_id].set() 324 self._mpr._processes[(task_type, task_id)].join() 325 326 def start_task(self, task_type, task_id): 327 """Starts a server given task_type and task_id. 328 329 Args: 330 task_type: the type of the task such as "worker". 331 task_id: the id the task such as 1. 332 333 Raises: 334 ValueError: if the server alreay exists. 335 """ 336 assert self._mpr 337 338 if (not self._start_events[task_type][task_id].is_set() or 339 not self._finish_events[task_type][task_id].is_set()): 340 raise ValueError( 341 'The task %s:%d is still alive. You cannot start another one.' % 342 (task_type, task_id)) 343 self._start_events[task_type][task_id] = self._mpr_manager.Event() 344 self._finish_events[task_type][task_id] = self._mpr_manager.Event() 345 self._mpr.start_single_process(task_type=task_type, task_id=task_id) 346 self._start_events[task_type][task_id].wait() 347 348 @property 349 def cluster_resolver(self): 350 return copy.deepcopy(self._cluster_resolver) 351 352 353def create_multi_process_cluster(num_workers, 354 num_ps, 355 has_chief=False, 356 has_eval=False, 357 rpc_layer='grpc'): 358 cluster_spec = create_cluster_spec( 359 has_chief=has_chief, 360 num_workers=num_workers, 361 num_ps=num_ps, 362 has_eval=has_eval) 363 364 cluster = MultiProcessCluster( 365 SimpleClusterResolver( 366 server_lib.ClusterSpec(cluster_spec), rpc_layer=rpc_layer)) 367 cluster.start() 368 return cluster 369 370 371@tf_export( 372 '__internal__.distribute.multi_process_runner.create_cluster_spec', v1=[]) 373def create_cluster_spec(has_chief=False, 374 num_workers=1, 375 num_ps=0, 376 has_eval=False): 377 """Create a cluster spec with tasks with unused local ports. 378 379 This utility finds available ports at localhost, and returns a dict that 380 represents the cluster spec that utilizes those ports, according to the 381 arguments. The dict representing the cluster spec contains task types, and 382 their instances' addresses. Note that this is usually only for testing purpose 383 using multiple processes in the local machine, and should not be used for real 384 multi-worker TensorFlow programs, where the addresses need to point to the 385 processes at separate machines. 386 387 This util is useful when creating the `cluster_spec` arg for 388 `tf.__internal__.distribute.multi_process_runner.run`. 389 390 Args: 391 has_chief: Whether the generated cluster spec should contain "chief" task 392 type. 393 num_workers: Number of workers to use in the cluster spec. 394 num_ps: Number of parameter servers to use in the cluster spec. 395 has_eval: Whether this cluster spec has evaluator. 396 397 Returns: 398 A dict that represents the cluster spec using localhost ports for the tasks. 399 400 Example: 401 402 ```python 403 cluster_spec = 404 tf.__internal__.distribute.multi_process_runner.create_cluster_spec( 405 has_chief=True, num_workers=2, num_ps=2) 406 # An example of cluster_spec is 407 # {'chief': ['localhost:23381'], 408 # 'worker': ['localhost:19197', 'localhost:22903'], 409 # 'ps': ['localhost:16912', 'localhost:21535']} 410 411 cluster_spec = 412 tf.__internal__.distribute.multi_process_runner.create_cluster_spec( 413 has_chief=False, num_workers=0, num_ps=0, has_eval=True) 414 # An example of cluster_spec is 415 # {'evaluator': ['localhost:23381']} 416 ``` 417 """ 418 if _portpicker_import_error: 419 raise _portpicker_import_error # pylint: disable=raising-bad-type 420 421 cluster_spec = {} 422 if has_chief: 423 cluster_spec['chief'] = ['localhost:%s' % pick_unused_port()] 424 if num_workers: 425 cluster_spec['worker'] = [ 426 'localhost:%s' % pick_unused_port() for _ in range(num_workers) 427 ] 428 if num_ps: 429 cluster_spec['ps'] = [ 430 'localhost:%s' % pick_unused_port() for _ in range(num_ps) 431 ] 432 if has_eval: 433 cluster_spec['evaluator'] = ['localhost:%s' % pick_unused_port()] 434 return cluster_spec 435 436 437@contextlib.contextmanager 438def skip_if_grpc_server_cant_be_started(test_obj): 439 try: 440 yield 441 except errors.UnknownError as e: 442 if 'Could not start gRPC server' in e.message: 443 reason = 'Cannot start std servers.' 444 test_obj.test_skipped_reason = reason 445 test_obj.skipTest(reason) 446 else: 447 raise 448 449 450class MultiWorkerTestBase(test.TestCase): 451 """Base class for testing multi node strategy and dataset.""" 452 453 @classmethod 454 def setUpClass(cls, num_workers=2, num_ps=1): # pylint: disable=g-missing-super-call 455 """Create a local cluster with 2 workers.""" 456 cls._cluster_spec = create_in_process_cluster(num_workers=num_workers, 457 num_ps=num_ps) 458 cls._default_target = 'grpc://' + cls._cluster_spec['worker'][0] 459 460 def setUp(self): 461 # We only cache the session in one test because another test may have a 462 # different session config or master target. 463 self._thread_local = threading.local() 464 self._thread_local.cached_session = None 465 self._coord = coordinator.Coordinator() 466 467 @contextlib.contextmanager 468 def session(self, graph=None, config=None, target=None): 469 """Create a test session with master target set to the testing cluster. 470 471 Creates a test session that connects to the local testing cluster. 472 473 Args: 474 graph: Optional graph to use during the returned session. 475 config: An optional config_pb2.ConfigProto to use to configure the 476 session. 477 target: the target of session to connect to. 478 479 Yields: 480 A Session object that should be used as a context manager to surround 481 the graph building and execution code in a test case. 482 """ 483 config = self._create_config(config) 484 485 if target is None: 486 target = self._default_target 487 with session.Session(graph=graph, config=config, target=target) as sess: 488 yield sess 489 490 @contextlib.contextmanager 491 # TODO(b/117573461): Overwrite self.evaluate() to use this function. 492 def cached_session(self, graph=None, config=None, target=None): 493 """Create a test session with master target set to the testing cluster. 494 495 Creates a test session that connects to the local testing cluster. 496 The session is only created once per test and then reused. 497 498 Args: 499 graph: Optional graph to use during the returned session. 500 config: An optional config_pb2.ConfigProto to use to configure the 501 session. 502 target: the target of session to connect to. 503 504 Yields: 505 A Session object that should be used as a context manager to surround 506 the graph building and execution code in a test case. Note that the 507 session will live until the end of the test. 508 """ 509 config = self._create_config(config) 510 511 if target is None: 512 target = self._default_target 513 if getattr(self._thread_local, 'cached_session', None) is None: 514 self._thread_local.cached_session = session.Session( 515 graph=None, config=config, target=target) 516 sess = self._thread_local.cached_session 517 with sess.graph.as_default(), sess.as_default(): 518 yield sess 519 520 def _create_config(self, config): 521 if config is None: 522 config = config_pb2.ConfigProto(allow_soft_placement=True) 523 else: 524 config = copy.deepcopy(config) 525 # Don't perform optimizations for tests so we don't inadvertently run 526 # gpu ops on cpu 527 config.graph_options.optimizer_options.opt_level = -1 528 config.graph_options.rewrite_options.constant_folding = ( 529 rewriter_config_pb2.RewriterConfig.OFF) 530 531 return config 532 533 def _run_client(self, client_fn, task_type, task_id, num_gpus, eager_mode, 534 *args, **kwargs): 535 536 def wrapped_client_fn(): 537 with self._coord.stop_on_exception(): 538 client_fn(task_type, task_id, num_gpus, *args, **kwargs) 539 540 if eager_mode: 541 with context.eager_mode(): 542 wrapped_client_fn() 543 else: 544 with context.graph_mode(): 545 wrapped_client_fn() 546 547 def _run_between_graph_clients(self, client_fn, cluster_spec, num_gpus, *args, 548 **kwargs): 549 """Runs several clients for between-graph replication. 550 551 Args: 552 client_fn: a function that needs to accept `task_type`, `task_id`, 553 `num_gpus`. 554 cluster_spec: a dict specifying jobs in a cluster. 555 num_gpus: number of GPUs per worker. 556 *args: will be passed to `client_fn`. 557 **kwargs: will be passed to `client_fn`. 558 """ 559 threads = [] 560 for task_type in ['chief', 'worker']: 561 for task_id in range(len(cluster_spec.get(task_type, []))): 562 t = threading.Thread( 563 target=self._run_client, 564 args=(client_fn, task_type, task_id, num_gpus, 565 context.executing_eagerly()) + args, 566 kwargs=kwargs) 567 t.start() 568 threads.append(t) 569 self._coord.join(threads) 570 571 572class SingleWorkerTestBaseGraph(MultiWorkerTestBase): 573 """Base class for testing remote single worker strategy graph and dataset.""" 574 575 @classmethod 576 def setUpClass(cls): 577 super(SingleWorkerTestBaseGraph, cls).setUpClass(num_workers=1) 578 579 580class SingleWorkerTestBaseEager(test.TestCase): 581 """Base class for testing remote single worker strategy eager and dataset.""" 582 583 def setUp(self): 584 super(SingleWorkerTestBaseEager, self).setUp() 585 workers, _ = test_util.create_local_cluster(num_workers=1, num_ps=0) 586 remote.connect_to_remote_host(workers[0].target) 587 588 def cached_session(self): 589 return DummySession() 590 591 592class DummySession(object): 593 594 def __enter__(self): 595 return 596 597 def __exit__(self, exception_type, exception_value, traceback): 598 pass 599 600 601class MockOsEnv(collections_abc.Mapping): 602 """A class that allows per-thread TF_CONFIG.""" 603 604 def __init__(self, *args): 605 self._dict = dict() 606 self._thread_local = threading.local() 607 super(MockOsEnv, self).__init__(*args) 608 609 def get(self, key, default=None): 610 if not hasattr(self._thread_local, 'dict'): 611 self._thread_local.dict = dict() 612 if key == 'TF_CONFIG': 613 return dict.get(self._thread_local.dict, key, default) 614 else: 615 return dict.get(self._dict, key, default) 616 617 def __getitem__(self, key): 618 if not hasattr(self._thread_local, 'dict'): 619 self._thread_local.dict = dict() 620 if key == 'TF_CONFIG': 621 return dict.__getitem__(self._thread_local.dict, key) 622 else: 623 return dict.__getitem__(self._dict, key) 624 625 def __setitem__(self, key, val): 626 if not hasattr(self._thread_local, 'dict'): 627 self._thread_local.dict = dict() 628 if key == 'TF_CONFIG': 629 return dict.__setitem__(self._thread_local.dict, key, val) 630 else: 631 return dict.__setitem__(self._dict, key, val) 632 633 def __iter__(self): 634 if not hasattr(self._thread_local, 'dict'): 635 self._thread_local.dict = dict() 636 for x in self._thread_local.dict: 637 yield x 638 for x in self._dict: 639 yield x 640 641 def __len__(self): 642 if not hasattr(self._thread_local, 'dict'): 643 self._thread_local.dict = dict() 644 return self._thread_local.dict.__len__() + self._dict.__len__() 645 646 647class IndependentWorkerTestBase(test.TestCase): 648 """Testing infra for independent workers.""" 649 650 def _make_mock_run_std_server(self): 651 652 def _mock_run_std_server(*args, **kwargs): 653 """Returns the std server once all threads have started it.""" 654 with skip_if_grpc_server_cant_be_started(self): 655 ret = original_run_std_server(*args, **kwargs) 656 # Wait for all std servers to be brought up in order to reduce the chance 657 # of remote sessions taking local ports that have been assigned to std 658 # servers. Only call this barrier the first time this function is run for 659 # each thread. 660 if not getattr(self._thread_local, 'server_started', False): 661 self._barrier.wait() 662 self._thread_local.server_started = True 663 return ret 664 665 return _mock_run_std_server 666 667 def setUp(self): 668 self._mock_os_env = MockOsEnv() 669 self._mock_context = test.mock.patch.object(os, 'environ', 670 self._mock_os_env) 671 self._coord = coordinator.Coordinator() 672 super(IndependentWorkerTestBase, self).setUp() 673 self._mock_context.__enter__() 674 # threading local object to be shared by all threads 675 self._thread_local = threading.local() 676 677 def tearDown(self): 678 self._mock_context.__exit__(None, None, None) 679 super(IndependentWorkerTestBase, self).tearDown() 680 681 def _task_thread(self, task_fn, tf_config, executing_eagerly, *args, 682 **kwargs): 683 with self._coord.stop_on_exception(): 684 os.environ['TF_CONFIG'] = json.dumps(tf_config) 685 # Force the new thread simulating a worker to run in the same context 686 # mode as the parent thread does. 687 if executing_eagerly: 688 with context.eager_mode(): 689 task_fn(*args, **kwargs) 690 else: 691 with ops.Graph().as_default(), context.graph_mode(): 692 task_fn(*args, **kwargs) 693 694 def _run_task_in_thread(self, task_fn, cluster_spec, task_type, task_id, 695 *args, **kwargs): 696 """Run tasks in a thread. 697 698 If `tf_config` is provided, use it for the new thread; if not, construct one 699 from `cluster_spec`, `task_type`, and `task_id`, and provide it to the new 700 thread to be set as `TF_CONFIG` environment. 701 702 Args: 703 task_fn: The function to run in the new thread. 704 cluster_spec: The cluster spec. 705 task_type: The task type. 706 task_id: The task id. 707 *args: Additional positional arguments to provide to the thread's task_fn. 708 **kwargs: Additional keyword arguments to provide to the thread's task_fn. 709 If `tf_config` is provided, that dict will be used for the TF_CONFIG for 710 the new thread. 711 712 Returns: 713 The thread that has started. 714 """ 715 tf_config = kwargs.pop('tf_config', None) 716 if tf_config is None: 717 if task_type: 718 tf_config = { 719 'cluster': cluster_spec, 720 'task': { 721 'type': task_type, 722 'index': task_id 723 } 724 } 725 else: 726 tf_config = { 727 'cluster': cluster_spec, 728 } 729 t = threading.Thread( 730 target=self._task_thread, 731 args=(task_fn, tf_config, context.executing_eagerly()) + args, 732 kwargs=kwargs) 733 t.start() 734 return t 735 736 def run_multiple_tasks_in_threads(self, task_fn, cluster_spec, *args, 737 **kwargs): 738 # The task_fn should create std_server by itself. 739 threads = {} 740 for task_type in cluster_spec.keys(): 741 threads[task_type] = [] 742 for task_id in range(len(cluster_spec[task_type])): 743 t = self._run_task_in_thread(task_fn, cluster_spec, task_type, task_id, 744 *args, **kwargs) 745 threads[task_type].append(t) 746 return threads 747 748 def join_independent_workers(self, worker_threads): 749 with skip_if_grpc_server_cant_be_started(self): 750 self._coord.join(worker_threads) 751 752 753class MultiWorkerMultiProcessTest(test.TestCase): 754 """Testing infra for independent workers using multiple processes.""" 755 756 def _run_task_in_process(self, cmd_args, cluster_spec, task_type, task_id): 757 env = os.environ.copy() 758 env['TF_CONFIG'] = json.dumps({ 759 'cluster': cluster_spec, 760 'task': { 761 'type': task_type, 762 'index': task_id 763 } 764 }) 765 return subprocess.Popen( 766 cmd_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env) 767 768 @deprecation.deprecated( 769 None, '`run_multiple_tasks_in_processes` is deprecated; any new test ' 770 'requiring multiple processes should use `multi_process_runner` for ' 771 'better support of log printing, streaming, and more functionality.') 772 def run_multiple_tasks_in_processes(self, cmd_args, cluster_spec): 773 """Run `cmd_args` in a process for each task in `cluster_spec`.""" 774 processes = {} 775 for task_type in cluster_spec.keys(): 776 processes[task_type] = [] 777 for task_id in range(len(cluster_spec[task_type])): 778 p = self._run_task_in_process(cmd_args, cluster_spec, task_type, 779 task_id) 780 processes[task_type].append(p) 781 return processes 782 783 @deprecation.deprecated( 784 None, '`join_independent_workers` is deprecated; any new test ' 785 'requiring multiple processes should use `multi_process_runner` for ' 786 'better support of log printing, streaming, and more functionality.') 787 def join_independent_workers(self, worker_processes): 788 return_codes = [] 789 for p in nest.flatten(worker_processes): 790 try: 791 # Calling p.wait() will hang if we don't consume its output. 792 p.communicate() 793 except ValueError: 794 # The output of the process may have been consumed, in which case 795 # calling `p.communicate()` will raise a ValueError. 796 pass 797 finally: 798 return_codes.append(p.returncode) 799 for return_code in return_codes: 800 self.assertEqual(return_code, 0) 801 802 @deprecation.deprecated( 803 None, '`stream_stderr` is deprecated; any new test ' 804 'requiring multiple processes should use `multi_process_runner` for ' 805 'better support of log printing, streaming, and more functionality.') 806 def stream_stderr(self, processes, print_only_first=False): 807 """Consume stderr of all processes and print to stdout. 808 809 To reduce the amount of logging, caller can set print_only_first to True. 810 In that case, this function only prints stderr from the first process of 811 each type. 812 813 Args: 814 processes: A dictionary from process type string -> list of processes. 815 print_only_first: If true, only print output from first process of each 816 type. 817 """ 818 819 def _stream_stderr_single_process(process, type_string, index, 820 print_to_stdout): 821 """Consume a single process's stderr and optionally print to stdout.""" 822 while True: 823 output = process.stderr.readline() 824 if not output and process.poll() is not None: 825 break 826 if output and print_to_stdout: 827 print('{}{} {}'.format(type_string, index, output.strip())) 828 sys.stdout.flush() 829 830 stream_threads = [] 831 for process_type, process_list in six.iteritems(processes): 832 for i in range(len(process_list)): 833 print_to_stdout = (not print_only_first) or (i == 0) 834 thread = threading.Thread( 835 target=_stream_stderr_single_process, 836 args=(process_list[i], process_type, i, print_to_stdout)) 837 thread.start() 838 stream_threads.append(thread) 839 for thread in stream_threads: 840 thread.join() 841 842 843def get_tf_config_task(): 844 return json.loads(os.environ['TF_CONFIG'])['task'] 845 846 847def get_tf_config_cluster_spec(): 848 return json.loads(os.environ['TF_CONFIG'])['cluster'] 849 850 851def get_task_type(): 852 return get_tf_config_task()['type'] 853 854 855def get_task_index(): 856 return get_tf_config_task()['index'] 857 858 859def is_chief(): 860 return ('chief' not in get_tf_config_cluster_spec() 861 and get_task_type() == 'worker' 862 and get_task_index() == 0) 863