• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implementation of the SessionRunHook for preemptible Cloud TPUs."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import logging as _logging
22import os
23import threading
24import time
25
26from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
27from tensorflow.python.platform import tf_logging as logging
28from tensorflow.python.training import session_run_hook
29
30
31class CloudTPUPreemptedHook(session_run_hook.SessionRunHook):
32  """The SessionRunHook for preemptible Cloud TPUs.
33
34  This is an implementation of SessionRunHook for the pre-emptible Google Cloud
35  TPU service. It attempts to close the session if the TPU is preempted, and
36  exits the coordinator process if the session cannot be closed.
37  """
38
39  def __init__(self, cluster):
40    self._cluster = cluster
41
42  def after_create_session(self, session, coord):
43    if tpu_cluster_resolver.is_running_in_gce():
44      self._tpu_poller = _TPUPollingThread(self._cluster, session)
45      self._tpu_poller.start()
46
47  def end(self, session):
48    self._tpu_poller.stop()
49
50
51class _TPUPollingThread(threading.Thread):
52  """A thread that polls the state of a TPU node.
53
54  When the node transitions into a TERMINAL state (PREEMPTED, TERMINATED)
55  that's considered as not recoverable by the underlying infrastructure,
56  it attempts to close the session, and exits the entire process if the
57  session.close() stucks.
58  """
59
60  def __init__(self, cluster, session):
61    super(_TPUPollingThread, self).__init__()
62
63    self.daemon = True
64    self._running = True
65    self._session_closed = False
66    self._cluster = cluster
67    self._session = session
68    self._interval = 30
69
70    # Some of the Google API libraries are quite chatty, so disable them.
71    for name in ['googleapiclient.discovery', 'oauth2client.client']:
72      _logging.getLogger(name).setLevel(_logging.WARNING)
73
74  def stop(self):
75    self._running = False
76    self._session_closed = True
77    self.join()
78
79  def run(self):
80    if not tpu_cluster_resolver.is_running_in_gce():
81      logging.warning(
82          'TPUPollingThread is running in a non-GCE environment, exiting...')
83      self._running = False
84      return
85
86    while self._running:
87      recoverable = self._cluster._cloud_tpu_client.recoverable()  # pylint: disable=protected-access
88      if not recoverable:
89        logging.warning(
90            'TPUPollingThread found TPU %s in state %s',
91            self._cluster._tpu, self._cluster._cloud_tpu_client.state())  # pylint: disable=protected-access
92        os._exit(1)  # pylint: disable=protected-access
93      time.sleep(self._interval)
94