• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A component for running distributed TensorFlow."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import copy
22import json
23import os
24import threading
25import time
26
27from tensorflow.core.protobuf import config_pb2
28from tensorflow.python.client import session
29from tensorflow.python.distribute import distribute_coordinator_context
30from tensorflow.python.distribute import multi_worker_util
31from tensorflow.python.platform import tf_logging as logging
32from tensorflow.python.training import coordinator
33from tensorflow.python.training import monitored_session
34from tensorflow.python.training import server_lib
35
36
37_thread_local = threading.local()
38
39
40class _TaskType(object):
41  PS = "ps"
42  WORKER = "worker"
43  CHIEF = "chief"
44  EVALUATOR = "evaluator"
45  CLIENT = "client"
46
47
48# TODO(yuefengz): support another mode where the client colocates with one
49# worker.
50class CoordinatorMode(object):
51  """Specify how distribute coordinator runs."""
52  # The default mode where distribute coordinator will run as a standalone
53  # client and connects to remote servers for training.  Each remote server can
54  # use the distribute coordinator binary with task_type set correctly which
55  # will then turn into standard servers.
56  STANDALONE_CLIENT = "standalone_client"
57
58  # The distribute coordinator runs on each worker. It will run a standard
59  # server on each worker and optionally run the `worker_fn` that is configured
60  # to talk to its standard server.
61  INDEPENDENT_WORKER = "independent_worker"
62
63
64class _Barrier(object):
65  """A reusable barrier class for worker synchronization."""
66
67  def __init__(self, num_participants):
68    """Initializes the barrier object.
69
70    Args:
71      num_participants: an integer which is the expected number of calls of
72        `wait` pass to through this barrier.
73    """
74    self._num_participants = num_participants
75    self._counter = 0
76    self._flag = False
77    self._local_sense = threading.local()
78    self._lock = threading.Lock()
79    self._condition = threading.Condition()
80
81  def wait(self):
82    """Waits until all other callers reach the same wait call."""
83    self._local_sense.value = not self._flag
84    with self._lock:
85      self._counter += 1
86      if self._counter == self._num_participants:
87        self._counter = 0
88        self._flag = self._local_sense.value
89    with self._condition:
90      while self._flag != self._local_sense.value:
91        self._condition.wait()
92      self._condition.notify_all()
93
94
95def _get_num_workers(cluster_spec):
96  """Gets number of workers including chief."""
97  if not cluster_spec:
98    return 0
99  return len(cluster_spec.as_dict().get(_TaskType.WORKER, [])) + len(
100      cluster_spec.as_dict().get(_TaskType.CHIEF, []))
101
102
103class _WorkerContext(object):
104  """The worker context class.
105
106  This context object provides configuration information for each task. One
107  context manager with a worker context object will be created per
108  invocation to the `worker_fn` where `get_current_worker_context` can be called
109  to access the worker context object.
110  """
111
112  def __init__(self,
113               strategy,
114               cluster_spec,
115               task_type,
116               task_id,
117               session_config=None,
118               rpc_layer="grpc",
119               worker_barrier=None):
120    """Initialize the worker context object.
121
122    Args:
123      strategy: a `DistributionStrategy` object.
124      cluster_spec: a ClusterSpec object. It can be empty or None in the local
125        training case.
126      task_type: a string indicating the role of the corresponding task, such as
127        "worker" or "ps". It can be None if it is local training or in-graph
128        replicated training.
129      task_id: an integer indicating id of the corresponding task. It can be
130        None if it is local training or in-graph replicated training.
131      session_config: an optional `tf.compat.v1.ConfigProto` object.
132      rpc_layer: optional string specifying the RPC protocol for communication
133        with worker masters. If None or empty, hosts in the `cluster_spec` will
134        be used directly.
135      worker_barrier: optional, the barrier object for worker synchronization.
136    """
137    self._strategy = strategy
138    self._cluster_spec = cluster_spec
139    self._task_type = task_type
140    self._task_id = task_id
141    self._session_config = session_config
142    self._worker_barrier = worker_barrier
143    self._rpc_layer = rpc_layer
144    self._master_target = self._get_master_target()
145    self._num_workers = _get_num_workers(cluster_spec)
146    self._is_chief_node = self._is_chief()
147
148  def _debug_message(self):
149    if self._cluster_spec:
150      return "[cluster_spec: %r, task_type: %r, task_id: %r]" % (
151          self._cluster_spec, self.task_type, self.task_id)
152    else:
153      return "[local]"
154
155  def __enter__(self):
156    old_context = distribute_coordinator_context.get_current_worker_context()
157    if old_context:
158      raise ValueError(
159          "You cannot run distribute coordinator in a `worker_fn`.\t" +
160          self._debug_message())
161    # pylint: disable=protected-access
162    distribute_coordinator_context._worker_context.current = self
163
164  def __exit__(self, unused_exception_type, unused_exception_value,
165               unused_traceback):
166    # pylint: disable=protected-access
167    distribute_coordinator_context._worker_context.current = None
168
169  def _get_master_target(self):
170    """Return the master target for a task."""
171    # If cluster_spec is None or empty, we use local master.
172    if not self._cluster_spec or self._task_type == _TaskType.EVALUATOR:
173      return ""
174
175    # If task_type is None, then it is in-graph replicated training. In this
176    # case we use the chief or first worker's master target.
177    if not self._task_type:
178      if _TaskType.CHIEF in self._cluster_spec.jobs:
179        task_type = _TaskType.CHIEF
180        task_id = 0
181      else:
182        assert _TaskType.WORKER in self._cluster_spec.jobs
183        task_type = _TaskType.WORKER
184        task_id = 0
185    else:
186      task_type = self._task_type
187      task_id = self._task_id
188
189    prefix = ""
190    if self._rpc_layer:
191      prefix = self._rpc_layer + "://"
192    return prefix + self._cluster_spec.job_tasks(task_type)[task_id or 0]
193
194  def _is_chief(self):
195    """Return whether the task is the chief worker."""
196    if (not self._cluster_spec or
197        self._task_type in [_TaskType.CHIEF, _TaskType.EVALUATOR, None]):
198      return True
199
200    # If not local and chief not in the cluster_spec, use the first worker as
201    # chief.
202    if (_TaskType.CHIEF not in self._cluster_spec.jobs and
203        self._task_type == _TaskType.WORKER and self._task_id == 0):
204      return True
205    return False
206
207  def wait_for_other_workers(self):
208    """Waits for other workers to reach the same call to this method.
209
210    Raises:
211      ValueError: if `worker_barrier` is not passed to the __init__ method.
212    """
213    if not self._worker_barrier:
214      # TODO(yuefengz): we should throw an error in independent worker mode.
215      return
216    self._worker_barrier.wait()
217
218  def session_creator(self,
219                      scaffold=None,
220                      config=None,
221                      checkpoint_dir=None,
222                      checkpoint_filename_with_path=None,
223                      max_wait_secs=7200):
224    """Returns a session creator.
225
226    The returned session creator will be configured with the correct master
227    target and session configs. It will also run either init ops or ready ops
228    by querying the `strategy` object when `create_session` is called on it.
229
230    Args:
231      scaffold: A `Scaffold` used for gathering or building supportive ops. If
232        not specified a default one is created. It's used to finalize the graph.
233      config: `ConfigProto` proto used to configure the session.
234      checkpoint_dir: A string. Optional path to a directory where to restore
235        variables.
236      checkpoint_filename_with_path: Full file name path to the checkpoint file.
237        Only one of `checkpoint_dir` or `checkpoint_filename_with_path` can be
238        specified.
239      max_wait_secs: Maximum time to wait for the session to become available.
240
241    Returns:
242      a descendant of SessionCreator.
243    """
244    if config:
245      session_config = copy.deepcopy(config)
246      session_config.MergeFrom(self._session_config)
247    else:
248      session_config = self._session_config
249
250    if not self._strategy or self._strategy.extended.experimental_should_init:
251      logging.info("Creating chief session creator with config: %r", config)
252      return monitored_session.ChiefSessionCreator(
253          scaffold,
254          master=self.master_target,
255          config=session_config,
256          checkpoint_dir=checkpoint_dir,
257          checkpoint_filename_with_path=checkpoint_filename_with_path)
258    else:
259      logging.info("Creating worker session creator with config: %r", config)
260      return monitored_session.WorkerSessionCreator(
261          scaffold,
262          master=self.master_target,
263          config=session_config,
264          max_wait_secs=max_wait_secs)
265
266  @property
267  def session_config(self):
268    return copy.deepcopy(self._session_config)
269
270  @property
271  def has_barrier(self):
272    """Whether the barrier is set or not."""
273    return self._worker_barrier is not None
274
275  @property
276  def distributed_mode(self):
277    """Whether it is distributed training or not."""
278    return bool(self._cluster_spec) and self._task_type != _TaskType.EVALUATOR
279
280  @property
281  def cluster_spec(self):
282    """Returns a copy of the cluster_spec object."""
283    return copy.deepcopy(self._cluster_spec)
284
285  @property
286  def task_type(self):
287    """Returns the role of the corresponding task."""
288    return self._task_type
289
290  @property
291  def task_id(self):
292    """Returns the id or index of the corresponding task."""
293    return self._task_id
294
295  @property
296  def master_target(self):
297    """Returns the session master for the corresponding task to connect to."""
298    return self._master_target
299
300  @property
301  def is_chief(self):
302    """Returns whether the task is a chief node."""
303    return self._is_chief_node
304
305  @property
306  def num_workers(self):
307    """Returns number of workers in the cluster, including chief."""
308    return self._num_workers
309
310  @property
311  def experimental_should_init(self):
312    """Whether to run init ops."""
313    return self._strategy.extended.experimental_should_init
314
315  @property
316  def should_checkpoint(self):
317    """Whether to save checkpoint."""
318    return self._strategy.extended.should_checkpoint
319
320  @property
321  def should_save_summary(self):
322    """Whether to save summaries."""
323    return self._strategy.extended.should_save_summary
324
325
326def _run_single_worker(worker_fn,
327                       strategy,
328                       cluster_spec,
329                       task_type,
330                       task_id,
331                       session_config,
332                       rpc_layer="",
333                       worker_barrier=None,
334                       coord=None):
335  """Runs a single worker by calling `worker_fn` under context."""
336  session_config = copy.deepcopy(session_config)
337  strategy = copy.deepcopy(strategy)
338  # If there is an EVALUATOR task, we run single-machine eval on that task.
339  if task_type == _TaskType.EVALUATOR:
340    # It is possible to not have a strategy object for EVALUATOR task.
341    if strategy:
342      strategy.configure(session_config)
343  else:
344    assert strategy
345    strategy.configure(session_config, cluster_spec, task_type, task_id)
346
347  context = _WorkerContext(
348      strategy,
349      cluster_spec,
350      task_type,
351      task_id,
352      session_config=session_config,
353      rpc_layer=rpc_layer,
354      worker_barrier=worker_barrier)
355  with context:
356    if coord:
357      with coord.stop_on_exception():
358        return worker_fn(strategy)
359    else:
360      return worker_fn(strategy)
361
362
363def _split_cluster_for_evaluator(cluster_spec, task_type):
364  """Split the cluster for evaluator since it needn't talk to other tasks."""
365  # Splitting the cluster is important to prevent the evaluator from talking to
366  # other tasks in the cluster. Since we allow evaluator not to use
367  # distribution strategies and as a result ops in the evaluator task may have
368  # unspecified devices. Those ops may end up on other tasks if we don't split
369  # the cluster.
370  # Note: if you bypass distribute coordinator and bring the cluster yourself,
371  # you can equivalently set device filters to split clusters. This is already
372  # done by distribution strategy's `update_config_proto` method.
373  new_cluster_spec = multi_worker_util.normalize_cluster_spec(
374      cluster_spec).as_dict()
375  if task_type == _TaskType.EVALUATOR:
376    assert _TaskType.EVALUATOR in new_cluster_spec
377    new_cluster_spec = {
378        _TaskType.EVALUATOR: new_cluster_spec[_TaskType.EVALUATOR]
379    }
380  else:
381    new_cluster_spec.pop(_TaskType.EVALUATOR, None)
382  return multi_worker_util.normalize_cluster_spec(new_cluster_spec)
383
384
385def _run_std_server(cluster_spec=None,
386                    task_type=None,
387                    task_id=None,
388                    session_config=None,
389                    rpc_layer=None,
390                    environment=None):
391  """Runs a standard server."""
392  # Check if the Server is already running. If so, assert that no configuration
393  # options have changed, and return the existing Server. This allows us to
394  # call `run_distribute_coordinator` multiple times.
395  if getattr(_thread_local, "server", None) is not None:
396    assert _thread_local.cluster_spec == cluster_spec
397    assert _thread_local.task_type == task_type
398    assert _thread_local.task_id == task_id
399    assert _thread_local.session_config_str == repr(session_config)
400    assert _thread_local.rpc_layer == rpc_layer
401    assert _thread_local.environment == environment
402    return _thread_local.server
403  else:
404    # This method is not thread-safe.
405    _thread_local.server_started = True
406    _thread_local.cluster_spec = cluster_spec
407    _thread_local.task_type = task_type
408    _thread_local.task_id = task_id
409    _thread_local.session_config_str = repr(session_config)
410    _thread_local.rpc_layer = rpc_layer
411    _thread_local.environment = environment
412
413  assert cluster_spec
414  target = cluster_spec.task_address(task_type, task_id)
415  if rpc_layer:
416    target = rpc_layer + "://" + target
417
418  class _FakeServer(object):
419    """A fake server that runs a master session."""
420
421    def start(self):
422      # A tensorflow server starts when a remote session is created.
423      logging.info(
424          "Creating a remote session to start a TensorFlow server, "
425          "target = %r, session_config=%r", target, session_config)
426      session.Session(target=target, config=session_config)
427
428    def join(self):
429      while True:
430        time.sleep(5)
431
432  if environment == "google":
433    server = _FakeServer()
434  else:
435    if session_config:
436      logging.info(
437          "Starting standard TensorFlow server, target = %r, session_config= "
438          "%r", target, session_config)
439    else:
440      logging.info("Starting standard TensorFlow server, target = %r", target)
441    cluster_spec = _split_cluster_for_evaluator(cluster_spec, task_type)
442    server = server_lib.Server(
443        cluster_spec,
444        job_name=task_type,
445        task_index=task_id,
446        config=session_config,
447        protocol=rpc_layer)
448
449  server.start()
450  _thread_local.server = server
451  return server
452
453
454def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
455                              cluster_spec, session_config, rpc_layer):
456  """Runs a standalone client for between-graph replication."""
457  coord = coordinator.Coordinator()
458  eval_thread = None
459  if _TaskType.EVALUATOR in cluster_spec.jobs:
460    eval_thread = threading.Thread(
461        target=_run_single_worker,
462        args=(eval_fn, eval_strategy, cluster_spec, _TaskType.EVALUATOR, 0,
463              session_config),
464        kwargs={
465            "rpc_layer": rpc_layer,
466            "coord": coord,
467        })
468    eval_thread.start()
469
470  threads = []
471  worker_barrier = _Barrier(_get_num_workers(cluster_spec))
472  for task_type in [_TaskType.CHIEF, _TaskType.WORKER]:
473    for task_id in range(len(cluster_spec.as_dict().get(task_type, []))):
474      t = threading.Thread(
475          target=_run_single_worker,
476          args=(worker_fn, strategy, cluster_spec, task_type, task_id,
477                session_config),
478          kwargs={
479              "rpc_layer": rpc_layer,
480              "worker_barrier": worker_barrier,
481              "coord": coord,
482          })
483      t.start()
484      threads.append(t)
485
486  if eval_thread:
487    # TODO(yuefengz): is it necessary to join eval thread?
488    threads_to_join = threads + [eval_thread]
489  else:
490    threads_to_join = threads
491  coord.join(threads_to_join)
492
493  # TODO(yuefengz): we probably want to return results from all workers?
494  return None
495
496
497def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
498                         cluster_spec, session_config, rpc_layer):
499  """Runs a standalone client for in-graph replication."""
500  coord = coordinator.Coordinator()
501  eval_thread = None
502  if _TaskType.EVALUATOR in cluster_spec.jobs:
503    eval_thread = threading.Thread(
504        target=_run_single_worker,
505        args=(eval_fn, eval_strategy, cluster_spec, _TaskType.EVALUATOR, 0,
506              session_config),
507        kwargs={
508            "rpc_layer": rpc_layer,
509            "coord": coord,
510        })
511    eval_thread.start()
512
513  worker_result = _run_single_worker(
514      worker_fn,
515      strategy,
516      cluster_spec,
517      None,
518      None,
519      session_config,
520      rpc_layer=rpc_layer,
521      coord=coord)
522
523  if eval_thread:
524    coord.join([eval_thread])
525
526  return worker_result
527
528
529def _configure_session_config_for_std_servers(
530    strategy, eval_strategy, session_config, cluster_spec, task_type, task_id):
531  # pylint: disable=g-doc-args
532  """Call strategy's `configure` to mutate the session_config.
533
534  The session_config is currently needed as default config for a TensorFlow
535  server. In the future, we should be able to remove this method and only pass
536  the session config to a client session.
537  """
538  if task_type == _TaskType.EVALUATOR:
539    if eval_strategy:
540      eval_strategy.configure(session_config=session_config)
541  else:
542    # The strategy may be shared in standalone client mode.
543    strategy = copy.deepcopy(strategy)
544    strategy.configure(
545        session_config=session_config,
546        cluster_spec=cluster_spec,
547        task_type=task_type,
548        task_id=task_id)
549  # Remove the device filters specific to the strategy, so that the
550  # TensorFlow server brought up with one strategy can be used by other
551  # strategies. The device filters can be set in the client side as well.
552  del session_config.device_filters[:]
553
554
555def run_standard_tensorflow_server(session_config=None):
556  """Starts a standard TensorFlow server.
557
558  This method parses configurations from "TF_CONFIG" environment variable and
559  starts a TensorFlow server. The "TF_CONFIG" is typically a json string and
560  must have information of the cluster and the role of the server in the
561  cluster. One example is:
562
563  TF_CONFIG='{
564      "cluster": {
565          "worker": ["host1:2222", "host2:2222", "host3:2222"],
566          "ps": ["host4:2222", "host5:2222"]
567      },
568      "task": {"type": "worker", "index": 1}
569  }'
570
571  This "TF_CONFIG" specifies there are 3 workers and 2 ps tasks in the cluster
572  and the current role is worker 1.
573
574  Valid task types are "chief", "worker", "ps" and "evaluator" and you can have
575  at most one "chief" and at most one "evaluator".
576
577  An optional key-value can be specified is "rpc_layer". The default value is
578  "grpc".
579
580  Args:
581    session_config: an optional `tf.compat.v1.ConfigProto` object. Users can
582      pass in the session config object to configure server-local devices.
583
584  Returns:
585    a `tf.distribute.Server` object which has already been started.
586
587  Raises:
588    ValueError: if the "TF_CONFIG" environment is not complete.
589  """
590  tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
591  if "cluster" not in tf_config:
592    raise ValueError("\"cluster\" is not found in TF_CONFIG.")
593  cluster_spec = multi_worker_util.normalize_cluster_spec(tf_config["cluster"])
594  if "task" not in tf_config:
595    raise ValueError("\"task\" is not found in TF_CONFIG.")
596  task_env = tf_config["task"]
597  if "type" not in task_env:
598    raise ValueError(
599        "\"task_type\" is not found in the `task` part of TF_CONFIG.")
600  task_type = task_env["type"]
601  task_id = int(task_env.get("index", 0))
602
603  rpc_layer = tf_config.get("rpc_layer", "grpc")
604
605  session_config = session_config or config_pb2.ConfigProto()
606  # Set the collective group leader for collective ops to initialize collective
607  # ops when server starts.
608  if "chief" in cluster_spec.jobs:
609    session_config.experimental.collective_group_leader = (
610        "/job:chief/replica:0/task:0")
611  else:
612    if "worker" not in cluster_spec.jobs:
613      raise ValueError(
614          "You must have `chief` or `worker` jobs in the `cluster_spec`.")
615    session_config.experimental.collective_group_leader = (
616        "/job:worker/replica:0/task:0")
617
618  server = _run_std_server(
619      cluster_spec=cluster_spec,
620      task_type=task_type,
621      task_id=task_id,
622      session_config=session_config,
623      rpc_layer=rpc_layer)
624  server.start()
625  return server
626
627
628# TODO(yuefengz): propagate cluster_spec in the STANDALONE_CLIENT mode.
629# TODO(yuefengz): we may need a smart way to figure out whether the current task
630# is the special task when we support cluster_spec propagation.
631def run_distribute_coordinator(worker_fn,
632                               strategy,
633                               eval_fn=None,
634                               eval_strategy=None,
635                               mode=CoordinatorMode.STANDALONE_CLIENT,
636                               cluster_spec=None,
637                               task_type=None,
638                               task_id=None,
639                               session_config=None,
640                               rpc_layer="grpc"):
641  """Runs the coordinator for distributed TensorFlow.
642
643  This function runs a split coordinator for distributed TensorFlow in its
644  default mode, i.e the STANDALONE_CLIENT mode. Given a `cluster_spec`
645  specifying server addresses and their roles in a cluster, this coordinator
646  will figure out how to set them up, give the underlying function the right
647  targets for master sessions via a scope object and coordinate their training.
648  The cluster consisting of standard servers needs to be brought up either with
649  the standard server binary or with a binary running distribute coordinator
650  with `task_type` set to non-client type which will then turn into standard
651  servers.
652
653  In addition to be the distribute coordinator, this is also the source of
654  configurations for each job in the distributed training. As there are multiple
655  ways to configure a distributed TensorFlow cluster, its context object
656  provides these configurations so that users or higher-level APIs don't have to
657  figure out the configuration for each job by themselves.
658
659  In the between-graph replicated training, this coordinator will create
660  multiple threads and each calls the `worker_fn` which is supposed to create
661  its own graph and connect to one worker master given by its context object. In
662  the in-graph replicated training, it has only one thread calling this
663  `worker_fn`.
664
665  Another mode is the INDEPENDENT_WORKER mode where each server runs a
666  distribute coordinator which will start a standard server and optionally runs
667  `worker_fn` depending whether it is between-graph training or in-graph
668  replicated training.
669
670  The `strategy` object is expected to be a DistributionStrategy object which
671  has implemented methods needed by distributed coordinator such as
672  `configure(session_config, cluster_spec, task_type, task_id)` which configures
673  the strategy object for a specific task and `experimental_should_init`
674  property which instructs the distribute coordinator whether to run init ops
675  for a task. The distribute coordinator will make a copy of the `strategy`
676  object, call its `configure` method and pass it to `worker_fn` as an argument.
677
678  The `worker_fn` defines the training logic and is called under its own
679  worker context which can be accessed to via `get_current_worker_context`. A
680  worker context provides access to configurations for each task, e.g. the
681  task_type, task_id, master target and so on. Since `worker_fn` will be called
682  in a thread and possibly multiple times, caller should be careful when it
683  accesses global data. For example, it is unsafe to define flags in a
684  `worker_fn` or to define different environment variables for different
685  `worker_fn`s.
686
687  The `worker_fn` for the between-graph replication is defined as if there is
688  only one worker corresponding to the `worker_fn` and possibly ps jobs. For
689  example, when training with parameter servers, it assigns variables to
690  parameter servers and all other operations to that worker. In the in-graph
691  replication case, the `worker_fn` has to define operations for all worker
692  jobs. Using a distribution strategy can simplify the `worker_fn` by not having
693  to worry about the replication and device assignment of variables and
694  operations.
695
696  This method is intended to be invoked by high-level APIs so that users don't
697  have to explicitly call it to run this coordinator. For those who don't use
698  high-level APIs, to change a program to use this coordinator, wrap everything
699  in a the program after global data definitions such as commandline flag
700  definition into the `worker_fn` and get task-specific configurations from
701  the worker context.
702
703  The `cluster_spec` can be either passed by the argument or parsed from the
704  "TF_CONFIG" environment variable. Example of a TF_CONFIG:
705  ```
706    cluster = {'chief': ['host0:2222'],
707               'ps': ['host1:2222', 'host2:2222'],
708               'worker': ['host3:2222', 'host4:2222', 'host5:2222']}
709    os.environ['TF_CONFIG'] = json.dumps({'cluster': cluster})
710  ```
711
712  If `cluster_spec` is not given in any format, it becomes local training and
713  this coordinator will connect to a local session.
714
715  For evaluation, if "evaluator" exists in the cluster_spec, a separate thread
716  will be created to call `eval_fn` with its `task_type` set to "evaluator". If
717  `eval_fn` is not defined, fall back to `worker_fn`. This implies that
718  evaluation will be done on a single machine if there is an "evaluator" task.
719  If "evaluator" doesn't exist in the cluster_spec, it entirely depends on the
720  `worker_fn` for how to do evaluation.
721
722  Args:
723    worker_fn: the function to be called. The function should accept a
724      `strategy` object and will be given access to a context object via a
725      context manager scope.
726    strategy: a DistributionStrategy object specifying whether it should
727      run between-graph replicated training or not, whether to run init ops,
728      etc. This object will also be configured given `session_config`,
729      `cluster_spec`, `task_type` and `task_id`.
730    eval_fn: optional function for "evaluator" task. If `eval_fn` is not passed
731      in but a "evaluator" task is found in the `cluster_spec`, the `worker_fn`
732      will be used for this task.
733    eval_strategy: optional DistributionStrategy object for "evaluator" task.
734    mode: in which mode this distribute coordinator runs.
735    cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles
736      in a cluster. If not set or empty, fall back to local training.
737    task_type: the current task type, optional if this is a client.
738    task_id: the current task id, optional if this is a client.
739    session_config: an optional `tf.compat.v1.ConfigProto` object which will be
740      passed to `strategy`'s `configure` method and used to create a session.
741    rpc_layer: optional string, the protocol for RPC, e.g. "grpc".
742
743  Raises:
744    ValueError: if `cluster_spec` is supplied but not a dict or a ClusterDef or
745      a ClusterSpec.
746
747  Returns:
748    In the client job, return the value returned by `worker_fn` if
749    it is in-graph replication or INDEPENDENT_WORKER mode; return None
750    otherwise.
751  """
752  tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
753  rpc_layer = tf_config.get("rpc_layer", rpc_layer)
754  environment = tf_config.get("environment", None)
755
756  if not cluster_spec:
757    cluster_spec = tf_config.get("cluster", {})
758    task_env = tf_config.get("task", {})
759    if task_env:
760      task_type = task_env.get("type", task_type)
761      task_id = int(task_env.get("index", task_id))
762
763  if cluster_spec:
764    # TODO(yuefengz): validate cluster_spec.
765    cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
766  elif hasattr(strategy.extended, "_cluster_resolver"):
767    cluster_resolver = strategy.extended._cluster_resolver  # pylint: disable=protected-access
768    task_type = cluster_resolver.task_type
769    task_id = cluster_resolver.task_id
770    rpc_layer = cluster_resolver.rpc_layer or rpc_layer
771    environment = cluster_resolver.environment
772    cluster_spec = cluster_resolver.cluster_spec()
773
774  # Setting the session config is necessary for some strategies such as
775  # CollectiveAllReduceStrategy.
776  session_config = session_config or config_pb2.ConfigProto(
777      allow_soft_placement=True)
778
779  if cluster_spec:
780    logging.info(
781        "Running Distribute Coordinator with mode = %r, cluster_spec = %r, "
782        "task_type = %r, task_id = %r, environment = %r, rpc_layer = %r", mode,
783        cluster_spec.as_dict(), task_type, task_id, environment, rpc_layer)
784
785  if not cluster_spec:
786    # `mode` is ignored in the local case.
787    logging.info("Running local Distribute Coordinator.")
788    _run_single_worker(worker_fn, strategy, None, None, None, session_config,
789                       rpc_layer)
790    if eval_fn:
791      _run_single_worker(eval_fn, eval_strategy, None, None, None,
792                         session_config, rpc_layer)
793    else:
794      logging.warning("Skipped evaluation since `eval_fn` is not passed in.")
795  elif mode == CoordinatorMode.STANDALONE_CLIENT:
796    if not eval_fn:
797      logging.warning("`eval_fn` is not passed in. The `worker_fn` will be "
798                      "used if an \"evaluator\" task exists in the cluster.")
799    eval_fn = eval_fn or worker_fn
800    if not eval_strategy:
801      logging.warning("`eval_strategy` is not passed in. No distribution "
802                      "strategy will be used for evaluation.")
803
804    # The client must know the cluster but servers in the cluster don't have to
805    # know the client.
806    if task_type in [_TaskType.CLIENT, None]:
807      if strategy.extended.experimental_between_graph:
808        return _run_between_graph_client(worker_fn, strategy, eval_fn,
809                                         eval_strategy, cluster_spec,
810                                         session_config, rpc_layer)
811      else:
812        return _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
813                                    cluster_spec, session_config, rpc_layer)
814    else:
815      # If not a client job, run the standard server.
816      _configure_session_config_for_std_servers(strategy, eval_strategy,
817                                                session_config, cluster_spec,
818                                                task_type, task_id)
819      server = _run_std_server(
820          cluster_spec=cluster_spec,
821          task_type=task_type,
822          task_id=task_id,
823          session_config=session_config,
824          rpc_layer=rpc_layer,
825          environment=environment)
826      server.join()
827  else:
828    if mode != CoordinatorMode.INDEPENDENT_WORKER:
829      raise ValueError("Unexpected coordinator mode: %r" % mode)
830
831    if not eval_fn:
832      logging.warning("`eval_fn` is not passed in. The `worker_fn` will be "
833                      "used if an \"evaluator\" task exists in the cluster.")
834    eval_fn = eval_fn or worker_fn
835    if not eval_strategy:
836      logging.warning("`eval_strategy` is not passed in. No distribution "
837                      "strategy will be used for evaluation.")
838
839    # Every one starts a standard server, get session config from `configure`
840    # method.
841    _configure_session_config_for_std_servers(strategy, eval_strategy,
842                                              session_config, cluster_spec,
843                                              task_type, task_id)
844
845    if (task_type != _TaskType.EVALUATOR and
846        not getattr(strategy.extended, "_std_server_started", False)):
847      # Right now, with eager mode, context is configured with a std server at
848      # the very beginning while with graph mode the std server is started when
849      # distribute coordinator is called. We should consolidate these two paths.
850      server = _run_std_server(
851          cluster_spec=cluster_spec,
852          task_type=task_type,
853          task_id=task_id,
854          session_config=session_config,
855          rpc_layer=rpc_layer,
856          environment=environment)
857    if task_type in [_TaskType.CHIEF, _TaskType.WORKER]:
858      if strategy.extended.experimental_between_graph:
859        # All jobs run `worker_fn` if between-graph.
860        return _run_single_worker(worker_fn, strategy, cluster_spec, task_type,
861                                  task_id, session_config, rpc_layer)
862      else:
863        # Only one node runs `worker_fn` if in-graph.
864        context = _WorkerContext(strategy, cluster_spec, task_type, task_id)
865        if context.is_chief:
866          return _run_single_worker(worker_fn, strategy, cluster_spec, None,
867                                    None, session_config, rpc_layer)
868        else:
869          server.join()
870    elif task_type == _TaskType.EVALUATOR:
871      return _run_single_worker(eval_fn, eval_strategy, cluster_spec, task_type,
872                                task_id, session_config, rpc_layer)
873    else:
874      if task_type != _TaskType.PS:
875        raise ValueError("Unexpected task_type: %r" % task_type)
876      server.join()
877