• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the 'License');
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an 'AS IS' BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ======================================
15"""Operations for handling session logging and shutdown notifications."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import threading
22
23import time
24from google.protobuf import text_format
25
26from tensorflow.core.protobuf import config_pb2
27from tensorflow.core.util import event_pb2
28from tensorflow.python.client import session as session_lib
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import errors
31from tensorflow.python.framework import ops
32from tensorflow.python.ops import array_ops
33from tensorflow.python.platform import tf_logging as logging
34from tensorflow.python.tpu.ops import tpu_ops
35from tensorflow.python.training import session_run_hook
36from tensorflow.python.training import training_util
37
38_WATCHDOG = None
39
40
41class CoordinatorShutdownException(Exception):
42  """Raised when the coordinator needs to shutdown."""
43  pass
44
45
46def _clone_session(session, graph=None):
47  return session_lib.Session(
48      target=session.sess_str,
49      config=session._config,  # pylint: disable=protected-access
50      graph=graph if graph else session.graph)
51
52
53class WorkerHeartbeatManager(object):
54  """Manages the status/heartbeat monitor for a set of workers."""
55
56  def __init__(self, session, devices, heartbeat_ops, request_placeholder):
57    """Construct a new WorkerHeartbeatManager.
58
59    (Prefer using `WorkerHeartbeatManager.from_devices` when possible.)
60
61    Args:
62      session: `tf.Session`, session to use for heartbeat operations.
63      devices: `list[string]` Set of devices to connect to.
64      heartbeat_ops: `list[tf.Operation]` Heartbeat operations.
65      request_placeholder: `tf.Placeholder[String]` Placeholder used to specify
66        the WorkerHeartbeatRequest protocol buffer.
67    """
68    self._session = session
69    self._devices = devices
70    self._ops = heartbeat_ops
71    self._request_placeholder = request_placeholder
72
73  @staticmethod
74  def from_devices(session, devices):
75    """Construct a heartbeat manager for the given devices."""
76    if not devices:
77      logging.error('Trying to create heartbeat manager with no devices?')
78
79    logging.info('Creating heartbeat manager for %s', devices)
80    request_placeholder = array_ops.placeholder(
81        name='worker_heartbeat_request', dtype=dtypes.string)
82
83    heartbeat_ops = []
84    for device in devices:
85      with ops.device(device):
86        heartbeat_ops.append(tpu_ops.worker_heartbeat(request_placeholder))
87
88    return WorkerHeartbeatManager(session, devices, heartbeat_ops,
89                                  request_placeholder)
90
91  def num_workers(self):
92    return len(self._devices)
93
94  def configure(self, message):
95    """Configure heartbeat manager for all devices.
96
97    Args:
98      message: `event_pb2.WorkerHeartbeatRequest`
99    Returns: `None`
100    """
101    logging.info('Configuring worker heartbeat: %s',
102                 text_format.MessageToString(message))
103    self._session.run(self._ops,
104                      {self._request_placeholder: message.SerializeToString()})
105
106  def ping(self, request=None, timeout_in_ms=5000):
107    """Ping all workers, returning the parsed status results."""
108    if request is None:
109      request = event_pb2.WorkerHeartbeatRequest()
110
111    options = config_pb2.RunOptions(timeout_in_ms=timeout_in_ms)
112    results = self._session.run(
113        self._ops,
114        feed_dict={self._request_placeholder: request.SerializeToString()},
115        options=options)
116    parsed_results = [
117        event_pb2.WorkerHeartbeatResponse.FromString(res_pb)
118        for res_pb in results
119    ]
120    logging.debug('Ping results: %s', parsed_results)
121    return parsed_results
122
123  def lame_workers(self):
124    """Ping all workers, returning manager containing lame workers (or None)."""
125    ping_results = self.ping()
126    lame_workers = []
127
128    for ping_response, device, op in zip(ping_results, self._devices,
129                                         self._ops):
130      if ping_response.health_status != event_pb2.OK:
131        lame_workers.append((device, op))
132
133    if not lame_workers:
134      return None
135
136    bad_devices, bad_ops = zip(*lame_workers)
137    return WorkerHeartbeatManager(self._session, bad_devices, bad_ops,
138                                  self._request_placeholder)
139
140  def __repr__(self):
141    return 'HeartbeatManager(%s)' % ','.join(self._devices)
142
143  def shutdown(self, timeout_ms=10000):
144    """Shutdown all workers after `shutdown_timeout_secs`."""
145    logging.info('Shutting down %s.', self)
146    req = event_pb2.WorkerHeartbeatRequest(
147        watchdog_config=event_pb2.WatchdogConfig(timeout_ms=timeout_ms),
148        shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR)
149    self.configure(req)
150
151    # Wait for workers to shutdown.  This isn't strictly required
152    # but it avoids triggering multiple checkpoints with the same lame worker.
153    logging.info('Waiting %dms for worker shutdown.', timeout_ms)
154    time.sleep(timeout_ms / 1000)
155
156
157def all_worker_devices(session):
158  """Return a list of devices for each worker in the system."""
159  devices = session.list_devices()
160
161  devices_that_support_heartbeats = []
162
163  for device in devices:
164    name = device.name
165    # Pick devices that have a TPU but target the attached CPU
166    if ':TPU:0' in name and 'coordinator' not in name:
167      devices_that_support_heartbeats.append(name.replace('TPU', 'CPU'))
168
169  return devices_that_support_heartbeats
170
171
172class WatchdogManager(threading.Thread):
173  """Configures worker watchdog timer and handles periodic pings.
174
175  Usage:
176    # Ping workers every minute, shutting down workers if they haven't received
177    # a ping after 1 hour.
178    watchdog_manager = WatchdogManager(
179      ping_interval=60, shutdown_timeout=3600
180    )
181
182    # Use as a context manager, resetting watchdog on context exit:
183    with watchdog_manager:
184      session.run(...)
185
186    # Or setup globally; watchdog will remain active until program exit.
187    watchdog_manager.configure_and_run()
188  """
189
190  def __init__(self,
191               session,
192               devices=None,
193               ping_interval=60,
194               shutdown_timeout=3600):
195    """Initialize a watchdog manager.
196
197    Args:
198      session: Session connected to worker devices.  A cloned session and graph
199        will be created for managing worker pings.
200      devices: Set of devices to monitor.  If none, all workers will be
201        monitored.
202      ping_interval: Time, in seconds, between watchdog pings.
203      shutdown_timeout: Time, in seconds, before watchdog timeout.
204    """
205    threading.Thread.__init__(self)
206    self.ping_interval = ping_interval
207    self.shutdown_timeout = shutdown_timeout
208    self.daemon = True
209    self._config = session._config  # pylint: disable=protected-access
210    self._target = session.sess_str
211    self._running = False
212    self._devices = devices
213
214    self._graph = None
215    self._session = None
216    self._worker_manager = None
217
218  def _reset_manager(self):
219    """Reset the graph, session and worker manager."""
220    self._graph = ops.Graph()
221    self._session = session_lib.Session(
222        target=self._target,
223        graph=self._graph,
224        config=self._config,
225    )
226
227    if self._devices is None:
228      self._devices = all_worker_devices(self._session)
229
230    with self._graph.as_default():
231      self._worker_manager = WorkerHeartbeatManager.from_devices(
232          self._session, self._devices)
233
234    self._worker_manager.configure(
235        event_pb2.WorkerHeartbeatRequest(
236            watchdog_config=event_pb2.WatchdogConfig(
237                timeout_ms=self.shutdown_timeout * 1000,),
238            shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR))
239
240  def configure_and_run(self):
241    logging.info(
242        'Enabling watchdog timer with %d second timeout '
243        'and %d second ping interval.', self.shutdown_timeout,
244        self.ping_interval)
245    self._reset_manager()
246    self._running = True
247    self.start()
248
249  def stop(self):
250    logging.info('Stopping worker watchdog.')
251    self._worker_manager.configure(
252        event_pb2.WorkerHeartbeatRequest(
253            watchdog_config=event_pb2.WatchdogConfig(timeout_ms=-1,),
254            shutdown_mode=event_pb2.NOT_CONFIGURED))
255    self._running = False
256    self.join()
257
258  def __enter__(self):
259    self.configure_and_run()
260
261  def __exit__(self, exc_type, exc_val, exc_tb):
262    self.stop()
263
264  def run(self):
265    # Don't fetch logs or adjust timing: just ping the watchdog.
266    #
267    # If we hit an exception, reset our session as it is likely broken.
268    while self._running:
269      try:
270        self._worker_manager.ping(request=None)
271        time.sleep(self.ping_interval)
272      except errors.OpError as e:
273        # Catch any TF errors that occur so we don't stop sending heartbeats
274        logging.debug('Caught error while sending heartbeat: %s', e)
275        self._reset_manager()
276
277
278def start_worker_watchdog(session,
279                          devices=None,
280                          ping_interval=60,
281                          shutdown_timeout=3600):
282  """Start global worker watchdog to shutdown workers on coordinator exit."""
283  global _WATCHDOG
284  if _WATCHDOG is None:
285    # Ensure we can send a few pings before we timeout!
286    ping_interval = min(shutdown_timeout / 10., ping_interval)
287    _WATCHDOG = WatchdogManager(session, devices, ping_interval,
288                                shutdown_timeout)
289    _WATCHDOG.configure_and_run()
290
291
292class GracefulShutdownHook(session_run_hook.SessionRunHook):
293  """Session hook that watches for shutdown events.
294
295  If a shutdown is indicated, `saver.save(checkpoint_prefix)` is executed, and a
296  SystemShutdown exception is raised to terminate the main session.  If `saver`
297  is None the `SAVERS` collection will be read to find a saver.
298
299  `on_shutdown_hooks` is an optional list of functions that should be called
300  after checkpointing.  The function is called with (`run_context`,
301  `all_workers`, `lame_workers`).
302
303  If `heartbeat_group` is not specified, it will default to all CPU workers
304  in the system.
305  """
306
307  def __init__(self, checkpoint_prefix, saver=None, on_shutdown_hooks=None):
308    self._saver = saver
309    self._checkpoint_prefix = checkpoint_prefix
310    self._on_shutdown_hooks = on_shutdown_hooks if on_shutdown_hooks else []
311
312    # Worker heartbeats are managed independently of the main training graph.
313    self._graph = ops.Graph()
314    self._workers = None
315    self._session = None
316    self._heartbeat_supported = False
317
318  def after_create_session(self, training_session, coord):  # pylint: disable=unused-argument
319    # N.B. We have to pull the global step here to avoid it being unavailable
320    # at checkpoint time; the graph has been frozen at that point.
321    if training_util.get_global_step() is None and self.saver() is not None:
322      raise ValueError(
323          'Saver defined but no global step.  Run `get_or_create_global_step()`'
324          ' in your model definition to allow checkpointing.')
325
326    with self._graph.as_default():
327      logging.info('Installing graceful shutdown hook.')
328      self._session = _clone_session(training_session, self._graph)
329      self._workers = WorkerHeartbeatManager.from_devices(
330          self._session, all_worker_devices(self._session))
331      self._heartbeat_supported = self._workers.num_workers() > 0
332      if self._heartbeat_supported:
333        try:
334          self._workers.configure(
335              event_pb2.WorkerHeartbeatRequest(
336                  shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR))
337        except errors.InvalidArgumentError:
338          logging.warn(
339              'TPU device does not support heartbeats. Failure '
340              'handling will be disabled.')
341          self._heartbeat_supported = False
342      else:
343        logging.warn(
344            'No workers support hearbeats. Failure handling will be disabled.')
345
346  def saver(self):
347    if self._saver:
348      return self._saver
349
350    savers = ops.get_collection(ops.GraphKeys.SAVERS)
351    if not savers:
352      return None
353
354    if not isinstance(savers, list):
355      return savers
356
357    if len(savers) > 1:
358      logging.error(
359          'Multiple savers in the SAVERS collection.  On-demand checkpointing '
360          'will be disabled. Pass an explicit `saver` to the constructor to '
361          'override this behavior.')
362      return None
363
364    return savers[0]
365
366  def after_run(self, run_context, run_values):
367    del run_values
368
369    if not self._heartbeat_supported:
370      return
371
372    lame_workers = self._workers.lame_workers()
373    if lame_workers:
374      logging.info('ShutdownHook: lame workers found: %s', lame_workers)
375
376      if self.saver():
377        logging.info('ShutdownHook: saving checkpoint to %s',
378                     self._checkpoint_prefix)
379        self.saver().save(
380            run_context.session,
381            self._checkpoint_prefix,
382            global_step=training_util.get_global_step(),
383            write_state=True,
384        )
385      else:
386        logging.info('ShutdownHook: no Saver defined.')
387
388      for fn in self._on_shutdown_hooks:
389        fn(run_context, self._workers, lame_workers)
390
391
392class RestartComputation(object):
393  """Restart the entire computation.
394
395  This hook shuts down all workers and returns control to the top-level by
396  throwing a CoordinatorShutdownException.
397  """
398
399  def __init__(self, timeout_ms=10000):
400    self.timeout_ms = timeout_ms
401
402  def __call__(self, run_context, all_workers, lame_workers):
403    del run_context, lame_workers
404    all_workers.shutdown(timeout_ms=self.timeout_ms)
405
406    logging.info('Terminating coordinator.')
407    raise CoordinatorShutdownException()
408
409
410class ShutdownLameWorkers(object):
411  """Shutdown lamed workers.
412
413  Processing will continue normally (typically by waiting for the down
414  workers to be restarted).
415  """
416
417  def __init__(self, timeout_ms=10000):
418    self.timeout_in_ms = timeout_ms
419
420  def __call__(self, run_context, all_workers, lame_workers):
421    lame_workers.shutdown(timeout_ms=self.timeout_in_ms)
422