1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Implementation of Cluster Resolvers for Cloud TPUs.""" 16 17import collections 18import re 19 20from tensorflow.core.protobuf.tpu import topology_pb2 21from tensorflow.python.distribute.cluster_resolver import cluster_resolver 22from tensorflow.python.framework import config as framework_config 23from tensorflow.python.framework import errors 24from tensorflow.python.platform import tf_logging as logging 25from tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib 26from tensorflow.python.training import server_lib 27from tensorflow.python.util import compat 28 29try: 30 from cloud_tpu_client import client # pylint: disable=g-import-not-at-top 31except ImportError: 32 logging.debug( 33 'Falling back to TensorFlow client; we recommended you install the Cloud ' 34 'TPU client directly with pip install cloud-tpu-client.') 35 from tensorflow.python.tpu.client import client # pylint: disable=g-import-not-at-top 36 37 38def is_running_in_gce(): 39 return True 40 41 42class _LocalCloudTpuClient(object): 43 """Dummy local Cloud TPU client.""" 44 45 def api_available(self): 46 return False 47 48 49_TPU_DEVICE_REGEX = re.compile( 50 r'.*task:(?P<host_id>\d+)/.*device:TPU:(?P<core_id>\d+)$') 51_TPU_CONN_RETRIES = 120 52DeviceDetails = collections.namedtuple( 53 'DeviceDetails', ['device_map', 'total_cores']) 54 55 56class TPUClusterResolver(cluster_resolver.ClusterResolver): 57 """Cluster Resolver for Google Cloud TPUs. 58 59 This is an implementation of cluster resolvers for the Google Cloud TPU 60 service. 61 62 TPUClusterResolver supports the following distinct environments: 63 Google Compute Engine 64 Google Kubernetes Engine 65 Google internal 66 67 It can be passed into `tf.distribute.TPUStrategy` to support TF2 training on 68 Cloud TPUs. 69 """ 70 71 @staticmethod 72 def connect(tpu=None, 73 zone=None, 74 project=None): 75 """Initializes TPU and returns a TPUClusterResolver. 76 77 This API will connect to remote TPU cluster and initialize the TPU 78 hardwares. Example usage: 79 80 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver.connect( 81 ... tpu='') 82 83 It can be viewed as convenient wrapper of the following code: 84 85 >>> resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 86 >>> tf.config.experimental_connect_to_cluster(resolver) 87 >>> tf.tpu.experimental.initialize_tpu_system(resolver) 88 89 Args: 90 tpu: A string corresponding to the TPU to use. It can be the TPU name or 91 TPU worker gRPC address. If not set, it will try automatically resolve 92 the TPU address on Cloud TPUs. 93 zone: Zone where the TPUs are located. If omitted or empty, we will assume 94 that the zone of the TPU is the same as the zone of the GCE VM, which we 95 will try to discover from the GCE metadata service. 96 project: Name of the GCP project containing Cloud TPUs. If omitted or 97 empty, we will try to discover the project name of the GCE VM from the 98 GCE metadata service. 99 100 Returns: 101 An instance of TPUClusterResolver object. 102 103 Raises: 104 NotFoundError: If no TPU devices found in eager mode. 105 """ 106 resolver = TPUClusterResolver(tpu, zone, project) 107 from tensorflow.python.eager import remote # pylint: disable=g-import-not-at-top 108 remote.connect_to_cluster(resolver) 109 from tensorflow.python.tpu import tpu_strategy_util # pylint: disable=g-import-not-at-top 110 tpu_strategy_util.initialize_tpu_system(resolver) 111 return resolver 112 113 @staticmethod 114 def _get_device_dict_and_cores(devices): 115 """Returns a dict of hosts to cores and total cores given devices names. 116 117 Returns a namedtuple with two attributes: 118 device_map: A map of host_ids to a list of core_ids. 119 total_cores: The total number of cores within the TPU system. 120 121 Args: 122 devices: A list of devices returned by session.list_devices() 123 """ 124 device_map = collections.defaultdict(list) 125 num_cores = 0 126 for device in devices: 127 match = _TPU_DEVICE_REGEX.match(device.name) 128 if match: 129 host_id = match.group('host_id') 130 core_id = match.group('core_id') 131 device_map[host_id].append(core_id) 132 num_cores += 1 133 return DeviceDetails(device_map, num_cores) 134 135 @staticmethod 136 def _verify_and_return_same_core_count(device_dict): 137 """Verifies that every device in device_dict has the same # of cores.""" 138 num_cores_per_host_set = ( 139 {len(core_ids) for core_ids in device_dict.values()}) 140 if len(num_cores_per_host_set) != 1: 141 raise RuntimeError('TPU cores on each device is not the same. This ' 142 'should never happen. Devices: {}'.format(device_dict)) 143 return num_cores_per_host_set.pop() 144 145 def __init__(self, 146 tpu=None, 147 zone=None, 148 project=None, 149 job_name='worker', 150 coordinator_name=None, 151 coordinator_address=None, 152 credentials='default', 153 service=None, 154 discovery_url=None): 155 """Creates a new TPUClusterResolver object. 156 157 The ClusterResolver will then use the parameters to query the Cloud TPU APIs 158 for the IP addresses and ports of each Cloud TPU listed. 159 160 Args: 161 tpu: A string corresponding to the TPU to use. It can be the TPU name or 162 TPU worker gRPC address. If not set, it will try automatically resolve 163 the TPU address on Cloud TPUs. If set to "local", it will assume that 164 the TPU is directly connected to the VM instead of over the network. 165 zone: Zone where the TPUs are located. If omitted or empty, we will assume 166 that the zone of the TPU is the same as the zone of the GCE VM, which we 167 will try to discover from the GCE metadata service. 168 project: Name of the GCP project containing Cloud TPUs. If omitted or 169 empty, we will try to discover the project name of the GCE VM from the 170 GCE metadata service. 171 job_name: Name of the TensorFlow job the TPUs belong to. 172 coordinator_name: The name to use for the coordinator. Set to None if the 173 coordinator should not be included in the computed ClusterSpec. 174 coordinator_address: The address of the coordinator (typically an ip:port 175 pair). If set to None, a TF server will be started. If coordinator_name 176 is None, a TF server will not be started even if coordinator_address is 177 None. 178 credentials: GCE Credentials. If None, then we use default credentials 179 from the oauth2client 180 service: The GCE API object returned by the googleapiclient.discovery 181 function. If you specify a custom service object, then the credentials 182 parameter will be ignored. 183 discovery_url: A URL template that points to the location of the discovery 184 service. It should have two parameters {api} and {apiVersion} that when 185 filled in produce an absolute URL to the discovery document for that 186 service. The environment variable 'TPU_API_DISCOVERY_URL' will override 187 this. 188 189 Raises: 190 ImportError: If the googleapiclient is not installed. 191 ValueError: If no TPUs are specified. 192 RuntimeError: If an empty TPU name is specified and this is running in a 193 Google Cloud environment. 194 """ 195 196 if tpu != 'local': 197 # Default Cloud environment 198 self._cloud_tpu_client = client.Client( 199 tpu=tpu, 200 zone=zone, 201 project=project, 202 credentials=credentials, 203 service=service, 204 discovery_url=discovery_url) 205 self._tpu = self._cloud_tpu_client.name() 206 else: 207 # Directly connected TPU environment 208 self._cloud_tpu_client = _LocalCloudTpuClient() 209 self._tpu = 'local' 210 211 # By default the task_type is 'worker` and the task_id is 0 (which is the 212 # first worker in the task). 213 self.task_type = job_name 214 self.task_id = 0 215 self._coordinator_name = coordinator_name 216 if (coordinator_name and not coordinator_address): 217 self._start_local_server() 218 else: 219 self._coordinator_address = coordinator_address 220 221 self._tpu_topology = None 222 223 def __enter__(self): 224 self._cloud_tpu_client.enter() 225 226 def __exit__(self, type, value, traceback): # pylint: disable=redefined-builtin 227 self._cloud_tpu_client.exit(type, value, traceback) 228 229 def master(self, task_type=None, task_id=None, rpc_layer=None): 230 """Get the Master string to be used for the session. 231 232 In the normal case, this returns the grpc path (grpc://1.2.3.4:8470) of 233 first instance in the ClusterSpec returned by the cluster_spec function. 234 235 If a non-TPU name is used when constructing a TPUClusterResolver, that will 236 be returned instead (e.g. If the tpus argument's value when constructing 237 this TPUClusterResolver was 'grpc://10.240.1.2:8470', 238 'grpc://10.240.1.2:8470' will be returned). 239 240 Args: 241 task_type: (Optional, string) The type of the TensorFlow task of the 242 master. 243 task_id: (Optional, integer) The index of the TensorFlow task of the 244 master. 245 rpc_layer: (Optional, string) The RPC protocol TensorFlow should use to 246 communicate with TPUs. 247 248 Returns: 249 string, the connection string to use when creating a session. 250 251 Raises: 252 ValueError: If none of the TPUs specified exists. 253 """ 254 255 if self._tpu != 'local': 256 cluster_spec = self.cluster_spec() 257 if task_type is not None and task_id is not None: 258 # task_type and task_id is from the function parameter 259 master = cluster_spec.task_address(task_type, task_id) 260 elif self.task_type is not None and self.task_id is not None: 261 # task_type and task_id is from the object 262 master = cluster_spec.task_address(self.task_type, self.task_id) 263 else: 264 # by default we take the first item in the cluster with the right name 265 job_tasks = cluster_spec.job_tasks(self.task_type) 266 if not job_tasks: 267 raise ValueError('No TPUs with the specified names exist.') 268 master = job_tasks[0] 269 return cluster_resolver.format_master_url(master, 'grpc') 270 else: 271 return '' 272 273 def get_master(self): 274 return self.master() 275 276 def get_job_name(self): 277 return self.task_type 278 279 def get_tpu_system_metadata(self): 280 """Returns the metadata of the TPU system. 281 282 Users can call this method to get some facts of the TPU system, like 283 total number of cores, number of TPU workers and the devices. E.g. 284 ```python 285 286 resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') 287 tpu_system_metadata = resolver.get_tpu_system_metadata() 288 num_hosts = tpu_system_metadata.num_hosts 289 ``` 290 291 Returns: 292 A `tf.tpu.experimental.TPUSystemMetadata` object. 293 """ 294 cluster_spec = self.cluster_spec() 295 cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None 296 tpu_system_metadata = ( 297 tpu_system_metadata_lib._query_tpu_system_metadata( # pylint: disable=protected-access 298 self.master(), 299 cluster_def=cluster_def, 300 query_topology=False)) 301 302 return tpu_system_metadata 303 304 def cluster_spec(self): 305 """Returns a ClusterSpec object based on the latest TPU information. 306 307 We retrieve the information from the GCE APIs every time this method is 308 called. 309 310 Returns: 311 A ClusterSpec containing host information returned from Cloud TPUs, 312 or None. 313 314 Raises: 315 RuntimeError: If the provided TPU is not healthy. 316 """ 317 ############################################################################ 318 # There are 6 potential cases this code must handle: 319 # 0. [Local case.] When a TPU is connected directly to the VM. 320 # 1. [Normal case.] We should resolve the TPU name to a set of tasks, and 321 # a. Create a ClusterSpec that includes the coordinator job 322 # b. Create a ClusterSpec without the coordinator job. 323 # 2. [GKE / No API Access.] We should not resolve the TPU name to a set of 324 # tasks and 325 # a. Create a ClusterSpec with the coordinator 326 # b. Create a ClusterSpec without the coordinator 327 ############################################################################ 328 329 if self._tpu != 'local': 330 network_endpoints = self._cloud_tpu_client.network_endpoints() 331 worker_list = [ 332 '%s:%s' % (endpoint['ipAddress'], endpoint['port']) 333 for endpoint in network_endpoints 334 ] 335 cluster_spec = {self.task_type: worker_list} 336 if self._coordinator_address: 337 # {1, 2}.a 338 cluster_spec[self._coordinator_name] = [self._coordinator_address] 339 return server_lib.ClusterSpec(cluster_spec) 340 else: 341 return server_lib.ClusterSpec({}) 342 343 def num_accelerators(self, 344 task_type=None, 345 task_id=None, 346 config_proto=None): 347 """Returns the number of TPU cores per worker. 348 349 Connects to the master and list all the devices present in the master, 350 and counts them up. Also verifies that the device counts per host in the 351 cluster is the same before returning the number of TPU cores per host. 352 353 Args: 354 task_type: Unused. 355 task_id: Unused. 356 config_proto: Used to create a connection to a TPU master in order to 357 retrieve the system metadata. 358 359 Raises: 360 RuntimeError: If we cannot talk to a TPU worker after retrying or if the 361 number of TPU devices per host is different. 362 """ 363 if self._tpu == 'local': 364 return { 365 'TPU': 366 len([ 367 d for d in framework_config.list_logical_devices() 368 if d.device_type == 'TPU' 369 ]) 370 } 371 372 retry_count = 1 373 # TODO(b/120564445): Replace with standard library for retries. 374 while True: 375 try: 376 device_details = TPUClusterResolver._get_device_dict_and_cores( 377 cluster_resolver.get_accelerator_devices( 378 self.master(), config_proto=config_proto)) 379 break 380 except errors.DeadlineExceededError: 381 error_message = ('Failed to connect to master. The TPU might not be ' 382 'ready (e.g. still scheduling) or the master ' 383 'address is incorrect: got (%s)' % self.master()) 384 if retry_count <= _TPU_CONN_RETRIES: 385 logging.warning(error_message) 386 logging.warning('Retrying (%d/%d)...', retry_count, _TPU_CONN_RETRIES) 387 retry_count += 1 388 else: 389 raise RuntimeError(error_message) 390 391 if device_details.total_cores: 392 return { 393 'TPU': 394 TPUClusterResolver._verify_and_return_same_core_count( 395 device_details.device_map) 396 } 397 return {'TPU': 0} 398 399 def set_tpu_topology(self, serialized_tpu_topology): 400 """Sets the tpu topology info stored in this resolver.""" 401 self._tpu_topology = topology_pb2.TopologyProto() 402 self._tpu_topology.ParseFromString(serialized_tpu_topology) 403 404 @property 405 def tpu_hardware_feature(self): 406 """Returns the tpu topology info stored.""" 407 if self._tpu_topology is None: 408 return self._tpu_topology 409 return self._tpu_topology.tpu_hardware_feature 410 411 @property 412 def environment(self): 413 """Returns the current environment which TensorFlow is running in.""" 414 return self._environment 415 416 def _start_local_server(self): 417 address = compat.as_text(self._cloud_tpu_client.get_local_ip()) 418 self._server = server_lib.Server({'local': ['0.0.0.0:0']}, 419 protocol='grpc', 420 config=None, 421 start=True) 422 # self._server.target is of the form: grpc://ipaddress:port 423 target = compat.as_bytes(self._server.target) 424 splits = target.split(compat.as_bytes(':')) 425 assert len(splits) == 3, self._server.target 426 assert splits[0] == compat.as_bytes('grpc'), self._server.target 427 self._coordinator_port = compat.as_text(splits[2]) 428 self._coordinator_address = '%s:%s' % ( 429 address, compat.as_text(self._coordinator_port)) 430 431 def __deepcopy__(self, memo): 432 # TODO(b/73668574): Remove this once RunConfig avoids performing deepcopy. 433 return self 434