1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Cluster Resolvers are used for dynamic cluster IP/hostname resolution.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import abc 22 23import collections 24import re 25import six 26 27from tensorflow.python.client import session 28from tensorflow.python.eager import context 29from tensorflow.python.framework import ops 30from tensorflow.python.training.server_lib import ClusterSpec 31from tensorflow.python.util.tf_export import tf_export 32 33 34DEVICE_TYPE_REGEX = re.compile('.*device:([^:]+).*') 35 36 37def format_master_url(master, rpc_layer=None): 38 if rpc_layer: 39 return '%s://%s' % (rpc_layer, master) 40 else: 41 return master 42 43 44def get_accelerator_devices(master, config_proto): 45 """Returns accelerator devices given a master and a configuration.""" 46 if context.executing_eagerly(): 47 device_names = context.list_devices() # list_devices returns list(string) 48 devices = [] 49 for name in device_names: 50 device_type = 'GPU' # default device type is GPU 51 device_match = DEVICE_TYPE_REGEX.match(name) 52 if device_match: 53 device_type = device_match.group(1) 54 if device_type == 'CPU' or device_type == 'XLA_CPU': # Filter CPUs 55 continue 56 devices.append(session._DeviceAttributes(name, device_type, 0, 0)) # pylint: disable=protected-access 57 return devices 58 else: 59 with ops.Graph().as_default(): 60 with session.Session(master, config=config_proto) as s: 61 devices = s.list_devices() 62 return devices 63 64 65@tf_export('distribute.cluster_resolver.ClusterResolver') 66@six.add_metaclass(abc.ABCMeta) 67class ClusterResolver(object): 68 """Abstract class for all implementations of ClusterResolvers. 69 70 This defines the skeleton for all implementations of ClusterResolvers. 71 ClusterResolvers are a way for TensorFlow to communicate with various cluster 72 management systems (e.g. GCE, AWS, etc...). 73 74 By letting TensorFlow communicate with these systems, we will be able to 75 automatically discover and resolve IP addresses for various TensorFlow 76 workers. This will eventually allow us to automatically recover from 77 underlying machine failures and scale TensorFlow worker clusters up and down. 78 79 Note to Implementors: In addition to these abstract methods, you must also 80 implement the task_type, task_id, and rpc_layer attributes. You may choose 81 to implement them either as properties with getters or setters or directly 82 set the attributes. 83 84 - task_type is the name of the server's current named job (e.g. 'worker', 85 'ps' in a distributed parameterized training job). 86 - task_id is the ordinal index of the server within the task type. 87 - rpc_layer is the protocol used by TensorFlow to communicate with other 88 TensorFlow servers in a distributed environment. 89 """ 90 91 @abc.abstractmethod 92 def cluster_spec(self): 93 """Retrieve the current state of the cluster and returns a ClusterSpec. 94 95 Returns: 96 A ClusterSpec representing the state of the cluster at the moment this 97 function is called. 98 99 Implementors of this function must take care in ensuring that the 100 ClusterSpec returned is up-to-date at the time of calling this function. 101 This usually means retrieving the information from the underlying cluster 102 management system every time this function is invoked and reconstructing 103 a cluster_spec, rather than attempting to cache anything. 104 """ 105 raise NotImplementedError() 106 107 @abc.abstractmethod 108 def master(self, task_type=None, task_id=None, rpc_layer=None): 109 """Retrieves the name or URL of the session master. 110 111 Args: 112 task_type: (Optional) The type of the TensorFlow task of the master. 113 task_id: (Optional) The index of the TensorFlow task of the master. 114 rpc_layer: (Optional) The RPC protocol for the given cluster. 115 116 Returns: 117 The name or URL of the session master. 118 119 Implementors of this function must take care in ensuring that the master 120 returned is up-to-date at the time to calling this function. This usually 121 means retrieving the master every time this function is invoked. 122 """ 123 raise NotImplementedError() 124 125 def num_accelerators(self, 126 task_type=None, 127 task_id=None, 128 config_proto=None): 129 """Returns the number of accelerator cores per worker. 130 131 This returns the number of accelerator cores (such as GPUs and TPUs) 132 available per worker. 133 134 Optionally, we allow callers to specify the task_type, and task_id, for 135 if they want to target a specific TensorFlow process to query 136 the number of accelerators. This is to support heterogenous environments, 137 where the number of accelerators cores per host is different. 138 139 Args: 140 task_type: (Optional) The type of the TensorFlow task of the machine we 141 want to query. 142 task_id: (Optional) The index of the TensorFlow task of the machine we 143 want to query. 144 config_proto: (Optional) Configuration for starting a new session to 145 query how many accelerator cores it has. 146 147 Returns: 148 A map of accelerator types to number of cores. 149 """ 150 master = self.master(task_type, task_id) 151 devices = get_accelerator_devices(master, config_proto) 152 mapping = collections.defaultdict(int) 153 for device in devices: 154 if task_type is not None and task_id is not None: 155 job_path = '/job:%s' % task_type 156 task_path = '/task:%s' % task_id 157 if job_path not in device.name or task_path not in device.name: 158 continue 159 mapping[device.device_type] += 1 160 return mapping 161 162 @property 163 def environment(self): 164 """Returns the current environment which TensorFlow is running in. 165 166 There are two possible return values, "google" (when TensorFlow is running 167 in a Google-internal environment) or an empty string (when TensorFlow is 168 running elsewhere). 169 170 If you are implementing a ClusterResolver that works in both the Google 171 environment and the open-source world (for instance, a TPU ClusterResolver 172 or similar), you will have to return the appropriate string depending on the 173 environment, which you will have to detect. 174 175 Otherwise, if you are implementing a ClusterResolver that will only work 176 in open-source TensorFlow, you do not need to implement this property. 177 """ 178 return '' 179 180 181@tf_export('distribute.cluster_resolver.SimpleClusterResolver') 182class SimpleClusterResolver(ClusterResolver): 183 """Simple implementation of ClusterResolver that accepts a ClusterSpec.""" 184 185 def __init__(self, cluster_spec, master='', task_type=None, task_id=None, 186 environment='', num_accelerators=None, 187 rpc_layer=None): 188 """Creates a SimpleClusterResolver from a ClusterSpec.""" 189 super(SimpleClusterResolver, self).__init__() 190 191 self._task_type = task_type 192 self._task_id = task_id 193 self._environment = environment 194 195 self._num_accelerators = num_accelerators 196 self._rpc_layer = rpc_layer 197 198 if not isinstance(cluster_spec, ClusterSpec): 199 raise TypeError('cluster_spec must be a ClusterSpec.') 200 self._cluster_spec = cluster_spec 201 202 if not isinstance(master, str): 203 raise TypeError('master must be a string.') 204 self._master = master 205 206 def cluster_spec(self): 207 """Returns the ClusterSpec passed into the constructor.""" 208 return self._cluster_spec 209 210 def master(self, task_type=None, task_id=None, rpc_layer=None): 211 """Returns the master address to use when creating a session. 212 213 Args: 214 task_type: (Optional) The type of the TensorFlow task of the master. 215 task_id: (Optional) The index of the TensorFlow task of the master. 216 rpc_layer: (Optional) The RPC used by distributed TensorFlow. 217 218 Returns: 219 The name or URL of the session master. 220 221 If a task_type and task_id is given, this will override the `master` 222 string passed into the initialization function. 223 """ 224 if task_type is not None and task_id is not None: 225 master = self.cluster_spec().task_address(task_type, task_id) 226 else: 227 master = self._master 228 229 return format_master_url(master, rpc_layer=rpc_layer or self._rpc_layer) 230 231 @property 232 def task_type(self): 233 return self._task_type 234 235 @property 236 def task_id(self): 237 return self._task_id 238 239 @task_type.setter 240 def task_type(self, task_type): 241 self._task_type = task_type 242 243 @task_id.setter 244 def task_id(self, task_id): 245 self._task_id = task_id 246 247 @property 248 def environment(self): 249 return self._environment 250 251 def num_accelerators(self, 252 task_type=None, 253 task_id=None, 254 config_proto=None): 255 """Returns the number of accelerator cores per worker. 256 257 The SimpleClusterResolver does not do automatic detection of accelerators, 258 so a TensorFlow session will never be created, and thus all arguments are 259 unused and we simply assume that the type of accelerator is a GPU and return 260 the value in provided to us in the constructor. 261 262 Args: 263 task_type: Unused. 264 task_id: Unused. 265 config_proto: Unused. 266 """ 267 # Unused 268 del task_type, task_id, config_proto 269 if self._num_accelerators is None: 270 return {} 271 return self._num_accelerators 272 273 @property 274 def rpc_layer(self): 275 return self._rpc_layer 276 277 @rpc_layer.setter 278 def rpc_layer(self, rpc_layer): 279 self._rpc_layer = rpc_layer 280 281 282@tf_export('distribute.cluster_resolver.UnionResolver') 283class UnionClusterResolver(ClusterResolver): 284 """Performs a union on underlying ClusterResolvers. 285 286 This class performs a union given two or more existing ClusterResolvers. It 287 merges the underlying ClusterResolvers, and returns one unified ClusterSpec 288 when cluster_spec is called. The details of the merge function is 289 documented in the cluster_spec function. 290 291 For additional Cluster Resolver properties such as task type, task index, 292 rpc layer, environment, etc..., we will return the value from the first 293 ClusterResolver in the union. 294 """ 295 296 def __init__(self, *args, **kwargs): 297 """Initializes a UnionClusterResolver with other ClusterResolvers. 298 299 Args: 300 *args: `ClusterResolver` objects to be unionized. 301 **kwargs: 302 rpc_layer - (Optional) Override value for the RPC layer used by 303 TensorFlow. 304 task_type - (Optional) Override value for the current task type. 305 task_id - (Optional) Override value for the current task index. 306 307 Raises: 308 TypeError: If any argument is not a subclass of `ClusterResolvers`. 309 ValueError: If there are no arguments passed. 310 """ 311 super(UnionClusterResolver, self).__init__() 312 313 self._rpc_layer = kwargs.pop('rpc_layer', None) 314 self._task_type = kwargs.pop('task_type', None) 315 self._task_id = kwargs.pop('task_id', None) 316 317 if kwargs: 318 raise ValueError('Unexpected kwargs provided {!r}'.format(kwargs)) 319 320 if not args: 321 raise ValueError('At least one ClusterResolver is required.') 322 323 for cluster_resolver in args: 324 if not isinstance(cluster_resolver, ClusterResolver): 325 raise TypeError('All arguments must be a sub-class of ' 326 '`ClusterResolver.`') 327 self._cluster_resolvers = args 328 329 def cluster_spec(self): 330 """Returns a union of all the ClusterSpecs from the ClusterResolvers. 331 332 Returns: 333 A ClusterSpec containing host information merged from all the underlying 334 ClusterResolvers. 335 336 Raises: 337 KeyError: If there are conflicting keys detected when merging two or 338 more dictionaries, this exception is raised. 339 340 Note: If there are multiple ClusterResolvers exposing ClusterSpecs with the 341 same job name, we will merge the list/dict of workers. 342 343 If *all* underlying ClusterSpecs expose the set of workers as lists, we will 344 concatenate the lists of workers, starting with the list of workers from 345 the first ClusterResolver passed into the constructor. 346 347 If *any* of the ClusterSpecs expose the set of workers as a dict, we will 348 treat all the sets of workers as dicts (even if they are returned as lists) 349 and will only merge them into a dict if there is no conflicting keys. If 350 there is a conflicting key, we will raise a `KeyError`. 351 """ 352 353 merged_cluster = {} 354 355 # We figure out whether it is all lists for a particular job, or whether 356 # there are dicts inside. 357 for cluster_resolver in self._cluster_resolvers: 358 cluster_spec = cluster_resolver.cluster_spec() 359 cluster_dict = cluster_spec.as_dict() 360 361 for job_name, tasks in cluster_dict.items(): 362 if job_name in merged_cluster: 363 # If we see a dict, then we write a dict out regardless. 364 if isinstance(tasks, dict): 365 merged_cluster[job_name] = {} 366 else: 367 # We take whichever type is present. 368 if isinstance(tasks, list): 369 merged_cluster[job_name] = [] 370 else: 371 merged_cluster[job_name] = {} 372 373 # We then do the merge as appropriate in merged_cluster[job]. 374 for cluster_resolver in self._cluster_resolvers: 375 cluster_spec = cluster_resolver.cluster_spec() 376 cluster_dict = cluster_spec.as_dict() 377 378 for job_name, tasks in cluster_dict.items(): 379 if isinstance(merged_cluster[job_name], list): 380 # We all have lists, we can just concatenate and be done. 381 merged_cluster[job_name].extend(tasks) 382 else: 383 if isinstance(tasks, list): 384 # We convert to a dictionary if the type is a list. 385 task_dict = dict(zip(range(0, len(tasks)), tasks)) 386 else: 387 # We can simply make a copy (for update) and be done. 388 task_dict = tasks.copy() 389 390 # We detect if there are duplicates, and raise an error if so. 391 task_keys = set(task_dict) 392 merged_keys = set(merged_cluster[job_name].keys()) 393 intersected_keys = task_keys.intersection(merged_keys) 394 if intersected_keys: 395 raise KeyError('Duplicate keys detected when merging two ' 396 'ClusterSpecs: %s' % repr(intersected_keys)) 397 398 # We do the merge after all the processing. 399 merged_cluster[job_name].update(task_dict) 400 401 return ClusterSpec(merged_cluster) 402 403 def master(self, task_type=None, task_id=None, rpc_layer=None): 404 """Returns the master address to use when creating a session. 405 406 This usually returns the master from the first ClusterResolver passed in, 407 but you can override this by specifying the task_type and task_id. 408 409 Args: 410 task_type: (Optional) The type of the TensorFlow task of the master. 411 task_id: (Optional) The index of the TensorFlow task of the master. 412 rpc_layer: (Optional) The RPC protocol for the given cluster. 413 414 Returns: 415 The name or URL of the session master. 416 """ 417 if task_type is not None and task_id is not None: 418 master = self.cluster_spec().task_address(task_type, task_id) 419 return format_master_url(master, rpc_layer or self._rpc_layer) 420 421 return self._cluster_resolvers[0].master(rpc_layer=rpc_layer) 422 423 @property 424 def task_type(self): 425 return self._task_type or self._cluster_resolvers[0].task_type 426 427 @property 428 def task_id(self): 429 return self._task_id or self._cluster_resolvers[0].task_id 430 431 @task_type.setter 432 def task_type(self, task_type): 433 self._task_type = task_type 434 435 @task_id.setter 436 def task_id(self, task_id): 437 self._task_id = task_id 438 439 @property 440 def environment(self): 441 return self._cluster_resolvers[0].environment 442 443 def num_accelerators(self, 444 task_type=None, 445 task_id=None, 446 config_proto=None): 447 return self._cluster_resolvers[0].num_accelerators( 448 task_type, task_id, config_proto) 449 450 @property 451 def rpc_layer(self): 452 return self._rpc_layer or self._cluster_resolvers[0].rpc_layer 453 454 @rpc_layer.setter 455 def rpc_layer(self, rpc_layer): 456 self._rpc_layer = rpc_layer 457