• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A component for running distributed TensorFlow."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22import json
23import os
24import threading
25import time
26
27from tensorflow.core.protobuf import config_pb2
28from tensorflow.python.client import session
29from tensorflow.python.distribute import distribute_coordinator_context
30from tensorflow.python.distribute import multi_worker_util
31from tensorflow.python.platform import tf_logging as logging
32from tensorflow.python.training import coordinator
33from tensorflow.python.training import monitored_session
34from tensorflow.python.training import server_lib
35
36
37_thread_local = threading.local()
38
39
40class _TaskType(object):
41  PS = "ps"
42  WORKER = "worker"
43  CHIEF = "chief"
44  EVALUATOR = "evaluator"
45  CLIENT = "client"
46
47
48# TODO(yuefengz): support another mode where the client colocates with one
49# worker.
50class CoordinatorMode(object):
51  """Specify how distribute coordinator runs."""
52  # The default mode where distribute coordinator will run as a standalone
53  # client and connects to remote servers for training.  Each remote server can
54  # use the distribute coordinator binary with task_type set correctly which
55  # will then turn into standard servers.
56  STANDALONE_CLIENT = "standalone_client"
57
58  # The distribute coordinator runs on each worker. It will run a standard
59  # server on each worker and optionally run the `worker_fn` that is configured
60  # to talk to its standard server.
61  INDEPENDENT_WORKER = "independent_worker"
62
63
64class _Barrier(object):
65  """A reusable barrier class for worker synchronization."""
66
67  def __init__(self, num_participants):
68    """Initializes the barrier object.
69
70    Args:
71      num_participants: an integer which is the expected number of calls of
72        `wait` pass to through this barrier.
73    """
74    self._num_participants = num_participants
75    self._counter = 0
76    self._flag = False
77    self._local_sense = threading.local()
78    self._lock = threading.Lock()
79    self._condition = threading.Condition()
80
81  def wait(self):
82    """Waits until all other callers reach the same wait call."""
83    self._local_sense.value = not self._flag
84    with self._lock:
85      self._counter += 1
86      if self._counter == self._num_participants:
87        self._counter = 0
88        self._flag = self._local_sense.value
89    with self._condition:
90      while self._flag != self._local_sense.value:
91        self._condition.wait()
92      self._condition.notify_all()
93
94
95def _get_num_workers(cluster_spec):
96  """Gets number of workers including chief."""
97  if not cluster_spec:
98    return 0
99  return len(cluster_spec.as_dict().get(_TaskType.WORKER, [])) + len(
100      cluster_spec.as_dict().get(_TaskType.CHIEF, []))
101
102
103class _WorkerContext(object):
104  """The worker context class.
105
106  This context object provides configuration information for each task. One
107  context manager with a worker context object will be created per
108  invocation to the `worker_fn` where `get_current_worker_context` can be called
109  to access the worker context object.
110  """
111
112  def __init__(self,
113               strategy,
114               cluster_spec,
115               task_type,
116               task_id,
117               session_config=None,
118               rpc_layer="grpc",
119               worker_barrier=None):
120    """Initialize the worker context object.
121
122    Args:
123      strategy: a `DistributionStrategy` object.
124      cluster_spec: a ClusterSpec object. It can be empty or None in the local
125        training case.
126      task_type: a string indicating the role of the corresponding task, such as
127        "worker" or "ps". It can be None if it is local training or in-graph
128        replicated training.
129      task_id: an integer indicating id of the corresponding task. It can be
130        None if it is local training or in-graph replicated training.
131      session_config: an optional `tf.ConfigProto` object.
132      rpc_layer: optional string specifying the RPC protocol for communication
133        with worker masters. If None or empty, hosts in the `cluster_spec` will
134        be used directly.
135      worker_barrier: optional, the barrier object for worker synchronization.
136    """
137    self._strategy = strategy
138    self._cluster_spec = cluster_spec
139    self._task_type = task_type
140    self._task_id = task_id
141    self._session_config = session_config
142    self._worker_barrier = worker_barrier
143    self._rpc_layer = rpc_layer
144    self._master_target = self._get_master_target()
145    self._num_workers = _get_num_workers(cluster_spec)
146    self._is_chief_node = self._is_chief()
147
148  def _debug_message(self):
149    if self._cluster_spec:
150      return "[cluster_spec: %r, task_type: %r, task_id: %r]" % (
151          self._cluster_spec, self.task_type, self.task_id)
152    else:
153      return "[local]"
154
155  def __enter__(self):
156    old_context = distribute_coordinator_context.get_current_worker_context()
157    if old_context:
158      raise ValueError(
159          "You cannot run distribute coordinator in a `worker_fn`.\t" +
160          self._debug_message())
161    # pylint: disable=protected-access
162    distribute_coordinator_context._worker_context.current = self
163
164  def __exit__(self, unused_exception_type, unused_exception_value,
165               unused_traceback):
166    # pylint: disable=protected-access
167    distribute_coordinator_context._worker_context.current = None
168
169  def _get_master_target(self):
170    """Return the master target for a task."""
171    # If cluster_spec is None or empty, we use local master.
172    if not self._cluster_spec:
173      return ""
174
175    # If task_type is None, then it is in-graph replicated training. In this
176    # case we use the chief or first worker's master target.
177    if not self._task_type:
178      if _TaskType.CHIEF in self._cluster_spec.jobs:
179        task_type = _TaskType.CHIEF
180        task_id = 0
181      else:
182        assert _TaskType.WORKER in self._cluster_spec.jobs
183        task_type = _TaskType.WORKER
184        task_id = 0
185    else:
186      task_type = self._task_type
187      task_id = self._task_id
188
189    prefix = ""
190    if self._rpc_layer:
191      prefix = self._rpc_layer + "://"
192    return prefix + self._cluster_spec.job_tasks(task_type)[task_id or 0]
193
194  def _is_chief(self):
195    """Return whether the task is the chief worker."""
196    if (not self._cluster_spec or
197        self._task_type in [_TaskType.CHIEF, _TaskType.EVALUATOR, None]):
198      return True
199
200    # If not local and chief not in the cluster_spec, use the first worker as
201    # chief.
202    if (_TaskType.CHIEF not in self._cluster_spec.jobs and
203        self._task_type == _TaskType.WORKER and self._task_id == 0):
204      return True
205    return False
206
207  def wait_for_other_workers(self):
208    """Waits for other workers to reach the same call to this method.
209
210    Raises:
211      ValueError: if `worker_barrier` is not passed to the __init__ method.
212    """
213    if not self._worker_barrier:
214      # TODO(yuefengz): we should throw an error in independent worker mode.
215      return
216    self._worker_barrier.wait()
217
218  def session_creator(self,
219                      scaffold=None,
220                      config=None,
221                      checkpoint_dir=None,
222                      checkpoint_filename_with_path=None,
223                      max_wait_secs=7200):
224    """Returns a session creator.
225
226    The returned session creator will be configured with the correct master
227    target and session configs. It will also run either init ops or ready ops
228    by querying the `strategy` object when `create_session` is called on it.
229
230    Args:
231      scaffold: A `Scaffold` used for gathering or building supportive ops. If
232        not specified a default one is created. It's used to finalize the graph.
233      config: `ConfigProto` proto used to configure the session.
234      checkpoint_dir: A string. Optional path to a directory where to restore
235        variables.
236      checkpoint_filename_with_path: Full file name path to the checkpoint file.
237        Only one of `checkpoint_dir` or `checkpoint_filename_with_path` can be
238        specified.
239      max_wait_secs: Maximum time to wait for the session to become available.
240
241    Returns:
242      a descendant of SessionCreator.
243    """
244    if config:
245      session_config = copy.deepcopy(config)
246      session_config.MergeFrom(self._session_config)
247    else:
248      session_config = self._session_config
249
250    if not self._strategy or self._strategy.extended.experimental_should_init:
251      logging.info("Creating chief session creator with config: %r", config)
252      return monitored_session.ChiefSessionCreator(
253          scaffold,
254          master=self.master_target,
255          config=session_config,
256          checkpoint_dir=checkpoint_dir,
257          checkpoint_filename_with_path=checkpoint_filename_with_path)
258    else:
259      logging.info("Creating worker session creator with config: %r", config)
260      return monitored_session.WorkerSessionCreator(
261          scaffold,
262          master=self.master_target,
263          config=session_config,
264          max_wait_secs=max_wait_secs)
265
266  @property
267  def session_config(self):
268    return copy.deepcopy(self._session_config)
269
270  @property
271  def has_barrier(self):
272    """Whether the barrier is set or not."""
273    return self._worker_barrier is not None
274
275  @property
276  def distributed_mode(self):
277    """Whether it is distributed training or not."""
278    return bool(self._cluster_spec) and self._task_type != _TaskType.EVALUATOR
279
280  @property
281  def cluster_spec(self):
282    """Returns a copy of the cluster_spec object."""
283    return copy.deepcopy(self._cluster_spec)
284
285  @property
286  def task_type(self):
287    """Returns the role of the corresponing task."""
288    return self._task_type
289
290  @property
291  def task_id(self):
292    """Returns the id or index of the corresponing task."""
293    return self._task_id
294
295  @property
296  def master_target(self):
297    """Returns the session master for the corresponding task to connect to."""
298    return self._master_target
299
300  @property
301  def is_chief(self):
302    """Returns whether the task is a chief node."""
303    return self._is_chief_node
304
305  @property
306  def num_workers(self):
307    """Returns number of workers in the cluster, including chief."""
308    return self._num_workers
309
310  @property
311  def experimental_should_init(self):
312    """Whether to run init ops."""
313    return self._strategy.extended.experimental_should_init
314
315  @property
316  def should_checkpoint(self):
317    """Whether to save checkpoint."""
318    return self._strategy.extended.should_checkpoint
319
320  @property
321  def should_save_summary(self):
322    """Whether to save summaries."""
323    return self._strategy.extended.should_save_summary
324
325
326def _run_single_worker(worker_fn,
327                       strategy,
328                       cluster_spec,
329                       task_type,
330                       task_id,
331                       session_config,
332                       rpc_layer="",
333                       worker_barrier=None,
334                       coord=None):
335  """Runs a single worker by calling `worker_fn` under context."""
336  session_config = copy.deepcopy(session_config)
337  strategy = copy.deepcopy(strategy)
338  # If there is an EVALUATOR task, we run single-machine eval on that task.
339  if task_type == _TaskType.EVALUATOR:
340    # It is possible to not have a strategy object for EVALUATOR task.
341    if strategy:
342      strategy.configure(session_config)
343  else:
344    assert strategy
345    strategy.configure(session_config, cluster_spec, task_type, task_id)
346
347  context = _WorkerContext(
348      strategy,
349      cluster_spec,
350      task_type,
351      task_id,
352      session_config=session_config,
353      rpc_layer=rpc_layer,
354      worker_barrier=worker_barrier)
355  with context:
356    if coord:
357      with coord.stop_on_exception():
358        return worker_fn(strategy)
359    else:
360      return worker_fn(strategy)
361
362
363def _split_cluster_for_evaluator(cluster_spec, task_type):
364  """Split the cluster for evaluator since it needn't talk to other tasks."""
365  # Splitting the cluster is important to prevent the evaluator from talking to
366  # other tasks in the cluster. Since we allow evaluator not to use
367  # distribution strategies and as a result ops in the evalauator task may have
368  # unspecified devices. Those ops may end up on other tasks if we don't split
369  # the cluster.
370  new_cluster_spec = multi_worker_util.normalize_cluster_spec(
371      cluster_spec).as_dict()
372  if task_type == _TaskType.EVALUATOR:
373    assert _TaskType.EVALUATOR in new_cluster_spec
374    new_cluster_spec = {
375        _TaskType.EVALUATOR: new_cluster_spec[_TaskType.EVALUATOR]
376    }
377  else:
378    new_cluster_spec.pop(_TaskType.EVALUATOR, None)
379  return multi_worker_util.normalize_cluster_spec(new_cluster_spec)
380
381
382def _run_std_server(cluster_spec=None,
383                    task_type=None,
384                    task_id=None,
385                    session_config=None,
386                    rpc_layer=None,
387                    environment=None):
388  """Runs a standard server."""
389  # Check if the Server is already running. If so, assert that no configuration
390  # options have changed, and return the existing Server. This allows us to
391  # call `run_distribute_coordinator` multiple times.
392  if getattr(_thread_local, "server", None) is not None:
393    assert _thread_local.cluster_spec == cluster_spec
394    assert _thread_local.task_type == task_type
395    assert _thread_local.task_id == task_id
396    assert _thread_local.session_config_str == repr(session_config)
397    assert _thread_local.rpc_layer == rpc_layer
398    assert _thread_local.environment == environment
399    return _thread_local.server
400  else:
401    # This method is not thread-safe.
402    _thread_local.server_started = True
403    _thread_local.cluster_spec = cluster_spec
404    _thread_local.task_type = task_type
405    _thread_local.task_id = task_id
406    _thread_local.session_config_str = repr(session_config)
407    _thread_local.rpc_layer = rpc_layer
408    _thread_local.environment = environment
409
410  assert cluster_spec
411  target = cluster_spec.task_address(task_type, task_id)
412  if rpc_layer:
413    target = rpc_layer + "://" + target
414
415  class _FakeServer(object):
416    """A fake server that runs a master session."""
417
418    def start(self):
419      # A tensorflow server starts when a remote session is created.
420      logging.info(
421          "Creating a remote session to start a TensorFlow server, "
422          "target = %r, session_config=%r", target, session_config)
423      session.Session(target=target, config=session_config)
424
425    def join(self):
426      while True:
427        time.sleep(5)
428
429  if environment == "google":
430    server = _FakeServer()
431  else:
432    if session_config:
433      logging.info(
434          "Starting standard TensorFlow server, target = %r, session_config= "
435          "%r", target, session_config)
436    else:
437      logging.info("Starting standard TensorFlow server, target = %r", target)
438    cluster_spec = _split_cluster_for_evaluator(cluster_spec, task_type)
439    server = server_lib.Server(
440        cluster_spec,
441        job_name=task_type,
442        task_index=task_id,
443        config=session_config,
444        protocol=rpc_layer)
445
446  server.start()
447  _thread_local.server = server
448  return server
449
450
451def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
452                              cluster_spec, session_config, rpc_layer):
453  """Runs a standalone client for between-graph replication."""
454  coord = coordinator.Coordinator()
455  eval_thread = None
456  if _TaskType.EVALUATOR in cluster_spec.jobs:
457    eval_thread = threading.Thread(
458        target=_run_single_worker,
459        args=(eval_fn, eval_strategy, cluster_spec, _TaskType.EVALUATOR, 0,
460              session_config),
461        kwargs={
462            "rpc_layer": rpc_layer,
463            "coord": coord,
464        })
465    eval_thread.start()
466
467  threads = []
468  worker_barrier = _Barrier(_get_num_workers(cluster_spec))
469  for task_type in [_TaskType.CHIEF, _TaskType.WORKER]:
470    for task_id in range(len(cluster_spec.as_dict().get(task_type, []))):
471      t = threading.Thread(
472          target=_run_single_worker,
473          args=(worker_fn, strategy, cluster_spec, task_type, task_id,
474                session_config),
475          kwargs={
476              "rpc_layer": rpc_layer,
477              "worker_barrier": worker_barrier,
478              "coord": coord,
479          })
480      t.start()
481      threads.append(t)
482
483  if eval_thread:
484    # TODO(yuefengz): is it necessary to join eval thread?
485    threads_to_join = threads + [eval_thread]
486  else:
487    threads_to_join = threads
488  coord.join(threads_to_join)
489
490  # TODO(yuefengz): we probably want to return results from all workers?
491  return None
492
493
494def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
495                         cluster_spec, session_config, rpc_layer):
496  """Runs a standalone client for in-graph replication."""
497  coord = coordinator.Coordinator()
498  eval_thread = None
499  if _TaskType.EVALUATOR in cluster_spec.jobs:
500    eval_thread = threading.Thread(
501        target=_run_single_worker,
502        args=(eval_fn, eval_strategy, cluster_spec, _TaskType.EVALUATOR, 0,
503              session_config),
504        kwargs={
505            "rpc_layer": rpc_layer,
506            "coord": coord,
507        })
508    eval_thread.start()
509
510  worker_result = _run_single_worker(
511      worker_fn,
512      strategy,
513      cluster_spec,
514      None,
515      None,
516      session_config,
517      rpc_layer=rpc_layer,
518      coord=coord)
519
520  if eval_thread:
521    coord.join([eval_thread])
522
523  return worker_result
524
525
526def _configure_session_config_for_std_servers(
527    strategy, eval_strategy, session_config, cluster_spec, task_type, task_id):
528  # pylint: disable=g-doc-args
529  """Call strategy's `configure` to mutate the session_config.
530
531  The session_config is currently needed as default config for a TensorFlow
532  server. In the future, we should be able to remove this method and only pass
533  the session config to a client session.
534  """
535  if task_type == _TaskType.EVALUATOR:
536    if eval_strategy:
537      eval_strategy.configure(session_config=session_config)
538  else:
539    # The strategy may be shared in standalone client mode.
540    strategy = copy.deepcopy(strategy)
541    strategy.configure(
542        session_config=session_config,
543        cluster_spec=cluster_spec,
544        task_type=task_type,
545        task_id=task_id)
546  # Remove the device filters specific to the strategy, so that the
547  # TensorFlow server brought up with one strategy can be used by other
548  # strategies. The device filters can be set in the client side as well.
549  del session_config.device_filters[:]
550
551
552def run_standard_tensorflow_server(session_config=None):
553  """Starts a standard TensorFlow server.
554
555  This method parses configurations from "TF_CONFIG" environment variable and
556  starts a TensorFlow server. The "TF_CONFIG" is typically a json string and
557  must have information of the cluster and the role of the server in the
558  cluster. One example is:
559
560  TF_CONFIG='{
561      "cluster": {
562          "worker": ["host1:2222", "host2:2222", "host3:2222"],
563          "ps": ["host4:2222", "host5:2222"]
564      },
565      "task": {"type": "worker", "index": 1}
566  }'
567
568  This "TF_CONFIG" specifies there are 3 workers and 2 ps tasks in the cluster
569  and the current role is worker 1.
570
571  Valid task types are "chief", "worker", "ps" and "evaluator" and you can have
572  at most one "chief" and at most one "evaluator".
573
574  An optional key-value can be specified is "rpc_layer". The default value is
575  "grpc".
576
577  Args:
578    session_config: an optional `tf.ConfigProto` object. Users can pass in
579      the session config object to configure server-local devices.
580
581  Returns:
582    a `tf.train.Server` object which has already been started.
583
584  Raises:
585    ValueError: if the "TF_CONFIG" environment is not complete.
586  """
587  tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
588  if "cluster" not in tf_config:
589    raise ValueError("\"cluster\" is not found in TF_CONFIG.")
590  cluster_spec = multi_worker_util.normalize_cluster_spec(tf_config["cluster"])
591  if "task" not in tf_config:
592    raise ValueError("\"task\" is not found in TF_CONFIG.")
593  task_env = tf_config["task"]
594  if "type" not in task_env:
595    raise ValueError(
596        "\"task_type\" is not found in the `task` part of TF_CONFIG.")
597  task_type = task_env["type"]
598  task_id = int(task_env.get("index", 0))
599
600  rpc_layer = tf_config.get("rpc_layer", "grpc")
601
602  session_config = session_config or config_pb2.ConfigProto()
603  # Set the collective group leader for collective ops to initialize collective
604  # ops when server starts.
605  if "chief" in cluster_spec.jobs:
606    session_config.experimental.collective_group_leader = (
607        "/job:chief/replica:0/task:0")
608  else:
609    if "worker" not in cluster_spec.jobs:
610      raise ValueError(
611          "You must have `chief` or `worker` jobs in the `cluster_spec`.")
612    session_config.experimental.collective_group_leader = (
613        "/job:worker/replica:0/task:0")
614
615  server = _run_std_server(
616      cluster_spec=cluster_spec,
617      task_type=task_type,
618      task_id=task_id,
619      session_config=session_config,
620      rpc_layer=rpc_layer)
621  server.start()
622  return server
623
624
625# TODO(yuefengz): propagate cluster_spec in the STANDALONE_CLIENT mode.
626# TODO(yuefengz): we may need a smart way to figure out whether the current task
627# is the special task when we support cluster_spec propagation.
628def run_distribute_coordinator(worker_fn,
629                               strategy,
630                               eval_fn=None,
631                               eval_strategy=None,
632                               mode=CoordinatorMode.STANDALONE_CLIENT,
633                               cluster_spec=None,
634                               task_type=None,
635                               task_id=None,
636                               session_config=None,
637                               rpc_layer="grpc"):
638  """Runs the coordinator for distributed TensorFlow.
639
640  This function runs a split coordinator for distributed TensorFlow in its
641  default mode, i.e the STANDALONE_CLIENT mode. Given a `cluster_spec`
642  specifying server addresses and their roles in a cluster, this coordinator
643  will figure out how to set them up, give the underlying function the right
644  targets for master sessions via a scope object and coordinate their training.
645  The cluster consisting of standard servers needs to be brought up either with
646  the standard server binary or with a binary running distribute coordinator
647  with `task_type` set to non-client type which will then turn into standard
648  servers.
649
650  In addition to be the distribute coordinator, this is also the source of
651  configurations for each job in the distributed training. As there are multiple
652  ways to configure a distributed TensorFlow cluster, its context object
653  provides these configurations so that users or higher-level APIs don't have to
654  figure out the configuration for each job by themselves.
655
656  In the between-graph replicated training, this coordinator will create
657  multiple threads and each calls the `worker_fn` which is supposed to create
658  its own graph and connect to one worker master given by its context object. In
659  the in-graph replicated training, it has only one thread calling this
660  `worker_fn`.
661
662  Another mode is the INDEPENDENT_WORKER mode where each server runs a
663  distribute coordinator which will start a standard server and optionally runs
664  `worker_fn` depending whether it is between-graph training or in-graph
665  replicated training.
666
667  The `strategy` object is expected to be a DistributionStrategy object which
668  has implemented methods needed by distributed coordinator such as
669  `configure(session_config, cluster_spec, task_type, task_id)` which configures
670  the strategy object for a specific task and `experimental_should_init`
671  property which instructs the distribute coordinator whether to run init ops
672  for a task. The distribute coordinator will make a copy of the `strategy`
673  object, call its `configure` method and pass it to `worker_fn` as an argument.
674
675  The `worker_fn` defines the training logic and is called under its own
676  worker context which can be accessed to via `get_current_worker_context`. A
677  worker context provides access to configurations for each task, e.g. the
678  task_type, task_id, master target and so on. Since `worker_fn` will be called
679  in a thread and possibly multiple times, caller should be careful when it
680  accesses global data. For example, it is unsafe to define flags in a
681  `worker_fn` or to define different environment variables for different
682  `worker_fn`s.
683
684  The `worker_fn` for the between-graph replication is defined as if there is
685  only one worker corresponding to the `worker_fn` and possibly ps jobs. For
686  example, when training with parameter servers, it assigns variables to
687  parameter servers and all other operations to that worker. In the in-graph
688  replication case, the `worker_fn` has to define operations for all worker
689  jobs. Using a distribution strategy can simplify the `worker_fn` by not having
690  to worry about the replication and device assignment of variables and
691  operations.
692
693  This method is intended to be invoked by high-level APIs so that users don't
694  have to explictly call it to run this coordinator. For those who don't use
695  high-level APIs, to change a program to use this coordinator, wrap everything
696  in a the program after global data definitions such as commandline flag
697  definition into the `worker_fn` and get task-specific configurations from
698  the worker context.
699
700  The `cluster_spec` can be either passed by the argument or parsed from the
701  "TF_CONFIG" environment variable. Example of a TF_CONFIG:
702  ```
703    cluster = {'chief': ['host0:2222'],
704               'ps': ['host1:2222', 'host2:2222'],
705               'worker': ['host3:2222', 'host4:2222', 'host5:2222']}
706    os.environ['TF_CONFIG'] = json.dumps({'cluster': cluster})
707  ```
708
709  If `cluster_spec` is not given in any format, it becomes local training and
710  this coordinator will connect to a local session.
711
712  For evaluation, if "evaluator" exists in the cluster_spec, a separate thread
713  will be created to call `eval_fn` with its `task_type` set to "evaluator". If
714  `eval_fn` is not defined, fall back to `worker_fn`. This implies that
715  evaluation will be done on a single machine if there is an "evaluator" task.
716  If "evaluator" doesn't exist in the cluster_spec, it entirely depends on the
717  `worker_fn` for how to do evaluation.
718
719  Args:
720    worker_fn: the function to be called. The function should accept a
721      `strategy` object and will be given access to a context object via a
722      context manager scope.
723    strategy: a DistributionStrategy object specifying whether it should
724      run between-graph replicated training or not, whether to run init ops,
725      etc. This object will also be configured given `session_config`,
726      `cluster_spec`, `task_type` and `task_id`.
727    eval_fn: optional function for "evaluator" task. If `eval_fn` is not passed
728      in but a "evaluator" task is found in the `cluster_spec`, the `worker_fn`
729      will be used for this task.
730    eval_strategy: optional DistributionStrategy object for "evaluator" task.
731    mode: in which mode this distribute coordinator runs.
732    cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles
733      in a cluster. If not set or empty, fall back to local training.
734    task_type: the current task type, optional if this is a client.
735    task_id: the current task id, optional if this is a client.
736    session_config: an optional `tf.ConfigProto` object which will be passed
737      to `strategy`'s `configure` method and used to create a session.
738    rpc_layer: optional string, the protocol for RPC, e.g. "grpc".
739
740  Raises:
741    ValueError: if `cluster_spec` is supplied but not a dict or a ClusterDef or
742      a ClusterSpec.
743
744  Returns:
745    In the client job, return the value returned by `worker_fn` if
746    it is in-graph replication or INDEPENDENT_WORKER mode; return None
747    otherwise.
748  """
749  tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
750  if not cluster_spec:
751    cluster_spec = tf_config.get("cluster", {})
752    task_env = tf_config.get("task", {})
753    if task_env:
754      task_type = task_env.get("type", task_type)
755      task_id = int(task_env.get("index", task_id))
756
757  if cluster_spec:
758    cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
759    # TODO(yuefengz): validate cluster_spec.
760
761  rpc_layer = tf_config.get("rpc_layer", rpc_layer)
762  environment = tf_config.get("environment", None)
763
764  # Setting the session config is necessary for some strategies such as
765  # CollectiveAllReduceStrategy.
766  session_config = session_config or config_pb2.ConfigProto(
767      allow_soft_placement=True)
768
769  if cluster_spec:
770    logging.info(
771        "Running Distribute Coordinator with mode = %r, cluster_spec = %r, "
772        "task_type = %r, task_id = %r, environment = %r, rpc_layer = %r", mode,
773        cluster_spec.as_dict(), task_type, task_id, environment, rpc_layer)
774
775  if not cluster_spec:
776    # `mode` is ignored in the local case.
777    logging.info("Running local Distribute Coordinator.")
778    _run_single_worker(worker_fn, strategy, None, None, None, session_config,
779                       rpc_layer)
780    if eval_fn:
781      _run_single_worker(eval_fn, eval_strategy, None, None, None,
782                         session_config, rpc_layer)
783    else:
784      logging.warning("Skipped evaluation since `eval_fn` is not passed in.")
785  elif mode == CoordinatorMode.STANDALONE_CLIENT:
786    if not eval_fn:
787      logging.warning("`eval_fn` is not passed in. The `worker_fn` will be "
788                      "used if an \"evaluator\" task exists in the cluster.")
789    eval_fn = eval_fn or worker_fn
790    if not eval_strategy:
791      logging.warning("`eval_strategy` is not passed in. No distribution "
792                      "strategy will be used for evaluation.")
793
794    # The client must know the cluster but servers in the cluster don't have to
795    # know the client.
796    if task_type in [_TaskType.CLIENT, None]:
797      if strategy.extended.experimental_between_graph:
798        return _run_between_graph_client(worker_fn, strategy, eval_fn,
799                                         eval_strategy, cluster_spec,
800                                         session_config, rpc_layer)
801      else:
802        return _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
803                                    cluster_spec, session_config, rpc_layer)
804    else:
805      # If not a client job, run the standard server.
806      _configure_session_config_for_std_servers(strategy, eval_strategy,
807                                                session_config, cluster_spec,
808                                                task_type, task_id)
809      server = _run_std_server(
810          cluster_spec=cluster_spec,
811          task_type=task_type,
812          task_id=task_id,
813          session_config=session_config,
814          rpc_layer=rpc_layer,
815          environment=environment)
816      server.join()
817  else:
818    if mode != CoordinatorMode.INDEPENDENT_WORKER:
819      raise ValueError("Unexpected coordinator mode: %r" % mode)
820
821    if not eval_fn:
822      logging.warning("`eval_fn` is not passed in. The `worker_fn` will be "
823                      "used if an \"evaluator\" task exists in the cluster.")
824    eval_fn = eval_fn or worker_fn
825    if not eval_strategy:
826      logging.warning("`eval_strategy` is not passed in. No distribution "
827                      "strategy will be used for evaluation.")
828
829    # Every one starts a standard server, get session config from `configure`
830    # method.
831    _configure_session_config_for_std_servers(strategy, eval_strategy,
832                                              session_config, cluster_spec,
833                                              task_type, task_id)
834
835    if not getattr(strategy.extended, "_std_server_started", False):
836      # Right now, with eager mode, context is configured with a std server at
837      # the very beginning while with graph mode the std server is started when
838      # distribute coordinator is called. We should consolidate these two paths.
839      server = _run_std_server(
840          cluster_spec=cluster_spec,
841          task_type=task_type,
842          task_id=task_id,
843          session_config=session_config,
844          rpc_layer=rpc_layer,
845          environment=environment)
846    if task_type in [_TaskType.CHIEF, _TaskType.WORKER]:
847      if strategy.extended.experimental_between_graph:
848        # All jobs run `worker_fn` if between-graph.
849        return _run_single_worker(worker_fn, strategy, cluster_spec, task_type,
850                                  task_id, session_config, rpc_layer)
851      else:
852        # Only one node runs `worker_fn` if in-graph.
853        context = _WorkerContext(strategy, cluster_spec, task_type, task_id)
854        if context.is_chief:
855          return _run_single_worker(worker_fn, strategy, cluster_spec, None,
856                                    None, session_config, rpc_layer)
857        else:
858          server.join()
859    elif task_type == _TaskType.EVALUATOR:
860      return _run_single_worker(eval_fn, eval_strategy, cluster_spec, task_type,
861                                task_id, session_config, rpc_layer)
862    else:
863      if task_type != _TaskType.PS:
864        raise ValueError("Unexpected task_type: %r" % task_type)
865      server.join()
866