• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the 'License');
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an 'AS IS' BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ======================================
15"""Operations for handling session logging and shutdown notifications."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import threading
22
23import time
24from google.protobuf import text_format
25
26from tensorflow.core.protobuf import config_pb2
27from tensorflow.core.util import event_pb2
28from tensorflow.python.client import session as session_lib
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import errors
31from tensorflow.python.framework import ops
32from tensorflow.python.ops import array_ops
33from tensorflow.python.platform import tf_logging as logging
34from tensorflow.python.tpu.ops import tpu_ops
35from tensorflow.python.training import session_run_hook
36from tensorflow.python.training import training_util
37
38_WATCHDOG = None
39
40
41class CoordinatorResetError(errors.AbortedError):
42  """Raised when the monitored session should reset."""
43
44  def __init__(self):
45    errors.AbortedError.__init__(
46        self, None, None, 'Resetting session loop due to worker shutdown.')
47
48
49def _clone_session(session, graph=None):
50  return session_lib.Session(
51      target=session.sess_str,
52      config=session._config,  # pylint: disable=protected-access
53      graph=graph if graph else session.graph)
54
55
56class WorkerHeartbeatManager(object):
57  """Manages the status/heartbeat monitor for a set of workers."""
58
59  def __init__(self, session, devices, heartbeat_ops, request_placeholder):
60    """Construct a new WorkerHeartbeatManager.
61
62    (Prefer using `WorkerHeartbeatManager.from_devices` when possible.)
63
64    Args:
65      session: `tf.compat.v1.Session`, session to use for heartbeat operations.
66      devices: `list[string]` Set of devices to connect to.
67      heartbeat_ops: `list[tf.Operation]` Heartbeat operations.
68      request_placeholder: `tf.Placeholder[String]` Placeholder used to specify
69        the WorkerHeartbeatRequest protocol buffer.
70    """
71    self._session = session
72    self._devices = devices
73    self._ops = heartbeat_ops
74    self._request_placeholder = request_placeholder
75
76  @staticmethod
77  def from_devices(session, devices):
78    """Construct a heartbeat manager for the given devices."""
79    if not devices:
80      logging.error('Trying to create heartbeat manager with no devices?')
81
82    logging.info('Creating heartbeat manager for %s', devices)
83    request_placeholder = array_ops.placeholder(
84        name='worker_heartbeat_request', dtype=dtypes.string)
85
86    heartbeat_ops = []
87    for device in devices:
88      with ops.device(device):
89        heartbeat_ops.append(tpu_ops.worker_heartbeat(request_placeholder))
90
91    return WorkerHeartbeatManager(session, devices, heartbeat_ops,
92                                  request_placeholder)
93
94  def num_workers(self):
95    return len(self._devices)
96
97  def configure(self, message):
98    """Configure heartbeat manager for all devices.
99
100    Args:
101      message: `event_pb2.WorkerHeartbeatRequest`
102    Returns: `None`
103    """
104    logging.info('Configuring worker heartbeat: %s',
105                 text_format.MessageToString(message))
106    self._session.run(self._ops,
107                      {self._request_placeholder: message.SerializeToString()})
108
109  def ping(self, request=None, timeout_in_ms=5000):
110    """Ping all workers, returning the parsed status results."""
111    if request is None:
112      request = event_pb2.WorkerHeartbeatRequest()
113
114    options = config_pb2.RunOptions(timeout_in_ms=timeout_in_ms)
115    results = self._session.run(
116        self._ops,
117        feed_dict={self._request_placeholder: request.SerializeToString()},
118        options=options)
119    parsed_results = [
120        event_pb2.WorkerHeartbeatResponse.FromString(res_pb)
121        for res_pb in results
122    ]
123    logging.debug('Ping results: %s', parsed_results)
124    return parsed_results
125
126  def lame_workers(self):
127    """Ping all workers, returning manager containing lame workers (or None)."""
128    ping_results = self.ping()
129    lame_workers = []
130
131    for ping_response, device, op in zip(ping_results, self._devices,
132                                         self._ops):
133      if ping_response.health_status != event_pb2.OK:
134        lame_workers.append((device, op))
135
136    if not lame_workers:
137      return None
138
139    bad_devices, bad_ops = zip(*lame_workers)
140    return WorkerHeartbeatManager(self._session, bad_devices, bad_ops,
141                                  self._request_placeholder)
142
143  def __repr__(self):
144    return 'HeartbeatManager(%s)' % ','.join(self._devices)
145
146  # Default timeout is set to allow other shutdown triggered operations (log
147  # flushing etc) to finish before terminating the worker.
148  def shutdown(self, wait_time_in_ms=60000, exit_code=None):
149    """Shutdown all workers after `shutdown_timeout_secs`."""
150    logging.info('Shutting down %s.', self)
151    req = event_pb2.WorkerHeartbeatRequest(
152        watchdog_config=event_pb2.WatchdogConfig(timeout_ms=wait_time_in_ms),
153        shutdown_mode=event_pb2.SHUTDOWN_AFTER_TIMEOUT,
154        exit_code=event_pb2.RequestedExitCode(
155            exit_code=exit_code) if exit_code is not None else None)
156    self.configure(req)
157
158    # Wait for workers to shutdown.
159    sleep_sec = 10.0 + wait_time_in_ms / 1000
160    logging.info('Waiting %.2f seconds for worker shutdown.', sleep_sec)
161    time.sleep(sleep_sec)
162
163
164def all_worker_devices(session):
165  """Return a list of devices for each worker in the system."""
166  devices = session.list_devices()
167
168  devices_that_support_heartbeats = []
169
170  for device in devices:
171    name = device.name
172    # Pick devices that have a TPU but target the attached CPU
173    if ':TPU:0' in name and 'coordinator' not in name:
174      devices_that_support_heartbeats.append(name.replace('TPU', 'CPU'))
175
176  return devices_that_support_heartbeats
177
178
179class WatchdogManager(threading.Thread):
180  """Configures worker watchdog timer and handles periodic pings.
181
182  Usage:
183    # Ping workers every minute, shutting down workers if they haven't received
184    # a ping after 1 hour.
185    watchdog_manager = WatchdogManager(
186      ping_interval=60, shutdown_timeout=3600
187    )
188
189    # Use as a context manager, resetting watchdog on context exit:
190    with watchdog_manager:
191      session.run(...)
192
193    # Or setup globally; watchdog will remain active until program exit.
194    watchdog_manager.configure_and_run()
195  """
196
197  def __init__(self,
198               session,
199               devices=None,
200               ping_interval=60,
201               shutdown_timeout=3600):
202    """Initialize a watchdog manager.
203
204    Args:
205      session: Session connected to worker devices.  A cloned session and graph
206        will be created for managing worker pings.
207      devices: Set of devices to monitor.  If none, all workers will be
208        monitored.
209      ping_interval: Time, in seconds, between watchdog pings.
210      shutdown_timeout: Time, in seconds, before watchdog timeout.
211    """
212    threading.Thread.__init__(self)
213    self.ping_interval = ping_interval
214    self.shutdown_timeout = shutdown_timeout
215    self.daemon = True
216    self._config = session._config  # pylint: disable=protected-access
217    self._target = session.sess_str
218    self._running = False
219    self._devices = devices
220
221    self._graph = None
222    self._session = None
223    self._worker_manager = None
224
225  def _reset_manager(self, stopping=False):
226    """Reset the graph, session and worker manager."""
227    self._graph = ops.Graph()
228    self._session = session_lib.Session(
229        target=self._target,
230        graph=self._graph,
231        config=self._config,
232    )
233
234    if self._devices is None:
235      self._devices = all_worker_devices(self._session)
236
237    with self._graph.as_default():
238      self._worker_manager = WorkerHeartbeatManager.from_devices(
239          self._session, self._devices)
240
241    if stopping:
242      timeout_ms = -1
243      shutdown_mode = event_pb2.NOT_CONFIGURED
244    else:
245      timeout_ms = self.shutdown_timeout * 1000
246      shutdown_mode = event_pb2.WAIT_FOR_COORDINATOR
247
248    self._worker_manager.configure(
249        event_pb2.WorkerHeartbeatRequest(
250            watchdog_config=event_pb2.WatchdogConfig(timeout_ms=timeout_ms),
251            shutdown_mode=shutdown_mode))
252
253  def configure_and_run(self):
254    logging.info(
255        'Enabling watchdog timer with %d second timeout '
256        'and %d second ping interval.', self.shutdown_timeout,
257        self.ping_interval)
258    self._reset_manager()
259    self._running = True
260    self.start()
261
262  def stop(self):
263    logging.info('Stopping worker watchdog.')
264    self._reset_manager(stopping=True)
265    self._running = False
266    self.join()
267
268  def __enter__(self):
269    self.configure_and_run()
270
271  def __exit__(self, exc_type, exc_val, exc_tb):
272    self.stop()
273
274  def run(self):
275    # Don't fetch logs or adjust timing: just ping the watchdog.
276    #
277    # If we hit an exception, reset our session as it is likely broken.
278    while self._running:
279      try:
280        self._worker_manager.ping(request=None)
281        time.sleep(self.ping_interval)
282      except errors.OpError as e:
283        # Catch any TF errors that occur so we don't stop sending heartbeats
284        logging.debug('Caught error while sending heartbeat: %s', e)
285        self._reset_manager()
286
287
288def start_worker_watchdog(session,
289                          devices=None,
290                          ping_interval=60,
291                          shutdown_timeout=3600):
292  """Start global worker watchdog to shutdown workers on coordinator exit."""
293  global _WATCHDOG
294  if _WATCHDOG is None:
295    # Ensure we can send a few pings before we timeout!
296    ping_interval = min(shutdown_timeout / 10., ping_interval)
297    _WATCHDOG = WatchdogManager(session, devices, ping_interval,
298                                shutdown_timeout)
299    _WATCHDOG.configure_and_run()
300
301
302def stop_worker_watchdog():
303  """Stop global worker watchdog."""
304  global _WATCHDOG
305  if _WATCHDOG is not None:
306    _WATCHDOG.stop()
307    _WATCHDOG = None
308
309
310class GracefulShutdownHook(session_run_hook.SessionRunHook):
311  """Session hook that watches for shutdown events.
312
313  If a shutdown is indicated, `saver.save(checkpoint_prefix)` is executed, and a
314  SystemShutdown exception is raised to terminate the main session.  If `saver`
315  is None the `SAVERS` collection will be read to find a saver.
316
317  `on_shutdown_hooks` is an optional list of functions that should be called
318  after checkpointing.  The function is called with (`run_context`,
319  `all_workers`, `lame_workers`).
320
321  If `heartbeat_group` is not specified, it will default to all CPU workers
322  in the system.
323  """
324
325  def __init__(self, checkpoint_prefix, saver=None, on_shutdown_hooks=None):
326    self._saver = saver
327    self._checkpoint_prefix = checkpoint_prefix
328    self._on_shutdown_hooks = on_shutdown_hooks if on_shutdown_hooks else []
329
330    # Worker heartbeats are managed independently of the main training graph.
331    self._graph = ops.Graph()
332    self._workers = None
333    self._session = None
334    self._heartbeat_supported = False
335
336  def after_create_session(self, training_session, coord):  # pylint: disable=unused-argument
337    # N.B. We have to pull the global step here to avoid it being unavailable
338    # at checkpoint time; the graph has been frozen at that point.
339    if training_util.get_global_step() is None and self.saver() is not None:
340      raise ValueError(
341          'Saver defined but no global step.  Run `get_or_create_global_step()`'
342          ' in your model definition to allow checkpointing.')
343
344    with self._graph.as_default():
345      logging.info('Installing graceful shutdown hook.')
346      self._session = _clone_session(training_session, self._graph)
347      self._workers = WorkerHeartbeatManager.from_devices(
348          self._session, all_worker_devices(self._session))
349      self._heartbeat_supported = self._workers.num_workers() > 0
350      if self._heartbeat_supported:
351        try:
352          self._workers.configure(
353              event_pb2.WorkerHeartbeatRequest(
354                  shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR))
355        except errors.InvalidArgumentError:
356          logging.warn(
357              'TPU device does not support heartbeats. Failure '
358              'handling will be disabled.')
359          self._heartbeat_supported = False
360      else:
361        logging.warn(
362            'No workers support hearbeats. Failure handling will be disabled.')
363
364  def saver(self):
365    if self._saver:
366      return self._saver
367
368    savers = ops.get_collection(ops.GraphKeys.SAVERS)
369    if not savers:
370      return None
371
372    if not isinstance(savers, list):
373      return savers
374
375    if len(savers) > 1:
376      logging.error(
377          'Multiple savers in the SAVERS collection.  On-demand checkpointing '
378          'will be disabled. Pass an explicit `saver` to the constructor to '
379          'override this behavior.')
380      return None
381
382    return savers[0]
383
384  def after_run(self, run_context, run_values):
385    del run_values
386    if not self._heartbeat_supported:
387      return
388
389    lame_workers = self._workers.lame_workers()
390
391    if lame_workers:
392      logging.info('ShutdownHook: lame workers found: %s', lame_workers)
393
394      if self.saver():
395        logging.info('ShutdownHook: saving checkpoint to %s',
396                     self._checkpoint_prefix)
397        self.saver().save(
398            run_context.session,
399            self._checkpoint_prefix,
400            global_step=training_util.get_global_step(),
401            write_state=True,
402        )
403      else:
404        logging.info('ShutdownHook: no Saver defined.')
405
406      for fn in self._on_shutdown_hooks:
407        fn(run_context, self._workers, lame_workers)
408
409
410class ResetComputation(object):
411  """Hook to reset a TPUEstimator computation loop.
412
413  This hook shuts down all workers and resets the monitored session loop by
414  throwing a CoordinatorResetError.
415  """
416
417  def __init__(self):
418    pass
419
420  def __call__(self, run_context, all_workers, lame_workers):
421    del run_context, lame_workers
422    all_workers.shutdown()
423
424    logging.info('Resetting coordinator.')
425    raise CoordinatorResetError()
426
427
428class ShutdownLameWorkers(object):
429  """Shutdown lamed workers.
430
431  Processing will continue normally (typically by waiting for the down
432  workers to be restarted).
433  """
434
435  def __init__(self):
436    pass
437
438  def __call__(self, run_context, all_workers, lame_workers):
439    lame_workers.shutdown()
440
441
442class ShutdownAllWorkers(object):
443  """Shutdown all workers.
444
445  Processing will continue normally (typically by waiting for the down
446  workers to be restarted).
447  """
448
449  def __init__(self):
450    pass
451
452  def __call__(self, run_context, all_workers, lame_workers):
453    all_workers.shutdown()
454