• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Base testing class for strategies that require multiple nodes."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import contextlib
22import copy
23import json
24import os
25import subprocess
26import sys
27import threading
28import unittest
29
30import six
31
32_portpicker_import_error = None
33try:
34  import portpicker  # pylint: disable=g-import-not-at-top
35except (ImportError, ModuleNotFoundError) as _error:  # pylint: disable=invalid-name
36  _portpicker_import_error = _error
37  portpicker = None
38
39# pylint: disable=g-import-not-at-top
40from tensorflow.core.protobuf import config_pb2
41from tensorflow.core.protobuf import rewriter_config_pb2
42from tensorflow.python.client import session
43from tensorflow.python.distribute import distribute_coordinator as dc
44from tensorflow.python.distribute import multi_process_runner
45from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
46from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
47from tensorflow.python.eager import context
48from tensorflow.python.eager import remote
49from tensorflow.python.framework import errors
50from tensorflow.python.framework import ops
51from tensorflow.python.framework import test_util
52from tensorflow.python.platform import test
53from tensorflow.python.platform import tf_logging as logging
54from tensorflow.python.training import coordinator
55from tensorflow.python.training import server_lib
56from tensorflow.python.util import deprecation
57from tensorflow.python.util import nest
58from tensorflow.python.util.compat import collections_abc
59from tensorflow.python.util.tf_export import tf_export
60
61
62original_run_std_server = dc._run_std_server  # pylint: disable=protected-access
63
64ASSIGNED_PORTS = set()
65lock = threading.Lock()
66
67
68def pick_unused_port():
69  """Returns an unused and unassigned local port."""
70  if _portpicker_import_error:
71    raise _portpicker_import_error  # pylint: disable=raising-bad-type
72
73  global ASSIGNED_PORTS
74  with lock:
75    while True:
76      try:
77        port = portpicker.pick_unused_port()
78      except portpicker.NoFreePortFoundError:
79        raise unittest.SkipTest('Flakes in portpicker library do not represent '
80                                'TensorFlow errors.')
81      if port > 10000 and port not in ASSIGNED_PORTS:
82        ASSIGNED_PORTS.add(port)
83        logging.info('Using local port %r', port)
84        return port
85
86
87def _create_cluster(num_workers,
88                    num_ps,
89                    has_chief=False,
90                    has_eval=False,
91                    protocol='grpc',
92                    worker_config=None,
93                    ps_config=None,
94                    eval_config=None,
95                    worker_name='worker',
96                    ps_name='ps',
97                    chief_name='chief'):
98  """Creates and starts local servers and returns the cluster_spec dict."""
99  if _portpicker_import_error:
100    raise _portpicker_import_error  # pylint: disable=raising-bad-type
101  worker_ports = [pick_unused_port() for _ in range(num_workers)]
102  ps_ports = [pick_unused_port() for _ in range(num_ps)]
103
104  cluster_dict = {}
105  if num_workers > 0:
106    cluster_dict[worker_name] = ['localhost:%s' % port for port in worker_ports]
107  if num_ps > 0:
108    cluster_dict[ps_name] = ['localhost:%s' % port for port in ps_ports]
109  if has_eval:
110    cluster_dict['evaluator'] = ['localhost:%s' % pick_unused_port()]
111  if has_chief:
112    cluster_dict[chief_name] = ['localhost:%s' % pick_unused_port()]
113
114  cs = server_lib.ClusterSpec(cluster_dict)
115
116  for i in range(num_workers):
117    server_lib.Server(
118        cs,
119        job_name=worker_name,
120        protocol=protocol,
121        task_index=i,
122        config=worker_config,
123        start=True)
124
125  for i in range(num_ps):
126    server_lib.Server(
127        cs,
128        job_name=ps_name,
129        protocol=protocol,
130        task_index=i,
131        config=ps_config,
132        start=True)
133
134  if has_chief:
135    server_lib.Server(
136        cs,
137        job_name=chief_name,
138        protocol=protocol,
139        task_index=0,
140        config=worker_config,
141        start=True)
142
143  if has_eval:
144    server_lib.Server(
145        cs,
146        job_name='evaluator',
147        protocol=protocol,
148        task_index=0,
149        config=eval_config,
150        start=True)
151
152  return cluster_dict
153
154
155def create_in_process_cluster(num_workers,
156                              num_ps,
157                              has_chief=False,
158                              has_eval=False,
159                              rpc_layer='grpc'):
160  """Create an in-process cluster that consists of only standard server."""
161  # Leave some memory for cuda runtime.
162  gpu_mem_frac = 0.7 / (num_workers + int(has_chief) + int(has_eval))
163  worker_config = config_pb2.ConfigProto()
164  worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac
165
166  # The cluster may hang if workers don't have enough inter_op threads. See
167  # b/172296720 for more details.
168  if worker_config.inter_op_parallelism_threads < num_workers + 1:
169    worker_config.inter_op_parallelism_threads = num_workers + 1
170
171  # Enable collective ops which has no impact on non-collective ops.
172  # TODO(yuefengz, tucker): removing this after we move the initialization of
173  # collective mgr to the session level.
174  if has_chief:
175    worker_config.experimental.collective_group_leader = (
176        '/job:chief/replica:0/task:0')
177  else:
178    worker_config.experimental.collective_group_leader = (
179        '/job:worker/replica:0/task:0')
180
181  ps_config = config_pb2.ConfigProto()
182  ps_config.device_count['GPU'] = 0
183
184  eval_config = config_pb2.ConfigProto()
185  eval_config.experimental.collective_group_leader = ''
186
187  # Create in-process servers. Once an in-process tensorflow server is created,
188  # there is no way to terminate it. So we create one cluster per test process.
189  # We could've started the server in another process, we could then kill that
190  # process to terminate the server. The reasons why we don't want multiple
191  # processes are
192  # 1) it is more difficult to manage these processes;
193  # 2) there is something global in CUDA such that if we initialize CUDA in the
194  # parent process, the child process cannot initialize it again and thus cannot
195  # use GPUs (https://stackoverflow.com/questions/22950047).
196  cluster = None
197  try:
198    cluster = _create_cluster(
199        num_workers,
200        num_ps=num_ps,
201        has_chief=has_chief,
202        has_eval=has_eval,
203        worker_config=worker_config,
204        ps_config=ps_config,
205        eval_config=eval_config,
206        protocol=rpc_layer)
207  except errors.UnknownError as e:
208    if 'Could not start gRPC server' in e.message:
209      raise unittest.SkipTest('Cannot start std servers.')
210    else:
211      raise
212  return cluster
213
214
215class MultiProcessCluster(object):
216  """A cluster of TensorFlow servers in separate processes.
217
218  This class is not thread-safe.
219  """
220
221  def __init__(self,
222               cluster_resolver,
223               stream_output=False,
224               collective_leader=None):
225    self._cluster_resolver = cluster_resolver
226    self._cluster_spec = cluster_resolver.cluster_spec().as_dict()
227    self._rpc_layer = cluster_resolver.rpc_layer
228    self._stream_output = stream_output
229    self._start_events = {}
230    self._finish_events = {}
231    self._mpr_manager = multi_process_runner.manager()
232
233    def task_function(start_events, finish_events):
234      cluster_resolver = TFConfigClusterResolver()
235      cluster_spec = cluster_resolver.cluster_spec()
236      task_type = cluster_resolver.task_type
237      task_id = cluster_resolver.task_id
238      rpc_layer = cluster_resolver.rpc_layer
239
240      # TODO(yuefengz): support GPU clusters.
241      server_config = config_pb2.ConfigProto()
242      server_config.device_count['GPU'] = 0
243
244      if collective_leader:
245        server_config.experimental.collective_group_leader = collective_leader
246        server_config.experimental.collective_nccl = False
247
248        logging.info(
249            'Enabling collective ops with cluster_spec = %r, task_type = %r, '
250            'task_id = %r, rpc_layer = %r, collective_leader = %s',
251            cluster_spec, task_type, task_id, rpc_layer, collective_leader)
252      else:
253        logging.info(
254            'Starting server with cluster_spec = %r, task_type = %r, '
255            'task_id = %r, rpc_layer = %r', cluster_spec, task_type, task_id,
256            rpc_layer)
257
258      server_lib.Server(
259          cluster_spec,
260          job_name=task_type,
261          protocol=rpc_layer,
262          task_index=task_id,
263          config=server_config,
264          start=True)
265
266      start_event = start_events[task_type][task_id]
267      start_event.set()
268
269      finish_event = finish_events[task_type][task_id]
270      finish_event.wait()
271
272      os._exit(0)  # pylint: disable=protected-access
273
274    self._task_function = task_function
275    self._mpr = None
276
277  def start(self):
278    """Starts one TensorFlow server for each task in the cluster_resolver.
279
280    It will wait until all the servers are up before returns.
281    """
282    if self._mpr:
283      raise ValueError('The cluster has already been started.')
284    for task_type, task_addresses in self._cluster_spec.items():
285      self._start_events[task_type] = []
286      self._finish_events[task_type] = []
287      for _ in task_addresses:
288        self._start_events[task_type].append(self._mpr_manager.Event())
289        self._finish_events[task_type].append(self._mpr_manager.Event())
290
291    self._mpr = multi_process_runner.MultiProcessRunner(
292        self._task_function,
293        self._cluster_spec,
294        args=(self._start_events, self._finish_events),
295        rpc_layer=self._rpc_layer,
296        stream_output=self._stream_output,
297        return_output=False,
298        use_dill_for_args=False)
299    self._mpr.start()
300    for task_type, task_addresses in self._cluster_spec.items():
301      for i in range(len(task_addresses)):
302        self._start_events[task_type][i].wait()
303
304  def stop(self):
305    """Stops all the servers."""
306    for task_type, task_addresses in self._cluster_spec.items():
307      for i in range(len(task_addresses)):
308        self._finish_events[task_type][i].set()
309    try:
310      self._mpr.join()
311    except multi_process_runner.UnexpectedSubprocessExitError:
312      # TODO(yuefengz): investigate why processes exit with 255.
313      pass
314    self._mpr = None
315    self._start_events = {}
316    self._finish_events = {}
317
318  def kill_task(self, task_type, task_id):
319    """Kill a server given task_type and task_id.
320
321    Args:
322      task_type: the type of the task such as "worker".
323      task_id: the id the task such as 1.
324    """
325    assert self._mpr
326    if (not self._start_events[task_type][task_id].is_set() or
327        self._finish_events[task_type][task_id].is_set()):
328      raise ValueError("The task %s:%d doesn't exist." % (task_type, task_id))
329
330    self._finish_events[task_type][task_id].set()
331    self._mpr._processes[(task_type, task_id)].join()
332
333  def start_task(self, task_type, task_id):
334    """Starts a server given task_type and task_id.
335
336    Args:
337      task_type: the type of the task such as "worker".
338      task_id: the id the task such as 1.
339
340    Raises:
341      ValueError: if the server alreay exists.
342    """
343    assert self._mpr
344
345    if (not self._start_events[task_type][task_id].is_set() or
346        not self._finish_events[task_type][task_id].is_set()):
347      raise ValueError(
348          'The task %s:%d is still alive. You cannot start another one.' %
349          (task_type, task_id))
350    self._start_events[task_type][task_id] = self._mpr_manager.Event()
351    self._finish_events[task_type][task_id] = self._mpr_manager.Event()
352    self._mpr.start_single_process(task_type=task_type, task_id=task_id)
353    self._start_events[task_type][task_id].wait()
354
355  @property
356  def cluster_resolver(self):
357    return copy.deepcopy(self._cluster_resolver)
358
359
360def create_multi_process_cluster(num_workers,
361                                 num_ps,
362                                 has_chief=False,
363                                 has_eval=False,
364                                 rpc_layer='grpc',
365                                 stream_output=False,
366                                 collective_leader=None):
367  cluster_spec = create_cluster_spec(
368      has_chief=has_chief,
369      num_workers=num_workers,
370      num_ps=num_ps,
371      has_eval=has_eval)
372
373  cluster = MultiProcessCluster(
374      SimpleClusterResolver(
375          server_lib.ClusterSpec(cluster_spec), rpc_layer=rpc_layer),
376      stream_output=stream_output,
377      collective_leader=collective_leader)
378  cluster.start()
379  return cluster
380
381
382@tf_export(
383    '__internal__.distribute.multi_process_runner.create_cluster_spec', v1=[])
384def create_cluster_spec(has_chief=False,
385                        num_workers=1,
386                        num_ps=0,
387                        has_eval=False):
388  """Create a cluster spec with tasks with unused local ports.
389
390  This utility finds available ports at localhost, and returns a dict that
391  represents the cluster spec that utilizes those ports, according to the
392  arguments. The dict representing the cluster spec contains task types, and
393  their instances' addresses. Note that this is usually only for testing purpose
394  using multiple processes in the local machine, and should not be used for real
395  multi-worker TensorFlow programs, where the addresses need to point to the
396  processes at separate machines.
397
398  This util is useful when creating the `cluster_spec` arg for
399  `tf.__internal__.distribute.multi_process_runner.run`.
400
401  Args:
402    has_chief: Whether the generated cluster spec should contain "chief" task
403      type.
404    num_workers: Number of workers to use in the cluster spec.
405    num_ps: Number of parameter servers to use in the cluster spec.
406    has_eval: Whether this cluster spec has evaluator.
407
408  Returns:
409    A dict that represents the cluster spec using localhost ports for the tasks.
410
411  Example:
412
413  ```python
414  cluster_spec =
415  tf.__internal__.distribute.multi_process_runner.create_cluster_spec(
416      has_chief=True, num_workers=2, num_ps=2)
417  # An example of cluster_spec is
418  # {'chief': ['localhost:23381'],
419  # 'worker': ['localhost:19197', 'localhost:22903'],
420  # 'ps': ['localhost:16912', 'localhost:21535']}
421
422  cluster_spec =
423  tf.__internal__.distribute.multi_process_runner.create_cluster_spec(
424      has_chief=False, num_workers=0, num_ps=0, has_eval=True)
425  # An example of cluster_spec is
426  # {'evaluator': ['localhost:23381']}
427  ```
428  """
429  if _portpicker_import_error:
430    raise _portpicker_import_error  # pylint: disable=raising-bad-type
431
432  cluster_spec = {}
433  if has_chief:
434    cluster_spec['chief'] = ['localhost:%s' % pick_unused_port()]
435  if num_workers:
436    cluster_spec['worker'] = [
437        'localhost:%s' % pick_unused_port() for _ in range(num_workers)
438    ]
439  if num_ps:
440    cluster_spec['ps'] = [
441        'localhost:%s' % pick_unused_port() for _ in range(num_ps)
442    ]
443  if has_eval:
444    cluster_spec['evaluator'] = ['localhost:%s' % pick_unused_port()]
445  return cluster_spec
446
447
448@contextlib.contextmanager
449def skip_if_grpc_server_cant_be_started(test_obj):
450  try:
451    yield
452  except errors.UnknownError as e:
453    if 'Could not start gRPC server' in e.message:
454      reason = 'Cannot start std servers.'
455      test_obj.test_skipped_reason = reason
456      test_obj.skipTest(reason)
457    else:
458      raise
459
460
461class MultiWorkerTestBase(test.TestCase):
462  """Base class for testing multi node strategy and dataset."""
463
464  @classmethod
465  def setUpClass(cls, num_workers=2, num_ps=1):  # pylint: disable=g-missing-super-call
466    """Create a local cluster with 2 workers."""
467    cls._cluster_spec = create_in_process_cluster(num_workers=num_workers,
468                                                  num_ps=num_ps)
469    cls._default_target = 'grpc://' + cls._cluster_spec['worker'][0]
470
471  def setUp(self):
472    # We only cache the session in one test because another test may have a
473    # different session config or master target.
474    self._thread_local = threading.local()
475    self._thread_local.cached_session = None
476    self._coord = coordinator.Coordinator()
477
478  @contextlib.contextmanager
479  def session(self, graph=None, config=None, target=None):
480    """Create a test session with master target set to the testing cluster.
481
482    Creates a test session that connects to the local testing cluster.
483
484    Args:
485      graph: Optional graph to use during the returned session.
486      config: An optional config_pb2.ConfigProto to use to configure the
487        session.
488      target: the target of session to connect to.
489
490    Yields:
491      A Session object that should be used as a context manager to surround
492      the graph building and execution code in a test case.
493    """
494    config = self._create_config(config)
495
496    if target is None:
497      target = self._default_target
498    with session.Session(graph=graph, config=config, target=target) as sess:
499      yield sess
500
501  @contextlib.contextmanager
502  # TODO(b/117573461): Overwrite self.evaluate() to use this function.
503  def cached_session(self, graph=None, config=None, target=None):
504    """Create a test session with master target set to the testing cluster.
505
506    Creates a test session that connects to the local testing cluster.
507    The session is only created once per test and then reused.
508
509    Args:
510      graph: Optional graph to use during the returned session.
511      config: An optional config_pb2.ConfigProto to use to configure the
512        session.
513      target: the target of session to connect to.
514
515    Yields:
516      A Session object that should be used as a context manager to surround
517      the graph building and execution code in a test case. Note that the
518      session will live until the end of the test.
519    """
520    config = self._create_config(config)
521
522    if target is None:
523      target = self._default_target
524    if getattr(self._thread_local, 'cached_session', None) is None:
525      self._thread_local.cached_session = session.Session(
526          graph=None, config=config, target=target)
527    sess = self._thread_local.cached_session
528    with sess.graph.as_default(), sess.as_default():
529      yield sess
530
531  def _create_config(self, config):
532    if config is None:
533      config = config_pb2.ConfigProto(allow_soft_placement=True)
534    else:
535      config = copy.deepcopy(config)
536    # Don't perform optimizations for tests so we don't inadvertently run
537    # gpu ops on cpu
538    config.graph_options.optimizer_options.opt_level = -1
539    config.graph_options.rewrite_options.constant_folding = (
540        rewriter_config_pb2.RewriterConfig.OFF)
541
542    return config
543
544  def _run_client(self, client_fn, task_type, task_id, num_gpus, eager_mode,
545                  *args, **kwargs):
546
547    def wrapped_client_fn():
548      with self._coord.stop_on_exception():
549        client_fn(task_type, task_id, num_gpus, *args, **kwargs)
550
551    if eager_mode:
552      with context.eager_mode():
553        wrapped_client_fn()
554    else:
555      with context.graph_mode():
556        wrapped_client_fn()
557
558  def _run_between_graph_clients(self, client_fn, cluster_spec, num_gpus, *args,
559                                 **kwargs):
560    """Runs several clients for between-graph replication.
561
562    Args:
563      client_fn: a function that needs to accept `task_type`, `task_id`,
564        `num_gpus`.
565      cluster_spec: a dict specifying jobs in a cluster.
566      num_gpus: number of GPUs per worker.
567      *args: will be passed to `client_fn`.
568      **kwargs: will be passed to `client_fn`.
569    """
570    threads = []
571    for task_type in ['chief', 'worker']:
572      for task_id in range(len(cluster_spec.get(task_type, []))):
573        t = threading.Thread(
574            target=self._run_client,
575            args=(client_fn, task_type, task_id, num_gpus,
576                  context.executing_eagerly()) + args,
577            kwargs=kwargs)
578        t.start()
579        threads.append(t)
580    self._coord.join(threads)
581
582
583class SingleWorkerTestBaseGraph(MultiWorkerTestBase):
584  """Base class for testing remote single worker strategy graph and dataset."""
585
586  @classmethod
587  def setUpClass(cls):
588    super(SingleWorkerTestBaseGraph, cls).setUpClass(num_workers=1)
589
590
591class SingleWorkerTestBaseEager(test.TestCase):
592  """Base class for testing remote single worker strategy eager and dataset."""
593
594  def setUp(self):
595    super(SingleWorkerTestBaseEager, self).setUp()
596    workers, _ = test_util.create_local_cluster(num_workers=1, num_ps=0)
597    remote.connect_to_remote_host(workers[0].target)
598
599  def cached_session(self):
600    return DummySession()
601
602
603class DummySession(object):
604
605  def __enter__(self):
606    return
607
608  def __exit__(self, exception_type, exception_value, traceback):
609    pass
610
611
612class MockOsEnv(collections_abc.Mapping):
613  """A class that allows per-thread TF_CONFIG."""
614
615  def __init__(self, *args):
616    self._dict = dict()
617    self._thread_local = threading.local()
618    super(MockOsEnv, self).__init__(*args)
619
620  def get(self, key, default=None):
621    if not hasattr(self._thread_local, 'dict'):
622      self._thread_local.dict = dict()
623    if key == 'TF_CONFIG':
624      return dict.get(self._thread_local.dict, key, default)
625    else:
626      return dict.get(self._dict, key, default)
627
628  def __getitem__(self, key):
629    if not hasattr(self._thread_local, 'dict'):
630      self._thread_local.dict = dict()
631    if key == 'TF_CONFIG':
632      return dict.__getitem__(self._thread_local.dict, key)
633    else:
634      return dict.__getitem__(self._dict, key)
635
636  def __setitem__(self, key, val):
637    if not hasattr(self._thread_local, 'dict'):
638      self._thread_local.dict = dict()
639    if key == 'TF_CONFIG':
640      return dict.__setitem__(self._thread_local.dict, key, val)
641    else:
642      return dict.__setitem__(self._dict, key, val)
643
644  def __iter__(self):
645    if not hasattr(self._thread_local, 'dict'):
646      self._thread_local.dict = dict()
647    for x in self._thread_local.dict:
648      yield x
649    for x in self._dict:
650      yield x
651
652  def __len__(self):
653    if not hasattr(self._thread_local, 'dict'):
654      self._thread_local.dict = dict()
655    return self._thread_local.dict.__len__() + self._dict.__len__()
656
657
658class IndependentWorkerTestBase(test.TestCase):
659  """Testing infra for independent workers."""
660
661  def _make_mock_run_std_server(self):
662
663    def _mock_run_std_server(*args, **kwargs):
664      """Returns the std server once all threads have started it."""
665      with skip_if_grpc_server_cant_be_started(self):
666        ret = original_run_std_server(*args, **kwargs)
667      # Wait for all std servers to be brought up in order to reduce the chance
668      # of remote sessions taking local ports that have been assigned to std
669      # servers. Only call this barrier the first time this function is run for
670      # each thread.
671      if not getattr(self._thread_local, 'server_started', False):
672        self._barrier.wait()
673      self._thread_local.server_started = True
674      return ret
675
676    return _mock_run_std_server
677
678  def setUp(self):
679    self._mock_os_env = MockOsEnv()
680    self._mock_context = test.mock.patch.object(os, 'environ',
681                                                self._mock_os_env)
682    self._coord = coordinator.Coordinator()
683    super(IndependentWorkerTestBase, self).setUp()
684    self._mock_context.__enter__()
685    # threading local object to be shared by all threads
686    self._thread_local = threading.local()
687
688  def tearDown(self):
689    self._mock_context.__exit__(None, None, None)
690    super(IndependentWorkerTestBase, self).tearDown()
691
692  def _task_thread(self, task_fn, tf_config, executing_eagerly, *args,
693                   **kwargs):
694    with self._coord.stop_on_exception():
695      os.environ['TF_CONFIG'] = json.dumps(tf_config)
696      # Force the new thread simulating a worker to run in the same context
697      # mode as the parent thread does.
698      if executing_eagerly:
699        with context.eager_mode():
700          task_fn(*args, **kwargs)
701      else:
702        with ops.Graph().as_default(), context.graph_mode():
703          task_fn(*args, **kwargs)
704
705  def _run_task_in_thread(self, task_fn, cluster_spec, task_type, task_id,
706                          *args, **kwargs):
707    """Run tasks in a thread.
708
709    If `tf_config` is provided, use it for the new thread; if not, construct one
710    from `cluster_spec`, `task_type`, and `task_id`, and provide it to the new
711    thread to be set as `TF_CONFIG` environment.
712
713    Args:
714      task_fn: The function to run in the new thread.
715      cluster_spec: The cluster spec.
716      task_type: The task type.
717      task_id: The task id.
718      *args: Additional positional arguments to provide to the thread's task_fn.
719      **kwargs: Additional keyword arguments to provide to the thread's task_fn.
720        If `tf_config` is provided, that dict will be used for the TF_CONFIG for
721        the new thread.
722
723    Returns:
724      The thread that has started.
725    """
726    tf_config = kwargs.pop('tf_config', None)
727    if tf_config is None:
728      if task_type:
729        tf_config = {
730            'cluster': cluster_spec,
731            'task': {
732                'type': task_type,
733                'index': task_id
734            }
735        }
736      else:
737        tf_config = {
738            'cluster': cluster_spec,
739        }
740    t = threading.Thread(
741        target=self._task_thread,
742        args=(task_fn, tf_config, context.executing_eagerly()) + args,
743        kwargs=kwargs)
744    t.start()
745    return t
746
747  def run_multiple_tasks_in_threads(self, task_fn, cluster_spec, *args,
748                                    **kwargs):
749    # The task_fn should create std_server by itself.
750    threads = {}
751    for task_type in cluster_spec.keys():
752      threads[task_type] = []
753      for task_id in range(len(cluster_spec[task_type])):
754        t = self._run_task_in_thread(task_fn, cluster_spec, task_type, task_id,
755                                     *args, **kwargs)
756        threads[task_type].append(t)
757    return threads
758
759  def join_independent_workers(self, worker_threads):
760    with skip_if_grpc_server_cant_be_started(self):
761      self._coord.join(worker_threads)
762
763
764class MultiWorkerMultiProcessTest(test.TestCase):
765  """Testing infra for independent workers using multiple processes."""
766
767  def _run_task_in_process(self, cmd_args, cluster_spec, task_type, task_id):
768    env = os.environ.copy()
769    env['TF_CONFIG'] = json.dumps({
770        'cluster': cluster_spec,
771        'task': {
772            'type': task_type,
773            'index': task_id
774        }
775    })
776    return subprocess.Popen(
777        cmd_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env)
778
779  @deprecation.deprecated(
780      None, '`run_multiple_tasks_in_processes` is deprecated; any new test '
781      'requiring multiple processes should use `multi_process_runner` for '
782      'better support of log printing, streaming, and more functionality.')
783  def run_multiple_tasks_in_processes(self, cmd_args, cluster_spec):
784    """Run `cmd_args` in a process for each task in `cluster_spec`."""
785    processes = {}
786    for task_type in cluster_spec.keys():
787      processes[task_type] = []
788      for task_id in range(len(cluster_spec[task_type])):
789        p = self._run_task_in_process(cmd_args, cluster_spec, task_type,
790                                      task_id)
791        processes[task_type].append(p)
792    return processes
793
794  @deprecation.deprecated(
795      None, '`join_independent_workers` is deprecated; any new test '
796      'requiring multiple processes should use `multi_process_runner` for '
797      'better support of log printing, streaming, and more functionality.')
798  def join_independent_workers(self, worker_processes):
799    return_codes = []
800    for p in nest.flatten(worker_processes):
801      try:
802        # Calling p.wait() will hang if we don't consume its output.
803        p.communicate()
804      except ValueError:
805        # The output of the process may have been consumed, in which case
806        # calling `p.communicate()` will raise a ValueError.
807        pass
808      finally:
809        return_codes.append(p.returncode)
810    for return_code in return_codes:
811      self.assertEqual(return_code, 0)
812
813  @deprecation.deprecated(
814      None, '`stream_stderr` is deprecated; any new test '
815      'requiring multiple processes should use `multi_process_runner` for '
816      'better support of log printing, streaming, and more functionality.')
817  def stream_stderr(self, processes, print_only_first=False):
818    """Consume stderr of all processes and print to stdout.
819
820    To reduce the amount of logging, caller can set print_only_first to True.
821    In that case, this function only prints stderr from the first process of
822    each type.
823
824    Args:
825      processes: A dictionary from process type string -> list of processes.
826      print_only_first: If true, only print output from first process of each
827        type.
828    """
829
830    def _stream_stderr_single_process(process, type_string, index,
831                                      print_to_stdout):
832      """Consume a single process's stderr and optionally print to stdout."""
833      while True:
834        output = process.stderr.readline()
835        if not output and process.poll() is not None:
836          break
837        if output and print_to_stdout:
838          print('{}{} {}'.format(type_string, index, output.strip()))
839          sys.stdout.flush()
840
841    stream_threads = []
842    for process_type, process_list in six.iteritems(processes):
843      for i in range(len(process_list)):
844        print_to_stdout = (not print_only_first) or (i == 0)
845        thread = threading.Thread(
846            target=_stream_stderr_single_process,
847            args=(process_list[i], process_type, i, print_to_stdout))
848        thread.start()
849        stream_threads.append(thread)
850    for thread in stream_threads:
851      thread.join()
852
853
854def get_tf_config_task():
855  return json.loads(os.environ['TF_CONFIG'])['task']
856
857
858def get_tf_config_cluster_spec():
859  return json.loads(os.environ['TF_CONFIG'])['cluster']
860
861
862def get_task_type():
863  return get_tf_config_task()['type']
864
865
866def get_task_index():
867  return get_tf_config_task()['index']
868
869
870def is_chief():
871  return ('chief' not in get_tf_config_cluster_spec()
872          and get_task_type() == 'worker'
873          and get_task_index() == 0)
874