1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# Lint as: python3 16"""Cloud TPU Client.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import datetime 23import json 24import logging 25import os 26import time 27 28from absl import flags 29from concurrent import futures 30from six.moves.urllib import request 31from six.moves.urllib.error import HTTPError 32 33_GOOGLE_API_CLIENT_INSTALLED = True 34try: 35 from googleapiclient import discovery # pylint: disable=g-import-not-at-top 36 from oauth2client import client # pylint: disable=g-import-not-at-top 37except ImportError: 38 _GOOGLE_API_CLIENT_INSTALLED = False 39 40FLAGS = flags.FLAGS 41 42flags.DEFINE_bool('runtime_oom_exit', True, 43 'Exit the script when the TPU runtime is OOM.') 44flags.DEFINE_bool('hbm_oom_exit', True, 45 'Exit the script when the TPU HBM is OOM.') 46 47_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' 48_ENDPOINTS_SEPARATOR = ',' 49_DEFAULT_ENV_VARIABLE = 'TPU_NAME' 50_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL' 51_GCE_METADATA_URL_ENV_VARIABLE = 'GCE_METADATA_IP' 52_DEFAULT_ENDPOINT_PORT = '8470' 53_OOM_EVENT_COOL_TIME_SEC = 90 54_VERSION_SWITCHER_ENDPOINT = 'http://{}:8475/requestversion' 55 56 57def _utcnow(): 58 """A wrapper function around datetime.datetime.utcnow. 59 60 This function is created for unit testing purpose. It's not easy to do 61 StubOutWithMock with datetime.datetime package. 62 63 Returns: 64 datetime.datetime 65 """ 66 return datetime.datetime.utcnow() 67 68 69def _environment_discovery_url(): 70 return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE) 71 72 73def _gce_metadata_endpoint(): 74 return 'http://' + os.environ.get(_GCE_METADATA_URL_ENV_VARIABLE, 75 'metadata.google.internal') 76 77 78def _request_compute_metadata(path): 79 req = request.Request( 80 '%s/computeMetadata/v1/%s' % (_gce_metadata_endpoint(), path), 81 headers={'Metadata-Flavor': 'Google'}) 82 resp = request.urlopen(req) 83 return _as_text(resp.read()) 84 85 86def _environment_var_to_network_endpoints(endpoints): 87 """Yields a dict with ip address and port.""" 88 for endpoint in endpoints.split(','): 89 grpc_prefix = 'grpc://' 90 if endpoint.startswith(grpc_prefix): 91 endpoint = endpoint.split(grpc_prefix)[1] 92 parts = endpoint.split(':') 93 ip_address = parts[0] 94 port = _DEFAULT_ENDPOINT_PORT 95 if len(parts) > 1: 96 port = parts[1] 97 yield { 98 'ipAddress': ip_address, 99 'port': port 100 } 101 102 103def _get_tpu_name(tpu): 104 if tpu: 105 return tpu 106 107 for e in [_GKE_ENV_VARIABLE, _DEFAULT_ENV_VARIABLE]: 108 if e in os.environ: 109 return os.environ[e] 110 return None 111 112 113def _as_text(s): 114 if isinstance(s, bytes): 115 return s.decode('utf-8') 116 return s 117 118 119class Client(object): 120 """Client for working with the Cloud TPU API. 121 122 This client is intended to be used for resolving tpu name to ip addresses. 123 124 It's recommended to use this library as a contextlib to utilize all 125 functionality. 126 """ 127 128 def __init__(self, 129 tpu=None, 130 zone=None, 131 project=None, 132 credentials='default', 133 service=None, 134 discovery_url=None): 135 if isinstance(tpu, list): 136 if not tpu: 137 raise ValueError('At least one TPU must be specified.') 138 if len(tpu) != 1: 139 raise NotImplementedError( 140 'Using multiple TPUs in a single session is not yet implemented') 141 tpu = tpu[0] 142 143 tpu = _get_tpu_name(tpu) 144 145 if tpu is None: 146 raise ValueError('Please provide a TPU Name to connect to.') 147 148 self._tpu = _as_text(tpu) 149 150 self._use_api = not self._tpu.startswith('grpc://') 151 self._service = service 152 153 self._credentials = None 154 self._project = None 155 self._zone = None 156 self._discovery_url = None 157 if self._use_api: 158 if credentials != 'default': 159 self._credentials = credentials 160 # Automatically detect project and zone if unspecified. 161 if project: 162 self._project = _as_text(project) 163 else: 164 self._project = _request_compute_metadata('project/project-id') 165 if zone: 166 self._zone = _as_text(zone) 167 else: 168 zone_path = _request_compute_metadata('instance/zone') 169 self._zone = zone_path.split('/')[-1] 170 self._discovery_url = _environment_discovery_url() or discovery_url 171 172 def _symptom_msg(self, msg): 173 """Return the structured Symptom message.""" 174 return 'Symptom: ' + msg 175 176 def _oom_event(self, symptoms): 177 """Check if a runtime OOM event is reported.""" 178 if not symptoms: 179 return False 180 for symptom in reversed(symptoms): 181 if symptom['symptomType'] != 'OUT_OF_MEMORY': 182 continue 183 oom_datetime_str = symptom['createTime'].split('.')[0] 184 oom_datetime = datetime.datetime.strptime(oom_datetime_str, 185 '%Y-%m-%dT%H:%M:%S') 186 time_diff = _utcnow() - oom_datetime 187 if time_diff < datetime.timedelta(seconds=_OOM_EVENT_COOL_TIME_SEC): 188 logging.warning(self._symptom_msg( 189 'a recent runtime OOM has occured ~{} seconds ago. The model ' 190 'script will terminate automatically. To prevent future OOM ' 191 'events, please consider reducing the model size. To disable this ' 192 'behavior, set flag --runtime_oom_exit=false when starting the ' 193 'script.'.format(time_diff.seconds))) 194 return True 195 return False 196 197 def _hbm_oom_event(self, symptoms): 198 """Check if a HBM OOM event is reported.""" 199 if not symptoms: 200 return False 201 for symptom in reversed(symptoms): 202 if symptom['symptomType'] != 'HBM_OUT_OF_MEMORY': 203 continue 204 oom_datetime_str = symptom['createTime'].split('.')[0] 205 oom_datetime = datetime.datetime.strptime(oom_datetime_str, 206 '%Y-%m-%dT%H:%M:%S') 207 time_diff = _utcnow() - oom_datetime 208 if time_diff < datetime.timedelta(seconds=_OOM_EVENT_COOL_TIME_SEC): 209 logging.warning(self._symptom_msg( 210 'a recent HBM OOM has occured ~{} seconds ago. The model ' 211 'script will terminate automatically. To prevent future HBM OOM ' 212 'events, please consider reducing the model size. To disable this ' 213 'behavior, set flag --hbm_oom_exit=false when starting the ' 214 'script.'.format(time_diff.seconds))) 215 return True 216 return False 217 218 def _tpu_service(self): 219 """Creates a new Cloud TPU API object. 220 221 This works around an issue where the underlying HTTP connection sometimes 222 times out when the script has been running for too long. Other methods in 223 this object call this method to get a new API object whenever they need 224 to communicate with the Cloud API. 225 226 Raises: 227 RuntimeError: If the dependent Python packages are missing. 228 229 Returns: 230 A Google Cloud TPU API object. 231 """ 232 if self._service: 233 return self._service 234 235 if not _GOOGLE_API_CLIENT_INSTALLED: 236 raise RuntimeError('Missing runtime dependency on the Google API client. ' 237 'Run `pip install cloud-tpu-client` to fix.') 238 239 credentials = self._credentials 240 if credentials is None or credentials == 'default': 241 credentials = client.GoogleCredentials.get_application_default() 242 243 if self._discovery_url: 244 return discovery.build( 245 'tpu', 246 'v1', 247 credentials=credentials, 248 discoveryServiceUrl=self._discovery_url, 249 cache_discovery=False) 250 else: 251 return discovery.build( 252 'tpu', 'v1', credentials=credentials, cache_discovery=False) 253 254 def _full_name(self): 255 """Returns the full Cloud name for this TPU.""" 256 return 'projects/%s/locations/%s/nodes/%s' % ( 257 self._project, self._zone, self._tpu) 258 259 def _fetch_cloud_tpu_metadata(self): 260 """Returns the TPU metadata object from the TPU Get API call.""" 261 service = self._tpu_service() 262 try: 263 r = service.projects().locations().nodes().get(name=self._full_name()) 264 return r.execute() 265 except Exception as e: 266 raise ValueError("Could not lookup TPU metadata from name '%s'. Please " 267 'doublecheck the tpu argument in the TPUClusterResolver ' 268 'constructor. Exception: %s' % (self._tpu, e)) 269 270 def _get_tpu_property(self, key): 271 if self._use_api: 272 metadata = self._fetch_cloud_tpu_metadata() 273 return metadata.get(key) 274 275 return None 276 277 def __enter__(self): 278 self._open = True 279 280 def __exit__(self, type, value, traceback): # pylint: disable=redefined-builtin 281 del type, value, traceback 282 283 def recoverable(self): 284 """Returns true if the TPU is in a state where training should eventually resume. 285 286 If false the TPU is in a unrecoverable state and should be recreated. 287 """ 288 state = self.state() 289 symptoms = self.symptoms() 290 if state and state in ['TERMINATED', 'PREEMPTED']: 291 return False 292 elif FLAGS.runtime_oom_exit and self._oom_event(symptoms): 293 return False 294 elif FLAGS.hbm_oom_exit and self._hbm_oom_event(symptoms): 295 return False 296 return True 297 298 def symptoms(self): 299 """Return Cloud TPU Symptoms of the TPU.""" 300 return self._get_tpu_property('symptoms') 301 302 def state(self): 303 """Return state of the TPU.""" 304 return self._get_tpu_property('state') 305 306 def health(self): 307 """Return health of the TPU.""" 308 return self._get_tpu_property('health') 309 310 def runtime_version(self): 311 """Return runtime version of the TPU.""" 312 313 if not self._use_api: 314 # Fallback on getting version directly from TPU. 315 url = _VERSION_SWITCHER_ENDPOINT.format( 316 self.network_endpoints()[0]['ipAddress']) 317 try: 318 req = request.Request(url) 319 resp = request.urlopen(req) 320 version_details = json.loads(resp.read()) 321 return version_details.get('currentVersion') 322 except HTTPError as e: 323 status_code = e.code 324 if status_code == 404: 325 return None 326 else: 327 raise e 328 return self._get_tpu_property('tensorflowVersion') 329 330 def accelerator_type(self): 331 """Return accelerator type of the TPU.""" 332 return self._get_tpu_property('acceleratorType') 333 334 def api_available(self): 335 """Return if the Cloud TPU API is available, if not certain features will not work.""" 336 return self._use_api 337 338 def name(self): 339 """Return the name of the tpu, or the ip address if name is not provided.""" 340 return self._tpu 341 342 def get_local_ip(self): 343 """Return the local ip address of the Google Cloud VM the workload is running on.""" 344 return _request_compute_metadata('instance/network-interfaces/0/ip') 345 346 def network_endpoints(self): 347 """Return a list of tpu endpoints.""" 348 if not self._use_api: 349 return list(_environment_var_to_network_endpoints(self._tpu)) 350 response = self._fetch_cloud_tpu_metadata() 351 352 if response.get('state') != 'READY': 353 raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' % 354 (self._tpu, response.get('state'))) 355 if 'networkEndpoints' in response: 356 return response['networkEndpoints'] 357 else: 358 return [{'ipAddress': response['ipAddress'], 'port': response['port']}] 359 360 def wait_for_healthy(self, timeout_s=1200, interval=30): 361 """Wait for TPU to become healthy or raise error if timeout reached. 362 363 Args: 364 timeout_s (int): The timeout in seconds for waiting TPU to become healthy. 365 interval (int): The interval in seconds to poll the TPU for health. 366 367 Raises: 368 RuntimeError: If the TPU doesn't become healthy by the timeout. 369 """ 370 timeout = time.time() + timeout_s 371 while self.health() != 'HEALTHY': 372 logging.warning( 373 ('Waiting for TPU "%s" with state "%s" ' 374 'and health "%s" to become healthy'), 375 self.name(), self.state(), self.health()) 376 if time.time() + interval > timeout: 377 raise RuntimeError( 378 'Timed out waiting for TPU "%s" to become healthy' % self.name()) 379 time.sleep(interval) 380 381 logging.warning('TPU "%s" is healthy.', self.name()) 382 383 def configure_tpu_version(self, version, restart_type='always'): 384 """Configure TPU software version. 385 386 Args: 387 version (string): Version of software to configure the TPU with. 388 restart_type (string): Restart behaviour when switching versions, 389 defaults to always restart. Options are 'always', 'ifNeeded'. 390 391 """ 392 393 def configure_worker(worker): 394 """Configure individual TPU worker. 395 396 Args: 397 worker: A dict with the field ipAddress where the configure request will 398 be sent. 399 """ 400 ip_address = worker['ipAddress'] 401 url = (_VERSION_SWITCHER_ENDPOINT + '/{}?restartType={}').format( 402 ip_address, version, restart_type) 403 req = request.Request(url, data=b'') 404 try: 405 request.urlopen(req) 406 except HTTPError as e: 407 status_code = e.code 408 if status_code == 404: 409 raise Exception( 410 'Tensorflow version {} is not available on Cloud TPU, ' 411 'try a previous nightly version or refer to ' 412 'https://cloud.google.com/tpu/docs/release-notes for ' 413 'the latest official version.'.format(version)) 414 else: 415 raise Exception('Failed to configure worker {}'.format(ip_address)) 416 417 workers = self.network_endpoints() 418 419 with futures.ThreadPoolExecutor(max_workers=len(workers)) as executor: 420 results = executor.map(configure_worker, workers) 421 for result in results: 422 if result: 423 result.result() 424