1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Implementation of Cluster Resolvers for Cloud TPUs.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import os 23import re 24 25from six.moves import urllib 26from six.moves.urllib.error import URLError 27from six.moves.urllib.request import Request 28from six.moves.urllib.request import urlopen 29 30from tensorflow.python.distribute.cluster_resolver.cluster_resolver import ClusterResolver 31from tensorflow.python.distribute.cluster_resolver.cluster_resolver import format_master_url 32from tensorflow.python.distribute.cluster_resolver.cluster_resolver import get_accelerator_devices 33from tensorflow.python.framework import errors 34from tensorflow.python.platform import tf_logging as logging 35from tensorflow.python.training import server_lib 36from tensorflow.python.util import compat 37from tensorflow.python.util.tf_export import tf_export 38 39_GOOGLE_API_CLIENT_INSTALLED = True 40try: 41 from googleapiclient import discovery # pylint: disable=g-import-not-at-top 42 from oauth2client.client import GoogleCredentials # pylint: disable=g-import-not-at-top 43except ImportError: 44 _GOOGLE_API_CLIENT_INSTALLED = False 45 46_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' 47_ENDPOINTS_SEPARATOR = ',' 48_DEFAULT_ENV_VARIABLE = 'TPU_NAME' 49_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL' 50 51_TPU_DEVICE_REGEX = re.compile( 52 r'.*task:(?P<host_id>\d+)/.*device:TPU:(?P<core_id>\d+)$') 53_TPU_CONN_RETRIES = 120 54 55DeviceDetails = collections.namedtuple( 56 'DeviceDetails', ['device_map', 'total_cores']) 57 58 59@tf_export('distribute.cluster_resolver.TPUClusterResolver') 60class TPUClusterResolver(ClusterResolver): 61 """Cluster Resolver for Google Cloud TPUs. 62 63 This is an implementation of cluster resolvers for the Google Cloud TPU 64 service. As Cloud TPUs are in alpha, you will need to specify a API definition 65 file for this to consume, in addition to a list of Cloud TPUs in your Google 66 Cloud Platform project. 67 """ 68 69 def _tpuService(self): 70 """Creates a new Cloud TPU API object. 71 72 This works around an issue where the underlying HTTP connection sometimes 73 times out when the script has been running for too long. Other methods in 74 this object calls this method to get a new API object whenever they need 75 to communicate with the Cloud API. 76 77 Returns: 78 A Google Cloud TPU API object. 79 """ 80 if self._service: 81 return self._service 82 83 credentials = self._credentials 84 if credentials is None or credentials == 'default': 85 credentials = GoogleCredentials.get_application_default() 86 87 if self._discovery_url: 88 return discovery.build( 89 'tpu', 'v1alpha1', 90 credentials=credentials, 91 discoveryServiceUrl=self._discovery_url) 92 else: 93 return discovery.build( 94 'tpu', 'v1alpha1', 95 credentials=credentials) 96 97 def _requestComputeMetadata(self, path): 98 req = Request('http://metadata/computeMetadata/v1/%s' % path, 99 headers={'Metadata-Flavor': 'Google'}) 100 resp = urlopen(req) 101 return compat.as_bytes(resp.read()) 102 103 def _shouldResolve(self): 104 if isinstance(self._should_resolve_override, bool): 105 return self._should_resolve_override 106 if (self._tpu == compat.as_bytes('') or 107 self._tpu == compat.as_bytes('local') or 108 self._tpu.startswith(compat.as_bytes('/bns')) or 109 self._tpu.startswith(compat.as_bytes('localhost:')) or 110 self._tpu.startswith(compat.as_bytes('grpc://')) or 111 self._tpu.startswith(compat.as_bytes('uptc://'))): 112 return False 113 return True 114 115 @staticmethod 116 def _get_device_dict_and_cores(devices): 117 """Returns a dict of hosts to cores and total cores given devices names. 118 119 Returns a namedtuple with two attributes: 120 device_map: A map of host_ids to a list of core_ids. 121 total_cores: The total number of cores within the TPU system. 122 123 Args: 124 devices: A list of devices returned by session.list_devices() 125 """ 126 device_map = collections.defaultdict(list) 127 num_cores = 0 128 for device in devices: 129 match = _TPU_DEVICE_REGEX.match(device.name) 130 if match: 131 host_id = match.group('host_id') 132 core_id = match.group('core_id') 133 device_map[host_id].append(core_id) 134 num_cores += 1 135 return DeviceDetails(device_map, num_cores) 136 137 @staticmethod 138 def _verify_and_return_same_core_count(device_dict): 139 """Verifies that every device in device_dict has the same # of cores.""" 140 num_cores_per_host_set = ( 141 {len(core_ids) for core_ids in device_dict.values()}) 142 if len(num_cores_per_host_set) != 1: 143 raise RuntimeError('TPU cores on each device is not the same. This ' 144 'should never happen. Devices: {}'.format(device_dict)) 145 return num_cores_per_host_set.pop() 146 147 @staticmethod 148 def _inGke(): 149 """When running in GKE, the environment variable will be set.""" 150 return _GKE_ENV_VARIABLE in os.environ 151 152 @staticmethod 153 def _gkeEndpoints(): 154 return os.environ[_GKE_ENV_VARIABLE] 155 156 @staticmethod 157 def _envVarFallback(): 158 if _DEFAULT_ENV_VARIABLE in os.environ: 159 return os.environ[_DEFAULT_ENV_VARIABLE] 160 return None 161 162 @staticmethod 163 def _environmentDiscoveryUrl(): 164 return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE) 165 166 @staticmethod 167 def _isRunningInGCE(): 168 """Checks for GCE presence by attempting to query the metadata service.""" 169 try: 170 req = Request('http://metadata.google.internal/computeMetadata/v1', 171 headers={'Metadata-Flavor': 'Google'}) 172 resp = urllib.request.urlopen(req, timeout=1) 173 info = resp.info() 174 if 'Metadata-Flavor' in info and info['Metadata-Flavor'] == 'Google': 175 return True 176 except URLError: 177 pass 178 return False 179 180 def __init__(self, 181 tpu=None, 182 zone=None, 183 project=None, 184 job_name='worker', 185 coordinator_name=None, 186 coordinator_address=None, 187 credentials='default', 188 service=None, 189 discovery_url=None): 190 """Creates a new TPUClusterResolver object. 191 192 The ClusterResolver will then use the parameters to query the Cloud TPU APIs 193 for the IP addresses and ports of each Cloud TPU listed. 194 195 Args: 196 tpu: A string corresponding to the TPU to use. If the string is the empty 197 string, the string 'local', or a string that begins with 'grpc://' or 198 '/bns', then it is assumed to not correspond with a Cloud TPU and will 199 instead be passed as the session master and no ClusterSpec propagation 200 will be done. In the future, this may also support a list of strings 201 when multiple Cloud TPUs are used. 202 zone: Zone where the TPUs are located. If omitted or empty, we will assume 203 that the zone of the TPU is the same as the zone of the GCE VM, which we 204 will try to discover from the GCE metadata service. 205 project: Name of the GCP project containing Cloud TPUs. If omitted or 206 empty, we will try to discover the project name of the GCE VM from the 207 GCE metadata service. 208 job_name: Name of the TensorFlow job the TPUs belong to. 209 coordinator_name: The name to use for the coordinator. Set to None if the 210 coordinator should not be included in the computed ClusterSpec. 211 coordinator_address: The address of the coordinator (typically an ip:port 212 pair). If set to None, a TF server will be started. If coordinator_name 213 is None, a TF server will not be started even if coordinator_address is 214 None. 215 credentials: GCE Credentials. If None, then we use default credentials 216 from the oauth2client 217 service: The GCE API object returned by the googleapiclient.discovery 218 function. If you specify a custom service object, then the credentials 219 parameter will be ignored. 220 discovery_url: A URL template that points to the location of 221 the discovery service. It should have two parameters {api} and 222 {apiVersion} that when filled in produce an absolute URL to the 223 discovery document for that service. The environment variable 224 'TPU_API_DISCOVERY_URL' will override this. 225 226 Raises: 227 ImportError: If the googleapiclient is not installed. 228 ValueError: If no TPUs are specified. 229 RuntimeError: If an empty TPU name is specified and this is running in a 230 Google Cloud environment. 231 """ 232 if isinstance(tpu, list): 233 if not tpu: 234 raise ValueError('At least one TPU must be specified.') 235 if len(tpu) != 1: 236 raise NotImplementedError( 237 'Using multiple TPUs in a single session is not yet implemented') 238 tpu = tpu[0] 239 240 in_gke = self._inGke() 241 # When using GKE with Cloud TPUs, the env variable will be set. 242 if tpu is None: 243 if in_gke: 244 tpu = self._gkeEndpoints() 245 else: 246 tpu = self._envVarFallback() 247 248 if tpu is None: 249 raise ValueError('Please provide a TPU Name to connect to.') 250 251 self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes 252 253 # If we are running in Cloud and don't specify a TPU name 254 if self._isRunningInGCE() and not self._tpu: 255 raise RuntimeError('You need to specify a TPU Name if you are running in ' 256 'the Google Cloud environment.') 257 258 # By default the task_type is 'worker` and the task_id is 0 (which is the 259 # first worker in the task). 260 self.task_type = job_name 261 self.task_id = 0 262 263 if tpu.startswith('grpc://'): 264 # Cloud environment, where we are using GRPC to communicate to TPUs. 265 self._environment = '' 266 elif tpu == 'local' or not tpu: 267 # Google environment, where the TPU is attached to the host. 268 self._environment = 'google' 269 elif tpu.startswith('/bns') or tpu.startswith('uptc://'): 270 # Google environment, where we reach the TPU through BNS. 271 self._environment = 'google' 272 273 # If TPU is in the Google environment or exists locally, we don't use any 274 # RPC layer. 275 if tpu.startswith('/bns') or tpu.startswith( 276 'uptc://') or tpu == 'local' or not tpu: 277 self.rpc_layer = None 278 else: 279 self.rpc_layer = 'grpc' 280 281 # Setting this overrides the return value of self._shouldResolve() 282 self._should_resolve_override = None 283 284 # We strip out the protocol if it is included, and override the 285 # shouldResolve function to never resolve. We are adding the protocol back 286 # in later in self.master(). 287 if self.rpc_layer is not None and tpu.startswith(self.rpc_layer + '://'): 288 tpu = tpu[len(self.rpc_layer + '://'):] 289 self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes 290 self._should_resolve_override = False 291 292 # Whether we should actually attempt to contact Cloud APIs 293 should_resolve = self._shouldResolve() 294 295 # We error out if we are in a non-Cloud environment which cannot talk to the 296 # Cloud APIs using the standard class and a special object is not passed in. 297 self._service = service 298 if (self._service is None and should_resolve and 299 not _GOOGLE_API_CLIENT_INSTALLED): 300 raise ImportError('googleapiclient and oauth2client must be installed ' 301 'before using the TPU cluster resolver. Execute: ' 302 '`pip install --upgrade google-api-python-client` ' 303 'and `pip install --upgrade oauth2client` to ' 304 'install with pip.') 305 306 # We save user-passed credentials, unless the user didn't pass in anything. 307 self._credentials = credentials 308 if (credentials == 'default' and should_resolve and 309 _GOOGLE_API_CLIENT_INSTALLED): 310 self._credentials = None 311 312 # Automatically detect project and zone if unspecified. 313 if not project and should_resolve: 314 project = compat.as_str( 315 self._requestComputeMetadata('project/project-id')) 316 if not zone and should_resolve: 317 zone_path = compat.as_str(self._requestComputeMetadata('instance/zone')) 318 zone = zone_path.split('/')[-1] 319 self._project = project 320 self._zone = zone 321 322 self._discovery_url = self._environmentDiscoveryUrl() or discovery_url 323 324 self._coordinator_name = coordinator_name 325 if (coordinator_name and not coordinator_address and 326 (should_resolve or in_gke)): 327 self._start_local_server() 328 else: 329 self._coordinator_address = coordinator_address 330 331 def master(self, task_type=None, task_id=None, rpc_layer=None): 332 """Get the Master string to be used for the session. 333 334 In the normal case, this returns the grpc path (grpc://1.2.3.4:8470) of 335 first instance in the ClusterSpec returned by the cluster_spec function. 336 337 If a non-TPU name is used when constructing a TPUClusterResolver, that will 338 be returned instead (e.g. If the tpus argument's value when constructing 339 this TPUClusterResolver was 'grpc://10.240.1.2:8470', 340 'grpc://10.240.1.2:8470' will be returned). 341 342 Args: 343 task_type: (Optional, string) The type of the TensorFlow task of the 344 master. 345 task_id: (Optional, integer) The index of the TensorFlow task of the 346 master. 347 rpc_layer: (Optional, string) The RPC protocol TensorFlow should use to 348 communicate with TPUs. 349 350 Returns: 351 string, the connection string to use when creating a session. 352 353 Raises: 354 ValueError: If none of the TPUs specified exists. 355 """ 356 if self._shouldResolve(): 357 # We are going to communicate with the Cloud TPU APIs to get a Cluster. 358 cluster_spec = self.cluster_spec() 359 if task_type is not None and task_id is not None: 360 # task_type and task_id is from the function parameter 361 master = cluster_spec.task_address(task_type, task_id) 362 elif self.task_type is not None and self.task_id is not None: 363 # task_type and task_id is from the object 364 master = cluster_spec.task_address(self.task_type, self.task_id) 365 else: 366 # by default we take the first item in the cluster with the right name 367 job_tasks = cluster_spec.job_tasks(self.task_type) 368 if not job_tasks: 369 raise ValueError('No TPUs with the specified names exist.') 370 master = job_tasks[0] 371 else: 372 if isinstance(self._tpu, (bytes, bytearray)): 373 master = compat.as_text(self._tpu).split(_ENDPOINTS_SEPARATOR)[0] 374 else: 375 master = self._tpu.split(_ENDPOINTS_SEPARATOR)[0] 376 return format_master_url(master, rpc_layer or self.rpc_layer) 377 378 def get_master(self): 379 return self.master() 380 381 def get_job_name(self): 382 if (self._shouldResolve() or 383 self._isRunningInGCE()): 384 return self.task_type 385 386 def cluster_spec(self): 387 """Returns a ClusterSpec object based on the latest TPU information. 388 389 We retrieve the information from the GCE APIs every time this method is 390 called. 391 392 Returns: 393 A ClusterSpec containing host information returned from Cloud TPUs. 394 395 Raises: 396 RuntimeError: If the provided TPU is not healthy. 397 """ 398 ############################################################################ 399 # There are 5 potential cases this code must handle: 400 # 1. [Normal case.] We should resolve the TPU name to a set of tasks, and 401 # a. Create a ClusterSpec that includes the coordinator job 402 # b. Create a ClusterSpec without the coordinator job. 403 # 2. [GKE / No API Access.] We should not resolve the TPU name to a set of 404 # tasks and 405 # a. Create a ClusterSpec with the coordinator 406 # b. Create a ClusterSpec without the coordinator 407 # 3. [Other (legacy non-gRPC).] We should return an empty ClusterSpec. 408 ############################################################################ 409 410 if self._shouldResolve(): 411 # Case 1. 412 full_name = 'projects/%s/locations/%s/nodes/%s' % ( 413 self._project, self._zone, compat.as_text(self._tpu)) 414 service = self._tpuService() 415 request = service.projects().locations().nodes().get(name=full_name) 416 response = request.execute() 417 418 if 'state' in response and response['state'] != 'READY': 419 raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' % 420 (compat.as_text(self._tpu), response['state'])) 421 422 if 'networkEndpoints' in response: 423 worker_list = [ 424 '%s:%s' % (endpoint['ipAddress'], endpoint['port']) 425 for endpoint in response['networkEndpoints'] 426 ] 427 else: 428 # Fall back to the deprecated response format 429 instance_url = '%s:%s' % (response['ipAddress'], response['port']) 430 worker_list = [instance_url] 431 432 cluster_spec = {self.task_type: worker_list} 433 else: 434 if self.rpc_layer is None: 435 # Case 3. 436 return None 437 # Case 2. 438 tpus = [] 439 for tpu in compat.as_text(self._tpu).split(_ENDPOINTS_SEPARATOR): 440 # We are working around the fact that GKE environment variable that is 441 # supplied to us has the protocol string embedded in it, but we want 442 # to strip it out for the ClusterSpec. 443 if (self.rpc_layer is not None and 444 tpu.startswith(self.rpc_layer + '://')): 445 tpus.append(tpu[len(self.rpc_layer + '://'):]) 446 else: 447 tpus.append(tpu) 448 cluster_spec = {self.task_type: tpus} 449 450 if self._coordinator_address: 451 # {1, 2}.a 452 cluster_spec[self._coordinator_name] = [self._coordinator_address] 453 454 return server_lib.ClusterSpec(cluster_spec) 455 456 def num_accelerators(self, 457 task_type=None, 458 task_id=None, 459 config_proto=None): 460 """Returns the number of TPU cores per worker. 461 462 Connects to the master and list all the devices present in the master, 463 and counts them up. Also verifies that the device counts per host in the 464 cluster is the same before returning the number of TPU cores per host. 465 466 Args: 467 task_type: Unused. 468 task_id: Unused. 469 config_proto: Used to create a connection to a TPU master in order to 470 retrieve the system metadata. 471 472 Raises: 473 RuntimeError: If we cannot talk to a TPU worker after retrying or if the 474 number of TPU devices per host is different. 475 """ 476 retry_count = 1 477 # TODO(b/120564445): Replace with standard library for retries. 478 while True: 479 try: 480 device_details = TPUClusterResolver._get_device_dict_and_cores( 481 get_accelerator_devices(self.master(), config_proto=config_proto)) 482 break 483 except errors.DeadlineExceededError: 484 error_message = ('Failed to connect to master. The TPU might not be ' 485 'ready (e.g. still scheduling) or the master ' 486 'address is incorrect: got (%s)' % self.master()) 487 if retry_count <= _TPU_CONN_RETRIES: 488 logging.warning(error_message) 489 logging.warning('Retrying (%d/%d)...', retry_count, _TPU_CONN_RETRIES) 490 retry_count += 1 491 else: 492 raise RuntimeError(error_message) 493 494 if device_details.total_cores: 495 return {'TPU': TPUClusterResolver._verify_and_return_same_core_count( 496 device_details.device_map)} 497 return {'TPU': 0} 498 499 @property 500 def environment(self): 501 """Returns the current environment which TensorFlow is running in.""" 502 return self._environment 503 504 def _start_local_server(self): 505 address = compat.as_text(self._requestComputeMetadata( 506 'instance/network-interfaces/0/ip')) 507 self._server = server_lib.Server( 508 { 509 'local': ['0.0.0.0:0'] 510 }, protocol='grpc', config=None, start=True) 511 # self._server.target is of the form: grpc://ipaddress:port 512 target = compat.as_bytes(self._server.target) 513 splits = target.split(compat.as_bytes(':')) 514 assert len(splits) == 3, self._server.target 515 assert splits[0] == compat.as_bytes('grpc'), self._server.target 516 self._coordinator_port = compat.as_text(splits[2]) 517 self._coordinator_address = '%s:%s' % ( 518 address, compat.as_text(self._coordinator_port)) 519 520 def __deepcopy__(self, memo): 521 # TODO(b/73668574): Remove this once RunConfig avoids performing deepcopy. 522 return self 523