1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A component for running distributed TensorFlow.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22import json 23import os 24import threading 25import time 26 27from tensorflow.core.protobuf import config_pb2 28from tensorflow.python.client import session 29from tensorflow.python.distribute import distribute_coordinator_context 30from tensorflow.python.distribute import multi_worker_util 31from tensorflow.python.platform import tf_logging as logging 32from tensorflow.python.training import coordinator 33from tensorflow.python.training import monitored_session 34from tensorflow.python.training import server_lib 35 36 37_thread_local = threading.local() 38 39 40class _TaskType(object): 41 PS = "ps" 42 WORKER = "worker" 43 CHIEF = "chief" 44 EVALUATOR = "evaluator" 45 CLIENT = "client" 46 47 48# TODO(yuefengz): support another mode where the client colocates with one 49# worker. 50class CoordinatorMode(object): 51 """Specify how distribute coordinator runs.""" 52 # The default mode where distribute coordinator will run as a standalone 53 # client and connects to remote servers for training. Each remote server can 54 # use the distribute coordinator binary with task_type set correctly which 55 # will then turn into standard servers. 56 STANDALONE_CLIENT = "standalone_client" 57 58 # The distribute coordinator runs on each worker. It will run a standard 59 # server on each worker and optionally run the `worker_fn` that is configured 60 # to talk to its standard server. 61 INDEPENDENT_WORKER = "independent_worker" 62 63 64class _Barrier(object): 65 """A reusable barrier class for worker synchronization.""" 66 67 def __init__(self, num_participants): 68 """Initializes the barrier object. 69 70 Args: 71 num_participants: an integer which is the expected number of calls of 72 `wait` pass to through this barrier. 73 """ 74 self._num_participants = num_participants 75 self._counter = 0 76 self._flag = False 77 self._local_sense = threading.local() 78 self._lock = threading.Lock() 79 self._condition = threading.Condition() 80 81 def wait(self): 82 """Waits until all other callers reach the same wait call.""" 83 self._local_sense.value = not self._flag 84 with self._lock: 85 self._counter += 1 86 if self._counter == self._num_participants: 87 self._counter = 0 88 self._flag = self._local_sense.value 89 with self._condition: 90 while self._flag != self._local_sense.value: 91 self._condition.wait() 92 self._condition.notify_all() 93 94 95def _get_num_workers(cluster_spec): 96 """Gets number of workers including chief.""" 97 if not cluster_spec: 98 return 0 99 return len(cluster_spec.as_dict().get(_TaskType.WORKER, [])) + len( 100 cluster_spec.as_dict().get(_TaskType.CHIEF, [])) 101 102 103class _WorkerContext(object): 104 """The worker context class. 105 106 This context object provides configuration information for each task. One 107 context manager with a worker context object will be created per 108 invocation to the `worker_fn` where `get_current_worker_context` can be called 109 to access the worker context object. 110 """ 111 112 def __init__(self, 113 strategy, 114 cluster_spec, 115 task_type, 116 task_id, 117 session_config=None, 118 rpc_layer="grpc", 119 worker_barrier=None): 120 """Initialize the worker context object. 121 122 Args: 123 strategy: a `DistributionStrategy` object. 124 cluster_spec: a ClusterSpec object. It can be empty or None in the local 125 training case. 126 task_type: a string indicating the role of the corresponding task, such as 127 "worker" or "ps". It can be None if it is local training or in-graph 128 replicated training. 129 task_id: an integer indicating id of the corresponding task. It can be 130 None if it is local training or in-graph replicated training. 131 session_config: an optional `tf.ConfigProto` object. 132 rpc_layer: optional string specifying the RPC protocol for communication 133 with worker masters. If None or empty, hosts in the `cluster_spec` will 134 be used directly. 135 worker_barrier: optional, the barrier object for worker synchronization. 136 """ 137 self._strategy = strategy 138 self._cluster_spec = cluster_spec 139 self._task_type = task_type 140 self._task_id = task_id 141 self._session_config = session_config 142 self._worker_barrier = worker_barrier 143 self._rpc_layer = rpc_layer 144 self._master_target = self._get_master_target() 145 self._num_workers = _get_num_workers(cluster_spec) 146 self._is_chief_node = self._is_chief() 147 148 def _debug_message(self): 149 if self._cluster_spec: 150 return "[cluster_spec: %r, task_type: %r, task_id: %r]" % ( 151 self._cluster_spec, self.task_type, self.task_id) 152 else: 153 return "[local]" 154 155 def __enter__(self): 156 old_context = distribute_coordinator_context.get_current_worker_context() 157 if old_context: 158 raise ValueError( 159 "You cannot run distribute coordinator in a `worker_fn`.\t" + 160 self._debug_message()) 161 # pylint: disable=protected-access 162 distribute_coordinator_context._worker_context.current = self 163 164 def __exit__(self, unused_exception_type, unused_exception_value, 165 unused_traceback): 166 # pylint: disable=protected-access 167 distribute_coordinator_context._worker_context.current = None 168 169 def _get_master_target(self): 170 """Return the master target for a task.""" 171 # If cluster_spec is None or empty, we use local master. 172 if not self._cluster_spec: 173 return "" 174 175 # If task_type is None, then it is in-graph replicated training. In this 176 # case we use the chief or first worker's master target. 177 if not self._task_type: 178 if _TaskType.CHIEF in self._cluster_spec.jobs: 179 task_type = _TaskType.CHIEF 180 task_id = 0 181 else: 182 assert _TaskType.WORKER in self._cluster_spec.jobs 183 task_type = _TaskType.WORKER 184 task_id = 0 185 else: 186 task_type = self._task_type 187 task_id = self._task_id 188 189 prefix = "" 190 if self._rpc_layer: 191 prefix = self._rpc_layer + "://" 192 return prefix + self._cluster_spec.job_tasks(task_type)[task_id or 0] 193 194 def _is_chief(self): 195 """Return whether the task is the chief worker.""" 196 if (not self._cluster_spec or 197 self._task_type in [_TaskType.CHIEF, _TaskType.EVALUATOR, None]): 198 return True 199 200 # If not local and chief not in the cluster_spec, use the first worker as 201 # chief. 202 if (_TaskType.CHIEF not in self._cluster_spec.jobs and 203 self._task_type == _TaskType.WORKER and self._task_id == 0): 204 return True 205 return False 206 207 def wait_for_other_workers(self): 208 """Waits for other workers to reach the same call to this method. 209 210 Raises: 211 ValueError: if `worker_barrier` is not passed to the __init__ method. 212 """ 213 if not self._worker_barrier: 214 # TODO(yuefengz): we should throw an error in independent worker mode. 215 return 216 self._worker_barrier.wait() 217 218 def session_creator(self, 219 scaffold=None, 220 config=None, 221 checkpoint_dir=None, 222 checkpoint_filename_with_path=None, 223 max_wait_secs=7200): 224 """Returns a session creator. 225 226 The returned session creator will be configured with the correct master 227 target and session configs. It will also run either init ops or ready ops 228 by querying the `strategy` object when `create_session` is called on it. 229 230 Args: 231 scaffold: A `Scaffold` used for gathering or building supportive ops. If 232 not specified a default one is created. It's used to finalize the graph. 233 config: `ConfigProto` proto used to configure the session. 234 checkpoint_dir: A string. Optional path to a directory where to restore 235 variables. 236 checkpoint_filename_with_path: Full file name path to the checkpoint file. 237 Only one of `checkpoint_dir` or `checkpoint_filename_with_path` can be 238 specified. 239 max_wait_secs: Maximum time to wait for the session to become available. 240 241 Returns: 242 a descendant of SessionCreator. 243 """ 244 if config: 245 session_config = copy.deepcopy(config) 246 session_config.MergeFrom(self._session_config) 247 else: 248 session_config = self._session_config 249 250 if not self._strategy or self._strategy.extended.experimental_should_init: 251 logging.info("Creating chief session creator with config: %r", config) 252 return monitored_session.ChiefSessionCreator( 253 scaffold, 254 master=self.master_target, 255 config=session_config, 256 checkpoint_dir=checkpoint_dir, 257 checkpoint_filename_with_path=checkpoint_filename_with_path) 258 else: 259 logging.info("Creating worker session creator with config: %r", config) 260 return monitored_session.WorkerSessionCreator( 261 scaffold, 262 master=self.master_target, 263 config=session_config, 264 max_wait_secs=max_wait_secs) 265 266 @property 267 def session_config(self): 268 return copy.deepcopy(self._session_config) 269 270 @property 271 def has_barrier(self): 272 """Whether the barrier is set or not.""" 273 return self._worker_barrier is not None 274 275 @property 276 def distributed_mode(self): 277 """Whether it is distributed training or not.""" 278 return bool(self._cluster_spec) and self._task_type != _TaskType.EVALUATOR 279 280 @property 281 def cluster_spec(self): 282 """Returns a copy of the cluster_spec object.""" 283 return copy.deepcopy(self._cluster_spec) 284 285 @property 286 def task_type(self): 287 """Returns the role of the corresponing task.""" 288 return self._task_type 289 290 @property 291 def task_id(self): 292 """Returns the id or index of the corresponing task.""" 293 return self._task_id 294 295 @property 296 def master_target(self): 297 """Returns the session master for the corresponding task to connect to.""" 298 return self._master_target 299 300 @property 301 def is_chief(self): 302 """Returns whether the task is a chief node.""" 303 return self._is_chief_node 304 305 @property 306 def num_workers(self): 307 """Returns number of workers in the cluster, including chief.""" 308 return self._num_workers 309 310 @property 311 def experimental_should_init(self): 312 """Whether to run init ops.""" 313 return self._strategy.extended.experimental_should_init 314 315 @property 316 def should_checkpoint(self): 317 """Whether to save checkpoint.""" 318 return self._strategy.extended.should_checkpoint 319 320 @property 321 def should_save_summary(self): 322 """Whether to save summaries.""" 323 return self._strategy.extended.should_save_summary 324 325 326def _run_single_worker(worker_fn, 327 strategy, 328 cluster_spec, 329 task_type, 330 task_id, 331 session_config, 332 rpc_layer="", 333 worker_barrier=None, 334 coord=None): 335 """Runs a single worker by calling `worker_fn` under context.""" 336 session_config = copy.deepcopy(session_config) 337 strategy = copy.deepcopy(strategy) 338 # If there is an EVALUATOR task, we run single-machine eval on that task. 339 if task_type == _TaskType.EVALUATOR: 340 # It is possible to not have a strategy object for EVALUATOR task. 341 if strategy: 342 strategy.configure(session_config) 343 else: 344 assert strategy 345 strategy.configure(session_config, cluster_spec, task_type, task_id) 346 347 context = _WorkerContext( 348 strategy, 349 cluster_spec, 350 task_type, 351 task_id, 352 session_config=session_config, 353 rpc_layer=rpc_layer, 354 worker_barrier=worker_barrier) 355 with context: 356 if coord: 357 with coord.stop_on_exception(): 358 return worker_fn(strategy) 359 else: 360 return worker_fn(strategy) 361 362 363def _split_cluster_for_evaluator(cluster_spec, task_type): 364 """Split the cluster for evaluator since it needn't talk to other tasks.""" 365 # Splitting the cluster is important to prevent the evaluator from talking to 366 # other tasks in the cluster. Since we allow evaluator not to use 367 # distribution strategies and as a result ops in the evalauator task may have 368 # unspecified devices. Those ops may end up on other tasks if we don't split 369 # the cluster. 370 new_cluster_spec = multi_worker_util.normalize_cluster_spec( 371 cluster_spec).as_dict() 372 if task_type == _TaskType.EVALUATOR: 373 assert _TaskType.EVALUATOR in new_cluster_spec 374 new_cluster_spec = { 375 _TaskType.EVALUATOR: new_cluster_spec[_TaskType.EVALUATOR] 376 } 377 else: 378 new_cluster_spec.pop(_TaskType.EVALUATOR, None) 379 return multi_worker_util.normalize_cluster_spec(new_cluster_spec) 380 381 382def _run_std_server(cluster_spec=None, 383 task_type=None, 384 task_id=None, 385 session_config=None, 386 rpc_layer=None, 387 environment=None): 388 """Runs a standard server.""" 389 # Check if the Server is already running. If so, assert that no configuration 390 # options have changed, and return the existing Server. This allows us to 391 # call `run_distribute_coordinator` multiple times. 392 if getattr(_thread_local, "server", None) is not None: 393 assert _thread_local.cluster_spec == cluster_spec 394 assert _thread_local.task_type == task_type 395 assert _thread_local.task_id == task_id 396 assert _thread_local.session_config_str == repr(session_config) 397 assert _thread_local.rpc_layer == rpc_layer 398 assert _thread_local.environment == environment 399 return _thread_local.server 400 else: 401 # This method is not thread-safe. 402 _thread_local.server_started = True 403 _thread_local.cluster_spec = cluster_spec 404 _thread_local.task_type = task_type 405 _thread_local.task_id = task_id 406 _thread_local.session_config_str = repr(session_config) 407 _thread_local.rpc_layer = rpc_layer 408 _thread_local.environment = environment 409 410 assert cluster_spec 411 target = cluster_spec.task_address(task_type, task_id) 412 if rpc_layer: 413 target = rpc_layer + "://" + target 414 415 class _FakeServer(object): 416 """A fake server that runs a master session.""" 417 418 def start(self): 419 # A tensorflow server starts when a remote session is created. 420 logging.info( 421 "Creating a remote session to start a TensorFlow server, " 422 "target = %r, session_config=%r", target, session_config) 423 session.Session(target=target, config=session_config) 424 425 def join(self): 426 while True: 427 time.sleep(5) 428 429 if environment == "google": 430 server = _FakeServer() 431 else: 432 if session_config: 433 logging.info( 434 "Starting standard TensorFlow server, target = %r, session_config= " 435 "%r", target, session_config) 436 else: 437 logging.info("Starting standard TensorFlow server, target = %r", target) 438 cluster_spec = _split_cluster_for_evaluator(cluster_spec, task_type) 439 server = server_lib.Server( 440 cluster_spec, 441 job_name=task_type, 442 task_index=task_id, 443 config=session_config, 444 protocol=rpc_layer) 445 446 server.start() 447 _thread_local.server = server 448 return server 449 450 451def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy, 452 cluster_spec, session_config, rpc_layer): 453 """Runs a standalone client for between-graph replication.""" 454 coord = coordinator.Coordinator() 455 eval_thread = None 456 if _TaskType.EVALUATOR in cluster_spec.jobs: 457 eval_thread = threading.Thread( 458 target=_run_single_worker, 459 args=(eval_fn, eval_strategy, cluster_spec, _TaskType.EVALUATOR, 0, 460 session_config), 461 kwargs={ 462 "rpc_layer": rpc_layer, 463 "coord": coord, 464 }) 465 eval_thread.start() 466 467 threads = [] 468 worker_barrier = _Barrier(_get_num_workers(cluster_spec)) 469 for task_type in [_TaskType.CHIEF, _TaskType.WORKER]: 470 for task_id in range(len(cluster_spec.as_dict().get(task_type, []))): 471 t = threading.Thread( 472 target=_run_single_worker, 473 args=(worker_fn, strategy, cluster_spec, task_type, task_id, 474 session_config), 475 kwargs={ 476 "rpc_layer": rpc_layer, 477 "worker_barrier": worker_barrier, 478 "coord": coord, 479 }) 480 t.start() 481 threads.append(t) 482 483 if eval_thread: 484 # TODO(yuefengz): is it necessary to join eval thread? 485 threads_to_join = threads + [eval_thread] 486 else: 487 threads_to_join = threads 488 coord.join(threads_to_join) 489 490 # TODO(yuefengz): we probably want to return results from all workers? 491 return None 492 493 494def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy, 495 cluster_spec, session_config, rpc_layer): 496 """Runs a standalone client for in-graph replication.""" 497 coord = coordinator.Coordinator() 498 eval_thread = None 499 if _TaskType.EVALUATOR in cluster_spec.jobs: 500 eval_thread = threading.Thread( 501 target=_run_single_worker, 502 args=(eval_fn, eval_strategy, cluster_spec, _TaskType.EVALUATOR, 0, 503 session_config), 504 kwargs={ 505 "rpc_layer": rpc_layer, 506 "coord": coord, 507 }) 508 eval_thread.start() 509 510 worker_result = _run_single_worker( 511 worker_fn, 512 strategy, 513 cluster_spec, 514 None, 515 None, 516 session_config, 517 rpc_layer=rpc_layer, 518 coord=coord) 519 520 if eval_thread: 521 coord.join([eval_thread]) 522 523 return worker_result 524 525 526def _configure_session_config_for_std_servers( 527 strategy, eval_strategy, session_config, cluster_spec, task_type, task_id): 528 # pylint: disable=g-doc-args 529 """Call strategy's `configure` to mutate the session_config. 530 531 The session_config is currently needed as default config for a TensorFlow 532 server. In the future, we should be able to remove this method and only pass 533 the session config to a client session. 534 """ 535 if task_type == _TaskType.EVALUATOR: 536 if eval_strategy: 537 eval_strategy.configure(session_config=session_config) 538 else: 539 # The strategy may be shared in standalone client mode. 540 strategy = copy.deepcopy(strategy) 541 strategy.configure( 542 session_config=session_config, 543 cluster_spec=cluster_spec, 544 task_type=task_type, 545 task_id=task_id) 546 # Remove the device filters specific to the strategy, so that the 547 # TensorFlow server brought up with one strategy can be used by other 548 # strategies. The device filters can be set in the client side as well. 549 del session_config.device_filters[:] 550 551 552def run_standard_tensorflow_server(session_config=None): 553 """Starts a standard TensorFlow server. 554 555 This method parses configurations from "TF_CONFIG" environment variable and 556 starts a TensorFlow server. The "TF_CONFIG" is typically a json string and 557 must have information of the cluster and the role of the server in the 558 cluster. One example is: 559 560 TF_CONFIG='{ 561 "cluster": { 562 "worker": ["host1:2222", "host2:2222", "host3:2222"], 563 "ps": ["host4:2222", "host5:2222"] 564 }, 565 "task": {"type": "worker", "index": 1} 566 }' 567 568 This "TF_CONFIG" specifies there are 3 workers and 2 ps tasks in the cluster 569 and the current role is worker 1. 570 571 Valid task types are "chief", "worker", "ps" and "evaluator" and you can have 572 at most one "chief" and at most one "evaluator". 573 574 An optional key-value can be specified is "rpc_layer". The default value is 575 "grpc". 576 577 Args: 578 session_config: an optional `tf.ConfigProto` object. Users can pass in 579 the session config object to configure server-local devices. 580 581 Returns: 582 a `tf.train.Server` object which has already been started. 583 584 Raises: 585 ValueError: if the "TF_CONFIG" environment is not complete. 586 """ 587 tf_config = json.loads(os.environ.get("TF_CONFIG", "{}")) 588 if "cluster" not in tf_config: 589 raise ValueError("\"cluster\" is not found in TF_CONFIG.") 590 cluster_spec = multi_worker_util.normalize_cluster_spec(tf_config["cluster"]) 591 if "task" not in tf_config: 592 raise ValueError("\"task\" is not found in TF_CONFIG.") 593 task_env = tf_config["task"] 594 if "type" not in task_env: 595 raise ValueError( 596 "\"task_type\" is not found in the `task` part of TF_CONFIG.") 597 task_type = task_env["type"] 598 task_id = int(task_env.get("index", 0)) 599 600 rpc_layer = tf_config.get("rpc_layer", "grpc") 601 602 session_config = session_config or config_pb2.ConfigProto() 603 # Set the collective group leader for collective ops to initialize collective 604 # ops when server starts. 605 if "chief" in cluster_spec.jobs: 606 session_config.experimental.collective_group_leader = ( 607 "/job:chief/replica:0/task:0") 608 else: 609 if "worker" not in cluster_spec.jobs: 610 raise ValueError( 611 "You must have `chief` or `worker` jobs in the `cluster_spec`.") 612 session_config.experimental.collective_group_leader = ( 613 "/job:worker/replica:0/task:0") 614 615 server = _run_std_server( 616 cluster_spec=cluster_spec, 617 task_type=task_type, 618 task_id=task_id, 619 session_config=session_config, 620 rpc_layer=rpc_layer) 621 server.start() 622 return server 623 624 625# TODO(yuefengz): propagate cluster_spec in the STANDALONE_CLIENT mode. 626# TODO(yuefengz): we may need a smart way to figure out whether the current task 627# is the special task when we support cluster_spec propagation. 628def run_distribute_coordinator(worker_fn, 629 strategy, 630 eval_fn=None, 631 eval_strategy=None, 632 mode=CoordinatorMode.STANDALONE_CLIENT, 633 cluster_spec=None, 634 task_type=None, 635 task_id=None, 636 session_config=None, 637 rpc_layer="grpc"): 638 """Runs the coordinator for distributed TensorFlow. 639 640 This function runs a split coordinator for distributed TensorFlow in its 641 default mode, i.e the STANDALONE_CLIENT mode. Given a `cluster_spec` 642 specifying server addresses and their roles in a cluster, this coordinator 643 will figure out how to set them up, give the underlying function the right 644 targets for master sessions via a scope object and coordinate their training. 645 The cluster consisting of standard servers needs to be brought up either with 646 the standard server binary or with a binary running distribute coordinator 647 with `task_type` set to non-client type which will then turn into standard 648 servers. 649 650 In addition to be the distribute coordinator, this is also the source of 651 configurations for each job in the distributed training. As there are multiple 652 ways to configure a distributed TensorFlow cluster, its context object 653 provides these configurations so that users or higher-level APIs don't have to 654 figure out the configuration for each job by themselves. 655 656 In the between-graph replicated training, this coordinator will create 657 multiple threads and each calls the `worker_fn` which is supposed to create 658 its own graph and connect to one worker master given by its context object. In 659 the in-graph replicated training, it has only one thread calling this 660 `worker_fn`. 661 662 Another mode is the INDEPENDENT_WORKER mode where each server runs a 663 distribute coordinator which will start a standard server and optionally runs 664 `worker_fn` depending whether it is between-graph training or in-graph 665 replicated training. 666 667 The `strategy` object is expected to be a DistributionStrategy object which 668 has implemented methods needed by distributed coordinator such as 669 `configure(session_config, cluster_spec, task_type, task_id)` which configures 670 the strategy object for a specific task and `experimental_should_init` 671 property which instructs the distribute coordinator whether to run init ops 672 for a task. The distribute coordinator will make a copy of the `strategy` 673 object, call its `configure` method and pass it to `worker_fn` as an argument. 674 675 The `worker_fn` defines the training logic and is called under its own 676 worker context which can be accessed to via `get_current_worker_context`. A 677 worker context provides access to configurations for each task, e.g. the 678 task_type, task_id, master target and so on. Since `worker_fn` will be called 679 in a thread and possibly multiple times, caller should be careful when it 680 accesses global data. For example, it is unsafe to define flags in a 681 `worker_fn` or to define different environment variables for different 682 `worker_fn`s. 683 684 The `worker_fn` for the between-graph replication is defined as if there is 685 only one worker corresponding to the `worker_fn` and possibly ps jobs. For 686 example, when training with parameter servers, it assigns variables to 687 parameter servers and all other operations to that worker. In the in-graph 688 replication case, the `worker_fn` has to define operations for all worker 689 jobs. Using a distribution strategy can simplify the `worker_fn` by not having 690 to worry about the replication and device assignment of variables and 691 operations. 692 693 This method is intended to be invoked by high-level APIs so that users don't 694 have to explictly call it to run this coordinator. For those who don't use 695 high-level APIs, to change a program to use this coordinator, wrap everything 696 in a the program after global data definitions such as commandline flag 697 definition into the `worker_fn` and get task-specific configurations from 698 the worker context. 699 700 The `cluster_spec` can be either passed by the argument or parsed from the 701 "TF_CONFIG" environment variable. Example of a TF_CONFIG: 702 ``` 703 cluster = {'chief': ['host0:2222'], 704 'ps': ['host1:2222', 'host2:2222'], 705 'worker': ['host3:2222', 'host4:2222', 'host5:2222']} 706 os.environ['TF_CONFIG'] = json.dumps({'cluster': cluster}) 707 ``` 708 709 If `cluster_spec` is not given in any format, it becomes local training and 710 this coordinator will connect to a local session. 711 712 For evaluation, if "evaluator" exists in the cluster_spec, a separate thread 713 will be created to call `eval_fn` with its `task_type` set to "evaluator". If 714 `eval_fn` is not defined, fall back to `worker_fn`. This implies that 715 evaluation will be done on a single machine if there is an "evaluator" task. 716 If "evaluator" doesn't exist in the cluster_spec, it entirely depends on the 717 `worker_fn` for how to do evaluation. 718 719 Args: 720 worker_fn: the function to be called. The function should accept a 721 `strategy` object and will be given access to a context object via a 722 context manager scope. 723 strategy: a DistributionStrategy object specifying whether it should 724 run between-graph replicated training or not, whether to run init ops, 725 etc. This object will also be configured given `session_config`, 726 `cluster_spec`, `task_type` and `task_id`. 727 eval_fn: optional function for "evaluator" task. If `eval_fn` is not passed 728 in but a "evaluator" task is found in the `cluster_spec`, the `worker_fn` 729 will be used for this task. 730 eval_strategy: optional DistributionStrategy object for "evaluator" task. 731 mode: in which mode this distribute coordinator runs. 732 cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles 733 in a cluster. If not set or empty, fall back to local training. 734 task_type: the current task type, optional if this is a client. 735 task_id: the current task id, optional if this is a client. 736 session_config: an optional `tf.ConfigProto` object which will be passed 737 to `strategy`'s `configure` method and used to create a session. 738 rpc_layer: optional string, the protocol for RPC, e.g. "grpc". 739 740 Raises: 741 ValueError: if `cluster_spec` is supplied but not a dict or a ClusterDef or 742 a ClusterSpec. 743 744 Returns: 745 In the client job, return the value returned by `worker_fn` if 746 it is in-graph replication or INDEPENDENT_WORKER mode; return None 747 otherwise. 748 """ 749 tf_config = json.loads(os.environ.get("TF_CONFIG", "{}")) 750 if not cluster_spec: 751 cluster_spec = tf_config.get("cluster", {}) 752 task_env = tf_config.get("task", {}) 753 if task_env: 754 task_type = task_env.get("type", task_type) 755 task_id = int(task_env.get("index", task_id)) 756 757 if cluster_spec: 758 cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) 759 # TODO(yuefengz): validate cluster_spec. 760 761 rpc_layer = tf_config.get("rpc_layer", rpc_layer) 762 environment = tf_config.get("environment", None) 763 764 # Setting the session config is necessary for some strategies such as 765 # CollectiveAllReduceStrategy. 766 session_config = session_config or config_pb2.ConfigProto( 767 allow_soft_placement=True) 768 769 if cluster_spec: 770 logging.info( 771 "Running Distribute Coordinator with mode = %r, cluster_spec = %r, " 772 "task_type = %r, task_id = %r, environment = %r, rpc_layer = %r", mode, 773 cluster_spec.as_dict(), task_type, task_id, environment, rpc_layer) 774 775 if not cluster_spec: 776 # `mode` is ignored in the local case. 777 logging.info("Running local Distribute Coordinator.") 778 _run_single_worker(worker_fn, strategy, None, None, None, session_config, 779 rpc_layer) 780 if eval_fn: 781 _run_single_worker(eval_fn, eval_strategy, None, None, None, 782 session_config, rpc_layer) 783 else: 784 logging.warning("Skipped evaluation since `eval_fn` is not passed in.") 785 elif mode == CoordinatorMode.STANDALONE_CLIENT: 786 if not eval_fn: 787 logging.warning("`eval_fn` is not passed in. The `worker_fn` will be " 788 "used if an \"evaluator\" task exists in the cluster.") 789 eval_fn = eval_fn or worker_fn 790 if not eval_strategy: 791 logging.warning("`eval_strategy` is not passed in. No distribution " 792 "strategy will be used for evaluation.") 793 794 # The client must know the cluster but servers in the cluster don't have to 795 # know the client. 796 if task_type in [_TaskType.CLIENT, None]: 797 if strategy.extended.experimental_between_graph: 798 return _run_between_graph_client(worker_fn, strategy, eval_fn, 799 eval_strategy, cluster_spec, 800 session_config, rpc_layer) 801 else: 802 return _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy, 803 cluster_spec, session_config, rpc_layer) 804 else: 805 # If not a client job, run the standard server. 806 _configure_session_config_for_std_servers(strategy, eval_strategy, 807 session_config, cluster_spec, 808 task_type, task_id) 809 server = _run_std_server( 810 cluster_spec=cluster_spec, 811 task_type=task_type, 812 task_id=task_id, 813 session_config=session_config, 814 rpc_layer=rpc_layer, 815 environment=environment) 816 server.join() 817 else: 818 if mode != CoordinatorMode.INDEPENDENT_WORKER: 819 raise ValueError("Unexpected coordinator mode: %r" % mode) 820 821 if not eval_fn: 822 logging.warning("`eval_fn` is not passed in. The `worker_fn` will be " 823 "used if an \"evaluator\" task exists in the cluster.") 824 eval_fn = eval_fn or worker_fn 825 if not eval_strategy: 826 logging.warning("`eval_strategy` is not passed in. No distribution " 827 "strategy will be used for evaluation.") 828 829 # Every one starts a standard server, get session config from `configure` 830 # method. 831 _configure_session_config_for_std_servers(strategy, eval_strategy, 832 session_config, cluster_spec, 833 task_type, task_id) 834 835 if not getattr(strategy.extended, "_std_server_started", False): 836 # Right now, with eager mode, context is configured with a std server at 837 # the very beginning while with graph mode the std server is started when 838 # distribute coordinator is called. We should consolidate these two paths. 839 server = _run_std_server( 840 cluster_spec=cluster_spec, 841 task_type=task_type, 842 task_id=task_id, 843 session_config=session_config, 844 rpc_layer=rpc_layer, 845 environment=environment) 846 if task_type in [_TaskType.CHIEF, _TaskType.WORKER]: 847 if strategy.extended.experimental_between_graph: 848 # All jobs run `worker_fn` if between-graph. 849 return _run_single_worker(worker_fn, strategy, cluster_spec, task_type, 850 task_id, session_config, rpc_layer) 851 else: 852 # Only one node runs `worker_fn` if in-graph. 853 context = _WorkerContext(strategy, cluster_spec, task_type, task_id) 854 if context.is_chief: 855 return _run_single_worker(worker_fn, strategy, cluster_spec, None, 856 None, session_config, rpc_layer) 857 else: 858 server.join() 859 elif task_type == _TaskType.EVALUATOR: 860 return _run_single_worker(eval_fn, eval_strategy, cluster_spec, task_type, 861 task_id, session_config, rpc_layer) 862 else: 863 if task_type != _TaskType.PS: 864 raise ValueError("Unexpected task_type: %r" % task_type) 865 server.join() 866