1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Cluster Resolvers are used for dynamic cluster IP/hostname resolution.""" 16 17import abc 18 19import collections 20 21import six 22 23from tensorflow.python.client import session 24from tensorflow.python.eager import context 25from tensorflow.python.framework import config 26from tensorflow.python.framework import ops 27from tensorflow.python.training.server_lib import ClusterSpec 28from tensorflow.python.util.tf_export import tf_export 29 30 31def format_master_url(master, rpc_layer=None): 32 if rpc_layer: 33 return '%s://%s' % (rpc_layer, master) 34 else: 35 return master 36 37 38def get_accelerator_devices(master, config_proto): 39 """Returns accelerator devices given a master and a configuration.""" 40 if context.executing_eagerly(): 41 logical_devices = config.list_logical_devices() 42 devices = [] 43 for d in logical_devices: 44 if d.device_type == 'CPU' or d.device_type == 'XLA_CPU': # Filter CPUs 45 continue 46 devices.append(session._DeviceAttributes(d.name, d.device_type, 0, 0)) # pylint: disable=protected-access 47 return devices 48 else: 49 with ops.Graph().as_default(): 50 with session.Session(master, config=config_proto) as s: 51 devices = s.list_devices() 52 return devices 53 54 55@tf_export('distribute.cluster_resolver.ClusterResolver') 56@six.add_metaclass(abc.ABCMeta) 57class ClusterResolver(object): 58 """Abstract class for all implementations of ClusterResolvers. 59 60 This defines the skeleton for all implementations of ClusterResolvers. 61 ClusterResolvers are a way for TensorFlow to communicate with various cluster 62 management systems (e.g. GCE, AWS, etc...) and gives TensorFlow necessary 63 information to set up distributed training. 64 65 By letting TensorFlow communicate with these systems, we will be able to 66 automatically discover and resolve IP addresses for various TensorFlow 67 workers. This will eventually allow us to automatically recover from 68 underlying machine failures and scale TensorFlow worker clusters up and down. 69 70 Note to Implementors of `tf.distribute.cluster_resolver.ClusterResolver` 71 subclass: In addition to these abstract methods, when task_type, task_id, and 72 rpc_layer attributes are applicable, you should also implement them either as 73 properties with getters or setters, or directly set the attributes 74 `self._task_type`, `self._task_id`, or `self._rpc_layer` so the base class' 75 getters and setters are used. See 76 `tf.distribute.cluster_resolver.SimpleClusterResolver.__init__` for an 77 example. 78 79 In general, multi-client tf.distribute strategies such as 80 `tf.distribute.experimental.MultiWorkerMirroredStrategy` require task_type and 81 task_id properties to be available in the `ClusterResolver` they are using. On 82 the other hand, these concepts are not applicable in single-client strategies, 83 such as `tf.distribute.experimental.TPUStrategy`, because the program is only 84 expected to be run on one task, so there should not be a need to have code 85 branches according to task type and task id. 86 87 - task_type is the name of the server's current named job (e.g. 'worker', 88 'ps' in a distributed parameterized training job). 89 - task_id is the ordinal index of the server within the task type. 90 - rpc_layer is the protocol used by TensorFlow to communicate with other 91 TensorFlow servers in a distributed environment. 92 """ 93 94 @abc.abstractmethod 95 def cluster_spec(self): 96 """Retrieve the current state of the cluster and return a `tf.train.ClusterSpec`. 97 98 Returns: 99 A `tf.train.ClusterSpec` representing the state of the cluster at the 100 moment this function is called. 101 102 Implementors of this function must take care in ensuring that the 103 ClusterSpec returned is up-to-date at the time of calling this function. 104 This usually means retrieving the information from the underlying cluster 105 management system every time this function is invoked and reconstructing 106 a cluster_spec, rather than attempting to cache anything. 107 """ 108 raise NotImplementedError() 109 110 @abc.abstractmethod 111 def master(self, task_type=None, task_id=None, rpc_layer=None): 112 """Retrieves the name or URL of the session master. 113 114 Note: this is only useful for TensorFlow 1.x. 115 116 Args: 117 task_type: (Optional) The type of the TensorFlow task of the master. 118 task_id: (Optional) The index of the TensorFlow task of the master. 119 rpc_layer: (Optional) The RPC protocol for the given cluster. 120 121 Returns: 122 The name or URL of the session master. 123 124 Implementors of this function must take care in ensuring that the master 125 returned is up-to-date at the time to calling this function. This usually 126 means retrieving the master every time this function is invoked. 127 """ 128 raise NotImplementedError() 129 130 def num_accelerators(self, 131 task_type=None, 132 task_id=None, 133 config_proto=None): 134 """Returns the number of accelerator cores per worker. 135 136 This returns the number of accelerator cores (such as GPUs and TPUs) 137 available per worker. 138 139 Optionally, we allow callers to specify the task_type, and task_id, for 140 if they want to target a specific TensorFlow task to query 141 the number of accelerators. This is to support heterogenous environments, 142 where the number of accelerators cores per host is different. 143 144 Args: 145 task_type: (Optional) The type of the TensorFlow task of the machine we 146 want to query. 147 task_id: (Optional) The index of the TensorFlow task of the machine we 148 want to query. 149 config_proto: (Optional) Configuration for starting a new session to 150 query how many accelerator cores it has. 151 152 Returns: 153 A map of accelerator types to number of cores. 154 """ 155 master = self.master(task_type, task_id) 156 # TODO(b/126786766): in eager mode, we should check whether 157 # `tf.config.experimental_connect_to_cluster` is called or not. 158 devices = get_accelerator_devices(master, config_proto) 159 mapping = collections.defaultdict(int) 160 for device in devices: 161 if task_type is not None and task_id is not None: 162 job_path = '/job:%s' % task_type 163 task_path = '/task:%s' % task_id 164 if job_path not in device.name or task_path not in device.name: 165 continue 166 mapping[device.device_type] += 1 167 return mapping 168 169 @property 170 def environment(self): 171 """Returns the current environment which TensorFlow is running in. 172 173 There are two possible return values, "google" (when TensorFlow is running 174 in a Google-internal environment) or an empty string (when TensorFlow is 175 running elsewhere). 176 177 If you are implementing a ClusterResolver that works in both the Google 178 environment and the open-source world (for instance, a TPU ClusterResolver 179 or similar), you will have to return the appropriate string depending on the 180 environment, which you will have to detect. 181 182 Otherwise, if you are implementing a ClusterResolver that will only work 183 in open-source TensorFlow, you do not need to implement this property. 184 """ 185 return '' 186 187 @property 188 def task_type(self): 189 """Returns the task type this `ClusterResolver` indicates. 190 191 In TensorFlow distributed environment, each job may have an applicable 192 task type. Valid task types in TensorFlow include 193 'chief': a worker that is designated with more responsibility, 194 'worker': a regular worker for training/evaluation, 195 'ps': a parameter server, or 196 'evaluator': an evaluator that evaluates the checkpoints for metrics. 197 198 See [Multi-worker configuration]( 199 https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras#multi-worker_configuration) 200 for more information about 'chief' and 'worker' task type, which are most 201 commonly used. 202 203 Having access to such information is useful when user needs to run specific 204 code according to task types. For example, 205 206 ```python 207 cluster_spec = tf.train.ClusterSpec({ 208 "ps": ["localhost:2222", "localhost:2223"], 209 "worker": ["localhost:2224", "localhost:2225", "localhost:2226"] 210 }) 211 212 # SimpleClusterResolver is used here for illustration; other cluster 213 # resolvers may be used for other source of task type/id. 214 simple_resolver = SimpleClusterResolver(cluster_spec, task_type="worker", 215 task_id=1) 216 217 ... 218 219 if cluster_resolver.task_type == 'worker': 220 # Perform something that's only applicable on workers. This block 221 # will run on this particular instance since we've specified this task to 222 # be a worker in above cluster resolver. 223 elif cluster_resolver.task_type == 'ps': 224 # Perform something that's only applicable on parameter servers. This 225 # block will not run on this particular instance. 226 ``` 227 228 Returns `None` if such information is not available or is not applicable 229 in the current distributed environment, such as training with 230 `tf.distribute.experimental.TPUStrategy`. 231 232 For more information, please see 233 `tf.distribute.cluster_resolver.ClusterResolver`'s class doc. 234 """ 235 return getattr(self, '_task_type', None) 236 237 @property 238 def task_id(self): 239 """Returns the task id this `ClusterResolver` indicates. 240 241 In TensorFlow distributed environment, each job may have an applicable 242 task id, which is the index of the instance within its task type. This is 243 useful when user needs to run specific code according to task index. For 244 example, 245 246 ```python 247 cluster_spec = tf.train.ClusterSpec({ 248 "ps": ["localhost:2222", "localhost:2223"], 249 "worker": ["localhost:2224", "localhost:2225", "localhost:2226"] 250 }) 251 252 # SimpleClusterResolver is used here for illustration; other cluster 253 # resolvers may be used for other source of task type/id. 254 simple_resolver = SimpleClusterResolver(cluster_spec, task_type="worker", 255 task_id=0) 256 257 ... 258 259 if cluster_resolver.task_type == 'worker' and cluster_resolver.task_id == 0: 260 # Perform something that's only applicable on 'worker' type, id 0. This 261 # block will run on this particular instance since we've specified this 262 # task to be a 'worker', id 0 in above cluster resolver. 263 else: 264 # Perform something that's only applicable on other ids. This block will 265 # not run on this particular instance. 266 ``` 267 268 Returns `None` if such information is not available or is not applicable 269 in the current distributed environment, such as training with 270 `tf.distribute.cluster_resolver.TPUClusterResolver`. 271 272 For more information, please see 273 `tf.distribute.cluster_resolver.ClusterResolver`'s class docstring. 274 """ 275 return getattr(self, '_task_id', None) 276 277 @task_type.setter 278 def task_type(self, task_type): 279 """Setter of `task_type` property. See `task_type` property doc.""" 280 self._task_type = task_type 281 282 @task_id.setter 283 def task_id(self, task_id): 284 """Setter of `task_id` property. See `task_type` property doc.""" 285 self._task_id = task_id 286 287 288@tf_export('distribute.cluster_resolver.SimpleClusterResolver') 289class SimpleClusterResolver(ClusterResolver): 290 """Simple implementation of ClusterResolver that accepts all attributes. 291 292 Please see the base class for documentation of arguments of its constructor. 293 294 It is useful if you want to specify some or all attributes. 295 296 Usage example with `tf.distribute.Strategy`: 297 298 ```Python 299 cluster = tf.train.ClusterSpec({"worker": ["worker0.example.com:2222", 300 "worker1.example.com:2222"]}) 301 302 # On worker 0 303 cluster_resolver = SimpleClusterResolver(cluster, task_type="worker", 304 task_id=0, 305 num_accelerators={"GPU": 8}, 306 rpc_layer="grpc") 307 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy( 308 cluster_resolver=cluster_resolver) 309 310 # On worker 1 311 cluster_resolver = SimpleClusterResolver(cluster, task_type="worker", 312 task_id=1, 313 num_accelerators={"GPU": 8}, 314 rpc_layer="grpc") 315 strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy( 316 cluster_resolver=cluster_resolver) 317 ``` 318 """ 319 320 def __init__(self, cluster_spec, master='', task_type=None, task_id=None, 321 environment='', num_accelerators=None, 322 rpc_layer=None): 323 """Creates a SimpleClusterResolver from a ClusterSpec.""" 324 super(SimpleClusterResolver, self).__init__() 325 326 self._task_type = task_type 327 self._task_id = task_id 328 self._environment = environment 329 330 self._num_accelerators = num_accelerators 331 self._rpc_layer = rpc_layer 332 333 if not isinstance(cluster_spec, ClusterSpec): 334 raise TypeError('cluster_spec must be a `tf.train.ClusterSpec`.') 335 self._cluster_spec = cluster_spec 336 337 if not isinstance(master, str): 338 raise TypeError('master must be a string.') 339 self._master = master 340 341 def cluster_spec(self): 342 """Returns the ClusterSpec passed into the constructor.""" 343 return self._cluster_spec 344 345 def master(self, task_type=None, task_id=None, rpc_layer=None): 346 """Returns the master address to use when creating a session. 347 348 Note: this is only useful for TensorFlow 1.x. 349 350 Args: 351 task_type: (Optional) The type of the TensorFlow task of the master. 352 task_id: (Optional) The index of the TensorFlow task of the master. 353 rpc_layer: (Optional) The RPC used by distributed TensorFlow. 354 355 Returns: 356 The name or URL of the session master. 357 358 If a task_type and task_id is given, this will override the `master` 359 string passed into the initialization function. 360 """ 361 if task_type is not None and task_id is not None: 362 master = self.cluster_spec().task_address(task_type, task_id) 363 else: 364 master = self._master 365 366 return format_master_url(master, rpc_layer=rpc_layer or self._rpc_layer) 367 368 @property 369 def task_type(self): 370 return self._task_type 371 372 @property 373 def task_id(self): 374 return self._task_id 375 376 @task_type.setter 377 def task_type(self, task_type): 378 self._task_type = task_type 379 380 @task_id.setter 381 def task_id(self, task_id): 382 self._task_id = task_id 383 384 @property 385 def environment(self): 386 return self._environment 387 388 def num_accelerators(self, 389 task_type=None, 390 task_id=None, 391 config_proto=None): 392 """Returns the number of accelerator cores per worker. 393 394 The SimpleClusterResolver does not do automatic detection of accelerators, 395 and thus all arguments are unused and we simply return the value provided 396 in the constructor. 397 398 Args: 399 task_type: Unused. 400 task_id: Unused. 401 config_proto: Unused. 402 """ 403 # Unused 404 del task_type, task_id, config_proto 405 if self._num_accelerators is None: 406 return {} 407 return self._num_accelerators 408 409 @property 410 def rpc_layer(self): 411 return self._rpc_layer 412 413 @rpc_layer.setter 414 def rpc_layer(self, rpc_layer): 415 self._rpc_layer = rpc_layer 416 417 418@tf_export('distribute.cluster_resolver.UnionResolver') 419class UnionClusterResolver(ClusterResolver): 420 """Performs a union on underlying ClusterResolvers. 421 422 This class performs a union given two or more existing ClusterResolvers. It 423 merges the underlying ClusterResolvers, and returns one unified ClusterSpec 424 when cluster_spec is called. The details of the merge function is 425 documented in the cluster_spec function. 426 427 For additional ClusterResolver properties such as task type, task index, 428 rpc layer, environment, etc..., we will return the value from the first 429 ClusterResolver in the union. 430 431 An example to combine two cluster resolvers: 432 433 ```Python 434 cluster_0 = tf.train.ClusterSpec({"worker": ["worker0.example.com:2222", 435 "worker1.example.com:2222"]}) 436 cluster_resolver_0 = SimpleClusterResolver(cluster, task_type="worker", 437 task_id=0, 438 rpc_layer="grpc") 439 440 cluster_1 = tf.train.ClusterSpec({"ps": ["ps0.example.com:2222", 441 "ps1.example.com:2222"]}) 442 cluster_resolver_1 = SimpleClusterResolver(cluster, task_type="ps", 443 task_id=0, 444 rpc_layer="grpc") 445 446 # Its task type would be "worker". 447 cluster_resolver = UnionClusterResolver(cluster_resolver_0, 448 cluster_resolver_1) 449 ``` 450 451 An example to override the number of GPUs in a TFConfigClusterResolver 452 instance: 453 454 ```Python 455 tf_config = TFConfigClusterResolver() 456 gpu_override = SimpleClusterResolver(tf_config.cluster_spec(), 457 num_accelerators={"GPU": 1}) 458 cluster_resolver = UnionResolver(gpu_override, tf_config) 459 ``` 460 """ 461 462 def __init__(self, *args, **kwargs): 463 """Initializes a UnionClusterResolver with other ClusterResolvers. 464 465 Args: 466 *args: `ClusterResolver` objects to be unionized. 467 **kwargs: 468 rpc_layer - (Optional) Override value for the RPC layer used by 469 TensorFlow. 470 task_type - (Optional) Override value for the current task type. 471 task_id - (Optional) Override value for the current task index. 472 473 Raises: 474 TypeError: If any argument is not a subclass of `ClusterResolvers`. 475 ValueError: If there are no arguments passed. 476 """ 477 super(UnionClusterResolver, self).__init__() 478 479 self._rpc_layer = kwargs.pop('rpc_layer', None) 480 self._task_type = kwargs.pop('task_type', None) 481 self._task_id = kwargs.pop('task_id', None) 482 483 if kwargs: 484 raise ValueError('Unexpected kwargs provided {!r}'.format(kwargs)) 485 486 if not args: 487 raise ValueError('At least one ClusterResolver is required.') 488 489 for cluster_resolver in args: 490 if not isinstance(cluster_resolver, ClusterResolver): 491 raise TypeError('All arguments must be a sub-class of ' 492 '`ClusterResolver.`') 493 self._cluster_resolvers = args 494 495 def cluster_spec(self): 496 """Returns a union of all the ClusterSpecs from the ClusterResolvers. 497 498 Returns: 499 A ClusterSpec containing host information merged from all the underlying 500 ClusterResolvers. 501 502 Raises: 503 KeyError: If there are conflicting keys detected when merging two or 504 more dictionaries, this exception is raised. 505 506 Note: If there are multiple ClusterResolvers exposing ClusterSpecs with the 507 same job name, we will merge the list/dict of workers. 508 509 If *all* underlying ClusterSpecs expose the set of workers as lists, we will 510 concatenate the lists of workers, starting with the list of workers from 511 the first ClusterResolver passed into the constructor. 512 513 If *any* of the ClusterSpecs expose the set of workers as a dict, we will 514 treat all the sets of workers as dicts (even if they are returned as lists) 515 and will only merge them into a dict if there is no conflicting keys. If 516 there is a conflicting key, we will raise a `KeyError`. 517 """ 518 519 merged_cluster = {} 520 521 # We figure out whether it is all lists for a particular job, or whether 522 # there are dicts inside. 523 for cluster_resolver in self._cluster_resolvers: 524 cluster_spec = cluster_resolver.cluster_spec() 525 cluster_dict = cluster_spec.as_dict() 526 527 for job_name, tasks in cluster_dict.items(): 528 if job_name in merged_cluster: 529 # If we see a dict, then we write a dict out regardless. 530 if isinstance(tasks, dict): 531 merged_cluster[job_name] = {} 532 else: 533 # We take whichever type is present. 534 if isinstance(tasks, list): 535 merged_cluster[job_name] = [] 536 else: 537 merged_cluster[job_name] = {} 538 539 # We then do the merge as appropriate in merged_cluster[job]. 540 for cluster_resolver in self._cluster_resolvers: 541 cluster_spec = cluster_resolver.cluster_spec() 542 cluster_dict = cluster_spec.as_dict() 543 544 for job_name, tasks in cluster_dict.items(): 545 if isinstance(merged_cluster[job_name], list): 546 # We all have lists, we can just concatenate and be done. 547 merged_cluster[job_name].extend(tasks) 548 else: 549 if isinstance(tasks, list): 550 # We convert to a dictionary if the type is a list. 551 task_dict = dict(zip(range(0, len(tasks)), tasks)) 552 else: 553 # We can simply make a copy (for update) and be done. 554 task_dict = tasks.copy() 555 556 # We detect if there are duplicates, and raise an error if so. 557 task_keys = set(task_dict) 558 merged_keys = set(merged_cluster[job_name].keys()) 559 intersected_keys = task_keys.intersection(merged_keys) 560 if intersected_keys: 561 raise KeyError('Duplicate keys detected when merging two ' 562 'ClusterSpecs: %s' % repr(intersected_keys)) 563 564 # We do the merge after all the processing. 565 merged_cluster[job_name].update(task_dict) 566 567 return ClusterSpec(merged_cluster) 568 569 def master(self, task_type=None, task_id=None, rpc_layer=None): 570 """Returns the master address to use when creating a session. 571 572 This usually returns the master from the first ClusterResolver passed in, 573 but you can override this by specifying the task_type and task_id. 574 575 Note: this is only useful for TensorFlow 1.x. 576 577 Args: 578 task_type: (Optional) The type of the TensorFlow task of the master. 579 task_id: (Optional) The index of the TensorFlow task of the master. 580 rpc_layer: (Optional) The RPC protocol for the given cluster. 581 582 Returns: 583 The name or URL of the session master. 584 """ 585 if task_type is not None and task_id is not None: 586 master = self.cluster_spec().task_address(task_type, task_id) 587 return format_master_url(master, rpc_layer or self._rpc_layer) 588 589 return self._cluster_resolvers[0].master(rpc_layer=rpc_layer) 590 591 @property 592 def task_type(self): 593 return self._task_type or self._cluster_resolvers[0].task_type 594 595 @property 596 def task_id(self): 597 return self._task_id or self._cluster_resolvers[0].task_id 598 599 @task_type.setter 600 def task_type(self, task_type): 601 self._task_type = task_type 602 603 @task_id.setter 604 def task_id(self, task_id): 605 self._task_id = task_id 606 607 @property 608 def environment(self): 609 return self._cluster_resolvers[0].environment 610 611 def num_accelerators(self, 612 task_type=None, 613 task_id=None, 614 config_proto=None): 615 return self._cluster_resolvers[0].num_accelerators( 616 task_type, task_id, config_proto) 617 618 @property 619 def rpc_layer(self): 620 return self._rpc_layer or self._cluster_resolvers[0].rpc_layer 621 622 @rpc_layer.setter 623 def rpc_layer(self, rpc_layer): 624 self._rpc_layer = rpc_layer 625