1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Implementation of the SessionRunHook for preemptible Cloud TPUs.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import logging as _logging 22import os 23import threading 24import time 25 26from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver 27from tensorflow.python.platform import tf_logging as logging 28from tensorflow.python.training import session_run_hook 29 30 31class CloudTPUPreemptedHook(session_run_hook.SessionRunHook): 32 """The SessionRunHook for preemptible Cloud TPUs. 33 34 This is an implementation of SessionRunHook for the pre-emptible Google Cloud 35 TPU service. It attempts to close the session if the TPU is preempted, and 36 exits the coordinator process if the session cannot be closed. 37 """ 38 39 def __init__(self, cluster): 40 self._cluster = cluster 41 42 def after_create_session(self, session, coord): 43 if tpu_cluster_resolver.is_running_in_gce(): 44 self._tpu_poller = _TPUPollingThread(self._cluster, session) 45 self._tpu_poller.start() 46 47 def end(self, session): 48 self._tpu_poller.stop() 49 50 51class _TPUPollingThread(threading.Thread): 52 """A thread that polls the state of a TPU node. 53 54 When the node transitions into a TERMINAL state (PREEMPTED, TERMINATED) 55 that's considered as not recoverable by the underlying infrastructure, 56 it attempts to close the session, and exits the entire process if the 57 session.close() stucks. 58 """ 59 60 def __init__(self, cluster, session): 61 super(_TPUPollingThread, self).__init__() 62 63 self.daemon = True 64 self._running = True 65 self._session_closed = False 66 self._cluster = cluster 67 self._session = session 68 self._interval = 30 69 70 # Some of the Google API libraries are quite chatty, so disable them. 71 for name in ['googleapiclient.discovery', 'oauth2client.client']: 72 _logging.getLogger(name).setLevel(_logging.WARNING) 73 74 def stop(self): 75 self._running = False 76 self._session_closed = True 77 self.join() 78 79 def run(self): 80 if not tpu_cluster_resolver.is_running_in_gce(): 81 logging.warning( 82 'TPUPollingThread is running in a non-GCE environment, exiting...') 83 self._running = False 84 return 85 86 while self._running: 87 recoverable = self._cluster._cloud_tpu_client.recoverable() # pylint: disable=protected-access 88 if not recoverable: 89 logging.warning( 90 'TPUPollingThread found TPU %s in state %s', 91 self._cluster._tpu, self._cluster._cloud_tpu_client.state()) # pylint: disable=protected-access 92 os._exit(1) # pylint: disable=protected-access 93 time.sleep(self._interval) 94