1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Implementation of Cluster Resolvers for Cloud TPUs.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import re 23 24from tensorflow.python.distribute.cluster_resolver import cluster_resolver 25from tensorflow.python.framework import config as framework_config 26from tensorflow.python.framework import errors 27from tensorflow.python.platform import tf_logging as logging 28from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib 29from tensorflow.python.training import server_lib 30from tensorflow.python.util import compat 31 32try: 33 from cloud_tpu_client import client # pylint: disable=g-import-not-at-top 34except ImportError: 35 logging.debug( 36 'Falling back to TensorFlow client; we recommended you install the Cloud ' 37 'TPU client directly with pip install cloud-tpu-client.') 38 from tensorflow.python.tpu.client import client # pylint: disable=g-import-not-at-top 39 40 41def is_running_in_gce(): 42 return True 43 44 45class _LocalCloudTpuClient(object): 46 """Dummy local Cloud TPU client.""" 47 48 def api_available(self): 49 return False 50 51 52_TPU_DEVICE_REGEX = re.compile( 53 r'.*task:(?P<host_id>\d+)/.*device:TPU:(?P<core_id>\d+)$') 54_TPU_CONN_RETRIES = 120 55DeviceDetails = collections.namedtuple( 56 'DeviceDetails', ['device_map', 'total_cores']) 57 58 59class TPUClusterResolver(cluster_resolver.ClusterResolver): 60 """Cluster Resolver for Google Cloud TPUs. 61 62 This is an implementation of cluster resolvers for the Google Cloud TPU 63 service. 64 65 TPUClusterResolver supports the following distinct environments: 66 Google Compute Engine 67 Google Kubernetes Engine 68 Google internal 69 70 It can be passed into `tf.distribute.TPUStrategy` to support TF2 training on 71 Cloud TPUs. 72 """ 73 74 @staticmethod 75 def connect(tpu=None, 76 zone=None, 77 project=None): 78 """Initializes TPU and returns a TPUClusterResolver. 79 80 This API will connect to remote TPU cluster and initialize the TPU 81 hardwares. Example usage: 82 83 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver.connect( 84 ... tpu='') 85 86 It can be viewed as convenient wrapper of the following code: 87 88 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 89 >>> tf.config.experimental_connect_to_cluster(resolver) 90 >>> tf.tpu.experimental.initialize_tpu_system(resolver) 91 92 Args: 93 tpu: A string corresponding to the TPU to use. It can be the TPU name or 94 TPU worker gRPC address. If not set, it will try automatically resolve 95 the TPU address on Cloud TPUs. 96 zone: Zone where the TPUs are located. If omitted or empty, we will assume 97 that the zone of the TPU is the same as the zone of the GCE VM, which we 98 will try to discover from the GCE metadata service. 99 project: Name of the GCP project containing Cloud TPUs. If omitted or 100 empty, we will try to discover the project name of the GCE VM from the 101 GCE metadata service. 102 103 Returns: 104 An instance of TPUClusterResolver object. 105 106 Raises: 107 NotFoundError: If no TPU devices found in eager mode. 108 """ 109 resolver = TPUClusterResolver(tpu, zone, project) 110 from tensorflow.python.eager import remote # pylint: disable=g-import-not-at-top 111 remote.connect_to_cluster(resolver) 112 from tensorflow.python.tpu import tpu_strategy_util # pylint: disable=g-import-not-at-top 113 tpu_strategy_util.initialize_tpu_system(resolver) 114 return resolver 115 116 @staticmethod 117 def _get_device_dict_and_cores(devices): 118 """Returns a dict of hosts to cores and total cores given devices names. 119 120 Returns a namedtuple with two attributes: 121 device_map: A map of host_ids to a list of core_ids. 122 total_cores: The total number of cores within the TPU system. 123 124 Args: 125 devices: A list of devices returned by session.list_devices() 126 """ 127 device_map = collections.defaultdict(list) 128 num_cores = 0 129 for device in devices: 130 match = _TPU_DEVICE_REGEX.match(device.name) 131 if match: 132 host_id = match.group('host_id') 133 core_id = match.group('core_id') 134 device_map[host_id].append(core_id) 135 num_cores += 1 136 return DeviceDetails(device_map, num_cores) 137 138 @staticmethod 139 def _verify_and_return_same_core_count(device_dict): 140 """Verifies that every device in device_dict has the same # of cores.""" 141 num_cores_per_host_set = ( 142 {len(core_ids) for core_ids in device_dict.values()}) 143 if len(num_cores_per_host_set) != 1: 144 raise RuntimeError('TPU cores on each device is not the same. This ' 145 'should never happen. Devices: {}'.format(device_dict)) 146 return num_cores_per_host_set.pop() 147 148 def __init__(self, 149 tpu=None, 150 zone=None, 151 project=None, 152 job_name='worker', 153 coordinator_name=None, 154 coordinator_address=None, 155 credentials='default', 156 service=None, 157 discovery_url=None): 158 """Creates a new TPUClusterResolver object. 159 160 The ClusterResolver will then use the parameters to query the Cloud TPU APIs 161 for the IP addresses and ports of each Cloud TPU listed. 162 163 Args: 164 tpu: A string corresponding to the TPU to use. It can be the TPU name or 165 TPU worker gRPC address. If not set, it will try automatically resolve 166 the TPU address on Cloud TPUs. If set to "local", it will assume that 167 the TPU is directly connected to the VM instead of over the network. 168 zone: Zone where the TPUs are located. If omitted or empty, we will assume 169 that the zone of the TPU is the same as the zone of the GCE VM, which we 170 will try to discover from the GCE metadata service. 171 project: Name of the GCP project containing Cloud TPUs. If omitted or 172 empty, we will try to discover the project name of the GCE VM from the 173 GCE metadata service. 174 job_name: Name of the TensorFlow job the TPUs belong to. 175 coordinator_name: The name to use for the coordinator. Set to None if the 176 coordinator should not be included in the computed ClusterSpec. 177 coordinator_address: The address of the coordinator (typically an ip:port 178 pair). If set to None, a TF server will be started. If coordinator_name 179 is None, a TF server will not be started even if coordinator_address is 180 None. 181 credentials: GCE Credentials. If None, then we use default credentials 182 from the oauth2client 183 service: The GCE API object returned by the googleapiclient.discovery 184 function. If you specify a custom service object, then the credentials 185 parameter will be ignored. 186 discovery_url: A URL template that points to the location of the discovery 187 service. It should have two parameters {api} and {apiVersion} that when 188 filled in produce an absolute URL to the discovery document for that 189 service. The environment variable 'TPU_API_DISCOVERY_URL' will override 190 this. 191 192 Raises: 193 ImportError: If the googleapiclient is not installed. 194 ValueError: If no TPUs are specified. 195 RuntimeError: If an empty TPU name is specified and this is running in a 196 Google Cloud environment. 197 """ 198 199 if tpu != 'local': 200 # Default Cloud environment 201 self._cloud_tpu_client = client.Client( 202 tpu=tpu, 203 zone=zone, 204 project=project, 205 credentials=credentials, 206 service=service, 207 discovery_url=discovery_url) 208 self._tpu = self._cloud_tpu_client.name() 209 else: 210 # Directly connected TPU environment 211 self._cloud_tpu_client = _LocalCloudTpuClient() 212 self._tpu = 'local' 213 214 # By default the task_type is 'worker` and the task_id is 0 (which is the 215 # first worker in the task). 216 self.task_type = job_name 217 self.task_id = 0 218 self._coordinator_name = coordinator_name 219 if (coordinator_name and not coordinator_address): 220 self._start_local_server() 221 else: 222 self._coordinator_address = coordinator_address 223 224 def __enter__(self): 225 self._cloud_tpu_client.enter() 226 227 def __exit__(self, type, value, traceback): # pylint: disable=redefined-builtin 228 self._cloud_tpu_client.exit(type, value, traceback) 229 230 def master(self, task_type=None, task_id=None, rpc_layer=None): 231 """Get the Master string to be used for the session. 232 233 In the normal case, this returns the grpc path (grpc://1.2.3.4:8470) of 234 first instance in the ClusterSpec returned by the cluster_spec function. 235 236 If a non-TPU name is used when constructing a TPUClusterResolver, that will 237 be returned instead (e.g. If the tpus argument's value when constructing 238 this TPUClusterResolver was 'grpc://10.240.1.2:8470', 239 'grpc://10.240.1.2:8470' will be returned). 240 241 Args: 242 task_type: (Optional, string) The type of the TensorFlow task of the 243 master. 244 task_id: (Optional, integer) The index of the TensorFlow task of the 245 master. 246 rpc_layer: (Optional, string) The RPC protocol TensorFlow should use to 247 communicate with TPUs. 248 249 Returns: 250 string, the connection string to use when creating a session. 251 252 Raises: 253 ValueError: If none of the TPUs specified exists. 254 """ 255 256 if self._tpu != 'local': 257 cluster_spec = self.cluster_spec() 258 if task_type is not None and task_id is not None: 259 # task_type and task_id is from the function parameter 260 master = cluster_spec.task_address(task_type, task_id) 261 elif self.task_type is not None and self.task_id is not None: 262 # task_type and task_id is from the object 263 master = cluster_spec.task_address(self.task_type, self.task_id) 264 else: 265 # by default we take the first item in the cluster with the right name 266 job_tasks = cluster_spec.job_tasks(self.task_type) 267 if not job_tasks: 268 raise ValueError('No TPUs with the specified names exist.') 269 master = job_tasks[0] 270 return cluster_resolver.format_master_url(master, 'grpc') 271 else: 272 return '' 273 274 def get_master(self): 275 return self.master() 276 277 def get_job_name(self): 278 return self.task_type 279 280 def get_tpu_system_metadata(self): 281 """Returns the metadata of the TPU system. 282 283 Users can call this method to get some facts of the TPU system, like 284 total number of cores, number of TPU workers and the devices. E.g. 285 ```python 286 287 resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 288 tpu_system_medata = resolver.get_tpu_system_metadata() 289 num_hosts = tpu_system_medata.num_hosts 290 ``` 291 292 Returns: 293 A `tf.tpu.experimental.TPUSystemMetadata` object. 294 """ 295 cluster_spec = self.cluster_spec() 296 cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None 297 tpu_system_metadata = ( 298 tpu_system_metadata_lib._query_tpu_system_metadata( # pylint: disable=protected-access 299 self.master(), 300 cluster_def=cluster_def, 301 query_topology=False)) 302 303 return tpu_system_metadata 304 305 def cluster_spec(self): 306 """Returns a ClusterSpec object based on the latest TPU information. 307 308 We retrieve the information from the GCE APIs every time this method is 309 called. 310 311 Returns: 312 A ClusterSpec containing host information returned from Cloud TPUs, 313 or None. 314 315 Raises: 316 RuntimeError: If the provided TPU is not healthy. 317 """ 318 ############################################################################ 319 # There are 6 potential cases this code must handle: 320 # 0. [Local case.] When a TPU is connected directly to the VM. 321 # 1. [Normal case.] We should resolve the TPU name to a set of tasks, and 322 # a. Create a ClusterSpec that includes the coordinator job 323 # b. Create a ClusterSpec without the coordinator job. 324 # 2. [GKE / No API Access.] We should not resolve the TPU name to a set of 325 # tasks and 326 # a. Create a ClusterSpec with the coordinator 327 # b. Create a ClusterSpec without the coordinator 328 ############################################################################ 329 330 if self._tpu != 'local': 331 network_endpoints = self._cloud_tpu_client.network_endpoints() 332 worker_list = [ 333 '%s:%s' % (endpoint['ipAddress'], endpoint['port']) 334 for endpoint in network_endpoints 335 ] 336 cluster_spec = {self.task_type: worker_list} 337 if self._coordinator_address: 338 # {1, 2}.a 339 cluster_spec[self._coordinator_name] = [self._coordinator_address] 340 return server_lib.ClusterSpec(cluster_spec) 341 else: 342 return server_lib.ClusterSpec({}) 343 344 def num_accelerators(self, 345 task_type=None, 346 task_id=None, 347 config_proto=None): 348 """Returns the number of TPU cores per worker. 349 350 Connects to the master and list all the devices present in the master, 351 and counts them up. Also verifies that the device counts per host in the 352 cluster is the same before returning the number of TPU cores per host. 353 354 Args: 355 task_type: Unused. 356 task_id: Unused. 357 config_proto: Used to create a connection to a TPU master in order to 358 retrieve the system metadata. 359 360 Raises: 361 RuntimeError: If we cannot talk to a TPU worker after retrying or if the 362 number of TPU devices per host is different. 363 """ 364 if self._tpu == 'local': 365 return { 366 'TPU': 367 len([ 368 d for d in framework_config.list_logical_devices() 369 if d.device_type == 'TPU' 370 ]) 371 } 372 373 retry_count = 1 374 # TODO(b/120564445): Replace with standard library for retries. 375 while True: 376 try: 377 device_details = TPUClusterResolver._get_device_dict_and_cores( 378 cluster_resolver.get_accelerator_devices( 379 self.master(), config_proto=config_proto)) 380 break 381 except errors.DeadlineExceededError: 382 error_message = ('Failed to connect to master. The TPU might not be ' 383 'ready (e.g. still scheduling) or the master ' 384 'address is incorrect: got (%s)' % self.master()) 385 if retry_count <= _TPU_CONN_RETRIES: 386 logging.warning(error_message) 387 logging.warning('Retrying (%d/%d)...', retry_count, _TPU_CONN_RETRIES) 388 retry_count += 1 389 else: 390 raise RuntimeError(error_message) 391 392 if device_details.total_cores: 393 return { 394 'TPU': 395 TPUClusterResolver._verify_and_return_same_core_count( 396 device_details.device_map) 397 } 398 return {'TPU': 0} 399 400 @property 401 def environment(self): 402 """Returns the current environment which TensorFlow is running in.""" 403 return self._environment 404 405 def _start_local_server(self): 406 address = compat.as_text(self._cloud_tpu_client.get_local_ip()) 407 self._server = server_lib.Server({'local': ['0.0.0.0:0']}, 408 protocol='grpc', 409 config=None, 410 start=True) 411 # self._server.target is of the form: grpc://ipaddress:port 412 target = compat.as_bytes(self._server.target) 413 splits = target.split(compat.as_bytes(':')) 414 assert len(splits) == 3, self._server.target 415 assert splits[0] == compat.as_bytes('grpc'), self._server.target 416 self._coordinator_port = compat.as_text(splits[2]) 417 self._coordinator_address = '%s:%s' % ( 418 address, compat.as_text(self._coordinator_port)) 419 420 def __deepcopy__(self, memo): 421 # TODO(b/73668574): Remove this once RunConfig avoids performing deepcopy. 422 return self 423