1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A component for running distributed TensorFlow.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22import json 23import os 24import threading 25import time 26 27from tensorflow.core.protobuf import config_pb2 28from tensorflow.python.client import session 29from tensorflow.python.distribute import distribute_coordinator_context 30from tensorflow.python.distribute import multi_worker_util 31from tensorflow.python.platform import tf_logging as logging 32from tensorflow.python.training import coordinator 33from tensorflow.python.training import monitored_session 34from tensorflow.python.training import server_lib 35 36 37_thread_local = threading.local() 38 39 40class _TaskType(object): 41 PS = "ps" 42 WORKER = "worker" 43 CHIEF = "chief" 44 EVALUATOR = "evaluator" 45 CLIENT = "client" 46 47 48# TODO(yuefengz): support another mode where the client colocates with one 49# worker. 50class CoordinatorMode(object): 51 """Specify how distribute coordinator runs.""" 52 # The default mode where distribute coordinator will run as a standalone 53 # client and connects to remote servers for training. Each remote server can 54 # use the distribute coordinator binary with task_type set correctly which 55 # will then turn into standard servers. 56 STANDALONE_CLIENT = "standalone_client" 57 58 # The distribute coordinator runs on each worker. It will run a standard 59 # server on each worker and optionally run the `worker_fn` that is configured 60 # to talk to its standard server. 61 INDEPENDENT_WORKER = "independent_worker" 62 63 64class _Barrier(object): 65 """A reusable barrier class for worker synchronization.""" 66 67 def __init__(self, num_participants): 68 """Initializes the barrier object. 69 70 Args: 71 num_participants: an integer which is the expected number of calls of 72 `wait` pass to through this barrier. 73 """ 74 self._num_participants = num_participants 75 self._counter = 0 76 self._flag = False 77 self._local_sense = threading.local() 78 self._lock = threading.Lock() 79 self._condition = threading.Condition() 80 81 def wait(self): 82 """Waits until all other callers reach the same wait call.""" 83 self._local_sense.value = not self._flag 84 with self._lock: 85 self._counter += 1 86 if self._counter == self._num_participants: 87 self._counter = 0 88 self._flag = self._local_sense.value 89 with self._condition: 90 while self._flag != self._local_sense.value: 91 self._condition.wait() 92 self._condition.notify_all() 93 94 95def _get_num_workers(cluster_spec): 96 """Gets number of workers including chief.""" 97 if not cluster_spec: 98 return 0 99 return len(cluster_spec.as_dict().get(_TaskType.WORKER, [])) + len( 100 cluster_spec.as_dict().get(_TaskType.CHIEF, [])) 101 102 103class _WorkerContext(object): 104 """The worker context class. 105 106 This context object provides configuration information for each task. One 107 context manager with a worker context object will be created per 108 invocation to the `worker_fn` where `get_current_worker_context` can be called 109 to access the worker context object. 110 """ 111 112 def __init__(self, 113 strategy, 114 cluster_spec, 115 task_type, 116 task_id, 117 session_config=None, 118 rpc_layer="grpc", 119 worker_barrier=None): 120 """Initialize the worker context object. 121 122 Args: 123 strategy: a `DistributionStrategy` object. 124 cluster_spec: a ClusterSpec object. It can be empty or None in the local 125 training case. 126 task_type: a string indicating the role of the corresponding task, such as 127 "worker" or "ps". It can be None if it is local training or in-graph 128 replicated training. 129 task_id: an integer indicating id of the corresponding task. It can be 130 None if it is local training or in-graph replicated training. 131 session_config: an optional `tf.compat.v1.ConfigProto` object. 132 rpc_layer: optional string specifying the RPC protocol for communication 133 with worker masters. If None or empty, hosts in the `cluster_spec` will 134 be used directly. 135 worker_barrier: optional, the barrier object for worker synchronization. 136 """ 137 self._strategy = strategy 138 self._cluster_spec = cluster_spec 139 self._task_type = task_type 140 self._task_id = task_id 141 self._session_config = session_config 142 self._worker_barrier = worker_barrier 143 self._rpc_layer = rpc_layer 144 self._master_target = self._get_master_target() 145 self._num_workers = _get_num_workers(cluster_spec) 146 self._is_chief_node = self._is_chief() 147 148 def _debug_message(self): 149 if self._cluster_spec: 150 return "[cluster_spec: %r, task_type: %r, task_id: %r]" % ( 151 self._cluster_spec, self.task_type, self.task_id) 152 else: 153 return "[local]" 154 155 def __enter__(self): 156 old_context = distribute_coordinator_context.get_current_worker_context() 157 if old_context: 158 raise ValueError( 159 "You cannot run distribute coordinator in a `worker_fn`.\t" + 160 self._debug_message()) 161 # pylint: disable=protected-access 162 distribute_coordinator_context._worker_context.current = self 163 164 def __exit__(self, unused_exception_type, unused_exception_value, 165 unused_traceback): 166 # pylint: disable=protected-access 167 distribute_coordinator_context._worker_context.current = None 168 169 def _get_master_target(self): 170 """Return the master target for a task.""" 171 # If cluster_spec is None or empty, we use local master. 172 if not self._cluster_spec or self._task_type == _TaskType.EVALUATOR: 173 return "" 174 175 # If task_type is None, then it is in-graph replicated training. In this 176 # case we use the chief or first worker's master target. 177 if not self._task_type: 178 if _TaskType.CHIEF in self._cluster_spec.jobs: 179 task_type = _TaskType.CHIEF 180 task_id = 0 181 else: 182 assert _TaskType.WORKER in self._cluster_spec.jobs 183 task_type = _TaskType.WORKER 184 task_id = 0 185 else: 186 task_type = self._task_type 187 task_id = self._task_id 188 189 prefix = "" 190 if self._rpc_layer: 191 prefix = self._rpc_layer + "://" 192 return prefix + self._cluster_spec.job_tasks(task_type)[task_id or 0] 193 194 def _is_chief(self): 195 """Return whether the task is the chief worker.""" 196 if (not self._cluster_spec or 197 self._task_type in [_TaskType.CHIEF, _TaskType.EVALUATOR, None]): 198 return True 199 200 # If not local and chief not in the cluster_spec, use the first worker as 201 # chief. 202 if (_TaskType.CHIEF not in self._cluster_spec.jobs and 203 self._task_type == _TaskType.WORKER and self._task_id == 0): 204 return True 205 return False 206 207 def wait_for_other_workers(self): 208 """Waits for other workers to reach the same call to this method. 209 210 Raises: 211 ValueError: if `worker_barrier` is not passed to the __init__ method. 212 """ 213 if not self._worker_barrier: 214 # TODO(yuefengz): we should throw an error in independent worker mode. 215 return 216 self._worker_barrier.wait() 217 218 def session_creator(self, 219 scaffold=None, 220 config=None, 221 checkpoint_dir=None, 222 checkpoint_filename_with_path=None, 223 max_wait_secs=7200): 224 """Returns a session creator. 225 226 The returned session creator will be configured with the correct master 227 target and session configs. It will also run either init ops or ready ops 228 by querying the `strategy` object when `create_session` is called on it. 229 230 Args: 231 scaffold: A `Scaffold` used for gathering or building supportive ops. If 232 not specified a default one is created. It's used to finalize the graph. 233 config: `ConfigProto` proto used to configure the session. 234 checkpoint_dir: A string. Optional path to a directory where to restore 235 variables. 236 checkpoint_filename_with_path: Full file name path to the checkpoint file. 237 Only one of `checkpoint_dir` or `checkpoint_filename_with_path` can be 238 specified. 239 max_wait_secs: Maximum time to wait for the session to become available. 240 241 Returns: 242 a descendant of SessionCreator. 243 """ 244 if config: 245 session_config = copy.deepcopy(config) 246 session_config.MergeFrom(self._session_config) 247 else: 248 session_config = self._session_config 249 250 if not self._strategy or self._strategy.extended.experimental_should_init: 251 logging.info("Creating chief session creator with config: %r", config) 252 return monitored_session.ChiefSessionCreator( 253 scaffold, 254 master=self.master_target, 255 config=session_config, 256 checkpoint_dir=checkpoint_dir, 257 checkpoint_filename_with_path=checkpoint_filename_with_path) 258 else: 259 logging.info("Creating worker session creator with config: %r", config) 260 return monitored_session.WorkerSessionCreator( 261 scaffold, 262 master=self.master_target, 263 config=session_config, 264 max_wait_secs=max_wait_secs) 265 266 @property 267 def session_config(self): 268 return copy.deepcopy(self._session_config) 269 270 @property 271 def has_barrier(self): 272 """Whether the barrier is set or not.""" 273 return self._worker_barrier is not None 274 275 @property 276 def distributed_mode(self): 277 """Whether it is distributed training or not.""" 278 return bool(self._cluster_spec) and self._task_type != _TaskType.EVALUATOR 279 280 @property 281 def cluster_spec(self): 282 """Returns a copy of the cluster_spec object.""" 283 return copy.deepcopy(self._cluster_spec) 284 285 @property 286 def task_type(self): 287 """Returns the role of the corresponding task.""" 288 return self._task_type 289 290 @property 291 def task_id(self): 292 """Returns the id or index of the corresponding task.""" 293 return self._task_id 294 295 @property 296 def master_target(self): 297 """Returns the session master for the corresponding task to connect to.""" 298 return self._master_target 299 300 @property 301 def is_chief(self): 302 """Returns whether the task is a chief node.""" 303 return self._is_chief_node 304 305 @property 306 def num_workers(self): 307 """Returns number of workers in the cluster, including chief.""" 308 return self._num_workers 309 310 @property 311 def experimental_should_init(self): 312 """Whether to run init ops.""" 313 return self._strategy.extended.experimental_should_init 314 315 @property 316 def should_checkpoint(self): 317 """Whether to save checkpoint.""" 318 return self._strategy.extended.should_checkpoint 319 320 @property 321 def should_save_summary(self): 322 """Whether to save summaries.""" 323 return self._strategy.extended.should_save_summary 324 325 326def _run_single_worker(worker_fn, 327 strategy, 328 cluster_spec, 329 task_type, 330 task_id, 331 session_config, 332 rpc_layer="", 333 worker_barrier=None, 334 coord=None): 335 """Runs a single worker by calling `worker_fn` under context.""" 336 session_config = copy.deepcopy(session_config) 337 strategy = copy.deepcopy(strategy) 338 # If there is an EVALUATOR task, we run single-machine eval on that task. 339 if task_type == _TaskType.EVALUATOR: 340 # It is possible to not have a strategy object for EVALUATOR task. 341 if strategy: 342 strategy.configure(session_config) 343 else: 344 assert strategy 345 strategy.configure(session_config, cluster_spec, task_type, task_id) 346 347 context = _WorkerContext( 348 strategy, 349 cluster_spec, 350 task_type, 351 task_id, 352 session_config=session_config, 353 rpc_layer=rpc_layer, 354 worker_barrier=worker_barrier) 355 with context: 356 if coord: 357 with coord.stop_on_exception(): 358 return worker_fn(strategy) 359 else: 360 return worker_fn(strategy) 361 362 363def _split_cluster_for_evaluator(cluster_spec, task_type): 364 """Split the cluster for evaluator since it needn't talk to other tasks.""" 365 # Splitting the cluster is important to prevent the evaluator from talking to 366 # other tasks in the cluster. Since we allow evaluator not to use 367 # distribution strategies and as a result ops in the evaluator task may have 368 # unspecified devices. Those ops may end up on other tasks if we don't split 369 # the cluster. 370 # Note: if you bypass distribute coordinator and bring the cluster yourself, 371 # you can equivalently set device filters to split clusters. This is already 372 # done by distribution strategy's `update_config_proto` method. 373 new_cluster_spec = multi_worker_util.normalize_cluster_spec( 374 cluster_spec).as_dict() 375 if task_type == _TaskType.EVALUATOR: 376 assert _TaskType.EVALUATOR in new_cluster_spec 377 new_cluster_spec = { 378 _TaskType.EVALUATOR: new_cluster_spec[_TaskType.EVALUATOR] 379 } 380 else: 381 new_cluster_spec.pop(_TaskType.EVALUATOR, None) 382 return multi_worker_util.normalize_cluster_spec(new_cluster_spec) 383 384 385def _run_std_server(cluster_spec=None, 386 task_type=None, 387 task_id=None, 388 session_config=None, 389 rpc_layer=None, 390 environment=None): 391 """Runs a standard server.""" 392 # Check if the Server is already running. If so, assert that no configuration 393 # options have changed, and return the existing Server. This allows us to 394 # call `run_distribute_coordinator` multiple times. 395 if getattr(_thread_local, "server", None) is not None: 396 assert _thread_local.cluster_spec == cluster_spec 397 assert _thread_local.task_type == task_type 398 assert _thread_local.task_id == task_id 399 assert _thread_local.session_config_str == repr(session_config) 400 assert _thread_local.rpc_layer == rpc_layer 401 assert _thread_local.environment == environment 402 return _thread_local.server 403 else: 404 # This method is not thread-safe. 405 _thread_local.server_started = True 406 _thread_local.cluster_spec = cluster_spec 407 _thread_local.task_type = task_type 408 _thread_local.task_id = task_id 409 _thread_local.session_config_str = repr(session_config) 410 _thread_local.rpc_layer = rpc_layer 411 _thread_local.environment = environment 412 413 assert cluster_spec 414 target = cluster_spec.task_address(task_type, task_id) 415 if rpc_layer: 416 target = rpc_layer + "://" + target 417 418 class _FakeServer(object): 419 """A fake server that runs a master session.""" 420 421 def start(self): 422 # A tensorflow server starts when a remote session is created. 423 logging.info( 424 "Creating a remote session to start a TensorFlow server, " 425 "target = %r, session_config=%r", target, session_config) 426 session.Session(target=target, config=session_config) 427 428 def join(self): 429 while True: 430 time.sleep(5) 431 432 if environment == "google": 433 server = _FakeServer() 434 else: 435 if session_config: 436 logging.info( 437 "Starting standard TensorFlow server, target = %r, session_config= " 438 "%r", target, session_config) 439 else: 440 logging.info("Starting standard TensorFlow server, target = %r", target) 441 cluster_spec = _split_cluster_for_evaluator(cluster_spec, task_type) 442 server = server_lib.Server( 443 cluster_spec, 444 job_name=task_type, 445 task_index=task_id, 446 config=session_config, 447 protocol=rpc_layer) 448 449 server.start() 450 _thread_local.server = server 451 return server 452 453 454def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy, 455 cluster_spec, session_config, rpc_layer): 456 """Runs a standalone client for between-graph replication.""" 457 coord = coordinator.Coordinator() 458 eval_thread = None 459 if _TaskType.EVALUATOR in cluster_spec.jobs: 460 eval_thread = threading.Thread( 461 target=_run_single_worker, 462 args=(eval_fn, eval_strategy, cluster_spec, _TaskType.EVALUATOR, 0, 463 session_config), 464 kwargs={ 465 "rpc_layer": rpc_layer, 466 "coord": coord, 467 }) 468 eval_thread.start() 469 470 threads = [] 471 worker_barrier = _Barrier(_get_num_workers(cluster_spec)) 472 for task_type in [_TaskType.CHIEF, _TaskType.WORKER]: 473 for task_id in range(len(cluster_spec.as_dict().get(task_type, []))): 474 t = threading.Thread( 475 target=_run_single_worker, 476 args=(worker_fn, strategy, cluster_spec, task_type, task_id, 477 session_config), 478 kwargs={ 479 "rpc_layer": rpc_layer, 480 "worker_barrier": worker_barrier, 481 "coord": coord, 482 }) 483 t.start() 484 threads.append(t) 485 486 if eval_thread: 487 # TODO(yuefengz): is it necessary to join eval thread? 488 threads_to_join = threads + [eval_thread] 489 else: 490 threads_to_join = threads 491 coord.join(threads_to_join) 492 493 # TODO(yuefengz): we probably want to return results from all workers? 494 return None 495 496 497def _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy, 498 cluster_spec, session_config, rpc_layer): 499 """Runs a standalone client for in-graph replication.""" 500 coord = coordinator.Coordinator() 501 eval_thread = None 502 if _TaskType.EVALUATOR in cluster_spec.jobs: 503 eval_thread = threading.Thread( 504 target=_run_single_worker, 505 args=(eval_fn, eval_strategy, cluster_spec, _TaskType.EVALUATOR, 0, 506 session_config), 507 kwargs={ 508 "rpc_layer": rpc_layer, 509 "coord": coord, 510 }) 511 eval_thread.start() 512 513 worker_result = _run_single_worker( 514 worker_fn, 515 strategy, 516 cluster_spec, 517 None, 518 None, 519 session_config, 520 rpc_layer=rpc_layer, 521 coord=coord) 522 523 if eval_thread: 524 coord.join([eval_thread]) 525 526 return worker_result 527 528 529def _configure_session_config_for_std_servers( 530 strategy, eval_strategy, session_config, cluster_spec, task_type, task_id): 531 # pylint: disable=g-doc-args 532 """Call strategy's `configure` to mutate the session_config. 533 534 The session_config is currently needed as default config for a TensorFlow 535 server. In the future, we should be able to remove this method and only pass 536 the session config to a client session. 537 """ 538 if task_type == _TaskType.EVALUATOR: 539 if eval_strategy: 540 eval_strategy.configure(session_config=session_config) 541 else: 542 # The strategy may be shared in standalone client mode. 543 strategy = copy.deepcopy(strategy) 544 strategy.configure( 545 session_config=session_config, 546 cluster_spec=cluster_spec, 547 task_type=task_type, 548 task_id=task_id) 549 # Remove the device filters specific to the strategy, so that the 550 # TensorFlow server brought up with one strategy can be used by other 551 # strategies. The device filters can be set in the client side as well. 552 del session_config.device_filters[:] 553 554 555def run_standard_tensorflow_server(session_config=None): 556 """Starts a standard TensorFlow server. 557 558 This method parses configurations from "TF_CONFIG" environment variable and 559 starts a TensorFlow server. The "TF_CONFIG" is typically a json string and 560 must have information of the cluster and the role of the server in the 561 cluster. One example is: 562 563 TF_CONFIG='{ 564 "cluster": { 565 "worker": ["host1:2222", "host2:2222", "host3:2222"], 566 "ps": ["host4:2222", "host5:2222"] 567 }, 568 "task": {"type": "worker", "index": 1} 569 }' 570 571 This "TF_CONFIG" specifies there are 3 workers and 2 ps tasks in the cluster 572 and the current role is worker 1. 573 574 Valid task types are "chief", "worker", "ps" and "evaluator" and you can have 575 at most one "chief" and at most one "evaluator". 576 577 An optional key-value can be specified is "rpc_layer". The default value is 578 "grpc". 579 580 Args: 581 session_config: an optional `tf.compat.v1.ConfigProto` object. Users can 582 pass in the session config object to configure server-local devices. 583 584 Returns: 585 a `tf.distribute.Server` object which has already been started. 586 587 Raises: 588 ValueError: if the "TF_CONFIG" environment is not complete. 589 """ 590 tf_config = json.loads(os.environ.get("TF_CONFIG", "{}")) 591 if "cluster" not in tf_config: 592 raise ValueError("\"cluster\" is not found in TF_CONFIG.") 593 cluster_spec = multi_worker_util.normalize_cluster_spec(tf_config["cluster"]) 594 if "task" not in tf_config: 595 raise ValueError("\"task\" is not found in TF_CONFIG.") 596 task_env = tf_config["task"] 597 if "type" not in task_env: 598 raise ValueError( 599 "\"task_type\" is not found in the `task` part of TF_CONFIG.") 600 task_type = task_env["type"] 601 task_id = int(task_env.get("index", 0)) 602 603 rpc_layer = tf_config.get("rpc_layer", "grpc") 604 605 session_config = session_config or config_pb2.ConfigProto() 606 # Set the collective group leader for collective ops to initialize collective 607 # ops when server starts. 608 if "chief" in cluster_spec.jobs: 609 session_config.experimental.collective_group_leader = ( 610 "/job:chief/replica:0/task:0") 611 else: 612 if "worker" not in cluster_spec.jobs: 613 raise ValueError( 614 "You must have `chief` or `worker` jobs in the `cluster_spec`.") 615 session_config.experimental.collective_group_leader = ( 616 "/job:worker/replica:0/task:0") 617 618 server = _run_std_server( 619 cluster_spec=cluster_spec, 620 task_type=task_type, 621 task_id=task_id, 622 session_config=session_config, 623 rpc_layer=rpc_layer) 624 server.start() 625 return server 626 627 628# TODO(yuefengz): propagate cluster_spec in the STANDALONE_CLIENT mode. 629# TODO(yuefengz): we may need a smart way to figure out whether the current task 630# is the special task when we support cluster_spec propagation. 631def run_distribute_coordinator(worker_fn, 632 strategy, 633 eval_fn=None, 634 eval_strategy=None, 635 mode=CoordinatorMode.STANDALONE_CLIENT, 636 cluster_spec=None, 637 task_type=None, 638 task_id=None, 639 session_config=None, 640 rpc_layer="grpc"): 641 """Runs the coordinator for distributed TensorFlow. 642 643 This function runs a split coordinator for distributed TensorFlow in its 644 default mode, i.e the STANDALONE_CLIENT mode. Given a `cluster_spec` 645 specifying server addresses and their roles in a cluster, this coordinator 646 will figure out how to set them up, give the underlying function the right 647 targets for master sessions via a scope object and coordinate their training. 648 The cluster consisting of standard servers needs to be brought up either with 649 the standard server binary or with a binary running distribute coordinator 650 with `task_type` set to non-client type which will then turn into standard 651 servers. 652 653 In addition to be the distribute coordinator, this is also the source of 654 configurations for each job in the distributed training. As there are multiple 655 ways to configure a distributed TensorFlow cluster, its context object 656 provides these configurations so that users or higher-level APIs don't have to 657 figure out the configuration for each job by themselves. 658 659 In the between-graph replicated training, this coordinator will create 660 multiple threads and each calls the `worker_fn` which is supposed to create 661 its own graph and connect to one worker master given by its context object. In 662 the in-graph replicated training, it has only one thread calling this 663 `worker_fn`. 664 665 Another mode is the INDEPENDENT_WORKER mode where each server runs a 666 distribute coordinator which will start a standard server and optionally runs 667 `worker_fn` depending whether it is between-graph training or in-graph 668 replicated training. 669 670 The `strategy` object is expected to be a DistributionStrategy object which 671 has implemented methods needed by distributed coordinator such as 672 `configure(session_config, cluster_spec, task_type, task_id)` which configures 673 the strategy object for a specific task and `experimental_should_init` 674 property which instructs the distribute coordinator whether to run init ops 675 for a task. The distribute coordinator will make a copy of the `strategy` 676 object, call its `configure` method and pass it to `worker_fn` as an argument. 677 678 The `worker_fn` defines the training logic and is called under its own 679 worker context which can be accessed to via `get_current_worker_context`. A 680 worker context provides access to configurations for each task, e.g. the 681 task_type, task_id, master target and so on. Since `worker_fn` will be called 682 in a thread and possibly multiple times, caller should be careful when it 683 accesses global data. For example, it is unsafe to define flags in a 684 `worker_fn` or to define different environment variables for different 685 `worker_fn`s. 686 687 The `worker_fn` for the between-graph replication is defined as if there is 688 only one worker corresponding to the `worker_fn` and possibly ps jobs. For 689 example, when training with parameter servers, it assigns variables to 690 parameter servers and all other operations to that worker. In the in-graph 691 replication case, the `worker_fn` has to define operations for all worker 692 jobs. Using a distribution strategy can simplify the `worker_fn` by not having 693 to worry about the replication and device assignment of variables and 694 operations. 695 696 This method is intended to be invoked by high-level APIs so that users don't 697 have to explicitly call it to run this coordinator. For those who don't use 698 high-level APIs, to change a program to use this coordinator, wrap everything 699 in a the program after global data definitions such as commandline flag 700 definition into the `worker_fn` and get task-specific configurations from 701 the worker context. 702 703 The `cluster_spec` can be either passed by the argument or parsed from the 704 "TF_CONFIG" environment variable. Example of a TF_CONFIG: 705 ``` 706 cluster = {'chief': ['host0:2222'], 707 'ps': ['host1:2222', 'host2:2222'], 708 'worker': ['host3:2222', 'host4:2222', 'host5:2222']} 709 os.environ['TF_CONFIG'] = json.dumps({'cluster': cluster}) 710 ``` 711 712 If `cluster_spec` is not given in any format, it becomes local training and 713 this coordinator will connect to a local session. 714 715 For evaluation, if "evaluator" exists in the cluster_spec, a separate thread 716 will be created to call `eval_fn` with its `task_type` set to "evaluator". If 717 `eval_fn` is not defined, fall back to `worker_fn`. This implies that 718 evaluation will be done on a single machine if there is an "evaluator" task. 719 If "evaluator" doesn't exist in the cluster_spec, it entirely depends on the 720 `worker_fn` for how to do evaluation. 721 722 Args: 723 worker_fn: the function to be called. The function should accept a 724 `strategy` object and will be given access to a context object via a 725 context manager scope. 726 strategy: a DistributionStrategy object specifying whether it should 727 run between-graph replicated training or not, whether to run init ops, 728 etc. This object will also be configured given `session_config`, 729 `cluster_spec`, `task_type` and `task_id`. 730 eval_fn: optional function for "evaluator" task. If `eval_fn` is not passed 731 in but a "evaluator" task is found in the `cluster_spec`, the `worker_fn` 732 will be used for this task. 733 eval_strategy: optional DistributionStrategy object for "evaluator" task. 734 mode: in which mode this distribute coordinator runs. 735 cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and roles 736 in a cluster. If not set or empty, fall back to local training. 737 task_type: the current task type, optional if this is a client. 738 task_id: the current task id, optional if this is a client. 739 session_config: an optional `tf.compat.v1.ConfigProto` object which will be 740 passed to `strategy`'s `configure` method and used to create a session. 741 rpc_layer: optional string, the protocol for RPC, e.g. "grpc". 742 743 Raises: 744 ValueError: if `cluster_spec` is supplied but not a dict or a ClusterDef or 745 a ClusterSpec. 746 747 Returns: 748 In the client job, return the value returned by `worker_fn` if 749 it is in-graph replication or INDEPENDENT_WORKER mode; return None 750 otherwise. 751 """ 752 tf_config = json.loads(os.environ.get("TF_CONFIG", "{}")) 753 rpc_layer = tf_config.get("rpc_layer", rpc_layer) 754 environment = tf_config.get("environment", None) 755 756 if not cluster_spec: 757 cluster_spec = tf_config.get("cluster", {}) 758 task_env = tf_config.get("task", {}) 759 if task_env: 760 task_type = task_env.get("type", task_type) 761 task_id = int(task_env.get("index", task_id)) 762 763 if cluster_spec: 764 # TODO(yuefengz): validate cluster_spec. 765 cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec) 766 elif hasattr(strategy.extended, "_cluster_resolver"): 767 cluster_resolver = strategy.extended._cluster_resolver # pylint: disable=protected-access 768 task_type = cluster_resolver.task_type 769 task_id = cluster_resolver.task_id 770 rpc_layer = cluster_resolver.rpc_layer or rpc_layer 771 environment = cluster_resolver.environment 772 cluster_spec = cluster_resolver.cluster_spec() 773 774 # Setting the session config is necessary for some strategies such as 775 # CollectiveAllReduceStrategy. 776 session_config = session_config or config_pb2.ConfigProto( 777 allow_soft_placement=True) 778 779 if cluster_spec: 780 logging.info( 781 "Running Distribute Coordinator with mode = %r, cluster_spec = %r, " 782 "task_type = %r, task_id = %r, environment = %r, rpc_layer = %r", mode, 783 cluster_spec.as_dict(), task_type, task_id, environment, rpc_layer) 784 785 if not cluster_spec: 786 # `mode` is ignored in the local case. 787 logging.info("Running local Distribute Coordinator.") 788 _run_single_worker(worker_fn, strategy, None, None, None, session_config, 789 rpc_layer) 790 if eval_fn: 791 _run_single_worker(eval_fn, eval_strategy, None, None, None, 792 session_config, rpc_layer) 793 else: 794 logging.warning("Skipped evaluation since `eval_fn` is not passed in.") 795 elif mode == CoordinatorMode.STANDALONE_CLIENT: 796 if not eval_fn: 797 logging.warning("`eval_fn` is not passed in. The `worker_fn` will be " 798 "used if an \"evaluator\" task exists in the cluster.") 799 eval_fn = eval_fn or worker_fn 800 if not eval_strategy: 801 logging.warning("`eval_strategy` is not passed in. No distribution " 802 "strategy will be used for evaluation.") 803 804 # The client must know the cluster but servers in the cluster don't have to 805 # know the client. 806 if task_type in [_TaskType.CLIENT, None]: 807 if strategy.extended.experimental_between_graph: 808 return _run_between_graph_client(worker_fn, strategy, eval_fn, 809 eval_strategy, cluster_spec, 810 session_config, rpc_layer) 811 else: 812 return _run_in_graph_client(worker_fn, strategy, eval_fn, eval_strategy, 813 cluster_spec, session_config, rpc_layer) 814 else: 815 # If not a client job, run the standard server. 816 _configure_session_config_for_std_servers(strategy, eval_strategy, 817 session_config, cluster_spec, 818 task_type, task_id) 819 server = _run_std_server( 820 cluster_spec=cluster_spec, 821 task_type=task_type, 822 task_id=task_id, 823 session_config=session_config, 824 rpc_layer=rpc_layer, 825 environment=environment) 826 server.join() 827 else: 828 if mode != CoordinatorMode.INDEPENDENT_WORKER: 829 raise ValueError("Unexpected coordinator mode: %r" % mode) 830 831 if not eval_fn: 832 logging.warning("`eval_fn` is not passed in. The `worker_fn` will be " 833 "used if an \"evaluator\" task exists in the cluster.") 834 eval_fn = eval_fn or worker_fn 835 if not eval_strategy: 836 logging.warning("`eval_strategy` is not passed in. No distribution " 837 "strategy will be used for evaluation.") 838 839 # Every one starts a standard server, get session config from `configure` 840 # method. 841 _configure_session_config_for_std_servers(strategy, eval_strategy, 842 session_config, cluster_spec, 843 task_type, task_id) 844 845 if (task_type != _TaskType.EVALUATOR and 846 not getattr(strategy.extended, "_std_server_started", False)): 847 # Right now, with eager mode, context is configured with a std server at 848 # the very beginning while with graph mode the std server is started when 849 # distribute coordinator is called. We should consolidate these two paths. 850 server = _run_std_server( 851 cluster_spec=cluster_spec, 852 task_type=task_type, 853 task_id=task_id, 854 session_config=session_config, 855 rpc_layer=rpc_layer, 856 environment=environment) 857 if task_type in [_TaskType.CHIEF, _TaskType.WORKER]: 858 if strategy.extended.experimental_between_graph: 859 # All jobs run `worker_fn` if between-graph. 860 return _run_single_worker(worker_fn, strategy, cluster_spec, task_type, 861 task_id, session_config, rpc_layer) 862 else: 863 # Only one node runs `worker_fn` if in-graph. 864 context = _WorkerContext(strategy, cluster_spec, task_type, task_id) 865 if context.is_chief: 866 return _run_single_worker(worker_fn, strategy, cluster_spec, None, 867 None, session_config, rpc_layer) 868 else: 869 server.join() 870 elif task_type == _TaskType.EVALUATOR: 871 return _run_single_worker(eval_fn, eval_strategy, cluster_spec, task_type, 872 task_id, session_config, rpc_layer) 873 else: 874 if task_type != _TaskType.PS: 875 raise ValueError("Unexpected task_type: %r" % task_type) 876 server.join() 877