1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the 'License'); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an 'AS IS' BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ====================================== 15"""Operations for handling session logging and shutdown notifications.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import threading 22 23import time 24from google.protobuf import text_format 25 26from tensorflow.core.protobuf import config_pb2 27from tensorflow.core.util import event_pb2 28from tensorflow.python.client import session as session_lib 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import errors 31from tensorflow.python.framework import ops 32from tensorflow.python.ops import array_ops 33from tensorflow.python.platform import tf_logging as logging 34from tensorflow.python.tpu.ops import tpu_ops 35from tensorflow.python.training import session_run_hook 36from tensorflow.python.training import training_util 37 38_WATCHDOG = None 39 40 41class CoordinatorResetError(errors.AbortedError): 42 """Raised when the monitored session should reset.""" 43 44 def __init__(self): 45 errors.AbortedError.__init__( 46 self, None, None, 'Resetting session loop due to worker shutdown.') 47 48 49def _clone_session(session, graph=None): 50 return session_lib.Session( 51 target=session.sess_str, 52 config=session._config, # pylint: disable=protected-access 53 graph=graph if graph else session.graph) 54 55 56class WorkerHeartbeatManager(object): 57 """Manages the status/heartbeat monitor for a set of workers.""" 58 59 def __init__(self, session, devices, heartbeat_ops, request_placeholder): 60 """Construct a new WorkerHeartbeatManager. 61 62 (Prefer using `WorkerHeartbeatManager.from_devices` when possible.) 63 64 Args: 65 session: `tf.compat.v1.Session`, session to use for heartbeat operations. 66 devices: `list[string]` Set of devices to connect to. 67 heartbeat_ops: `list[tf.Operation]` Heartbeat operations. 68 request_placeholder: `tf.Placeholder[String]` Placeholder used to specify 69 the WorkerHeartbeatRequest protocol buffer. 70 """ 71 self._session = session 72 self._devices = devices 73 self._ops = heartbeat_ops 74 self._request_placeholder = request_placeholder 75 76 @staticmethod 77 def from_devices(session, devices): 78 """Construct a heartbeat manager for the given devices.""" 79 if not devices: 80 logging.error('Trying to create heartbeat manager with no devices?') 81 82 logging.info('Creating heartbeat manager for %s', devices) 83 request_placeholder = array_ops.placeholder( 84 name='worker_heartbeat_request', dtype=dtypes.string) 85 86 heartbeat_ops = [] 87 for device in devices: 88 with ops.device(device): 89 heartbeat_ops.append(tpu_ops.worker_heartbeat(request_placeholder)) 90 91 return WorkerHeartbeatManager(session, devices, heartbeat_ops, 92 request_placeholder) 93 94 def num_workers(self): 95 return len(self._devices) 96 97 def configure(self, message): 98 """Configure heartbeat manager for all devices. 99 100 Args: 101 message: `event_pb2.WorkerHeartbeatRequest` 102 Returns: `None` 103 """ 104 logging.info('Configuring worker heartbeat: %s', 105 text_format.MessageToString(message)) 106 self._session.run(self._ops, 107 {self._request_placeholder: message.SerializeToString()}) 108 109 def ping(self, request=None, timeout_in_ms=5000): 110 """Ping all workers, returning the parsed status results.""" 111 if request is None: 112 request = event_pb2.WorkerHeartbeatRequest() 113 114 options = config_pb2.RunOptions(timeout_in_ms=timeout_in_ms) 115 results = self._session.run( 116 self._ops, 117 feed_dict={self._request_placeholder: request.SerializeToString()}, 118 options=options) 119 parsed_results = [ 120 event_pb2.WorkerHeartbeatResponse.FromString(res_pb) 121 for res_pb in results 122 ] 123 logging.debug('Ping results: %s', parsed_results) 124 return parsed_results 125 126 def lame_workers(self): 127 """Ping all workers, returning manager containing lame workers (or None).""" 128 ping_results = self.ping() 129 lame_workers = [] 130 131 for ping_response, device, op in zip(ping_results, self._devices, 132 self._ops): 133 if ping_response.health_status != event_pb2.OK: 134 lame_workers.append((device, op)) 135 136 if not lame_workers: 137 return None 138 139 bad_devices, bad_ops = zip(*lame_workers) 140 return WorkerHeartbeatManager(self._session, bad_devices, bad_ops, 141 self._request_placeholder) 142 143 def __repr__(self): 144 return 'HeartbeatManager(%s)' % ','.join(self._devices) 145 146 # Default timeout is set to allow other shutdown triggered operations (log 147 # flushing etc) to finish before terminating the worker. 148 def shutdown(self, wait_time_in_ms=60000, exit_code=None): 149 """Shutdown all workers after `shutdown_timeout_secs`.""" 150 logging.info('Shutting down %s.', self) 151 req = event_pb2.WorkerHeartbeatRequest( 152 watchdog_config=event_pb2.WatchdogConfig(timeout_ms=wait_time_in_ms), 153 shutdown_mode=event_pb2.SHUTDOWN_AFTER_TIMEOUT, 154 exit_code=event_pb2.RequestedExitCode( 155 exit_code=exit_code) if exit_code is not None else None) 156 self.configure(req) 157 158 # Wait for workers to shutdown. 159 sleep_sec = 10.0 + wait_time_in_ms / 1000 160 logging.info('Waiting %.2f seconds for worker shutdown.', sleep_sec) 161 time.sleep(sleep_sec) 162 163 164def all_worker_devices(session): 165 """Return a list of devices for each worker in the system.""" 166 devices = session.list_devices() 167 168 devices_that_support_heartbeats = [] 169 170 for device in devices: 171 name = device.name 172 # Pick devices that have a TPU but target the attached CPU 173 if ':TPU:0' in name and 'coordinator' not in name: 174 devices_that_support_heartbeats.append(name.replace('TPU', 'CPU')) 175 176 return devices_that_support_heartbeats 177 178 179class WatchdogManager(threading.Thread): 180 """Configures worker watchdog timer and handles periodic pings. 181 182 Usage: 183 # Ping workers every minute, shutting down workers if they haven't received 184 # a ping after 1 hour. 185 watchdog_manager = WatchdogManager( 186 ping_interval=60, shutdown_timeout=3600 187 ) 188 189 # Use as a context manager, resetting watchdog on context exit: 190 with watchdog_manager: 191 session.run(...) 192 193 # Or setup globally; watchdog will remain active until program exit. 194 watchdog_manager.configure_and_run() 195 """ 196 197 def __init__(self, 198 session, 199 devices=None, 200 ping_interval=60, 201 shutdown_timeout=3600): 202 """Initialize a watchdog manager. 203 204 Args: 205 session: Session connected to worker devices. A cloned session and graph 206 will be created for managing worker pings. 207 devices: Set of devices to monitor. If none, all workers will be 208 monitored. 209 ping_interval: Time, in seconds, between watchdog pings. 210 shutdown_timeout: Time, in seconds, before watchdog timeout. 211 """ 212 threading.Thread.__init__(self) 213 self.ping_interval = ping_interval 214 self.shutdown_timeout = shutdown_timeout 215 self.daemon = True 216 self._config = session._config # pylint: disable=protected-access 217 self._target = session.sess_str 218 self._running = False 219 self._devices = devices 220 221 self._graph = None 222 self._session = None 223 self._worker_manager = None 224 225 def _reset_manager(self, stopping=False): 226 """Reset the graph, session and worker manager.""" 227 self._graph = ops.Graph() 228 self._session = session_lib.Session( 229 target=self._target, 230 graph=self._graph, 231 config=self._config, 232 ) 233 234 if self._devices is None: 235 self._devices = all_worker_devices(self._session) 236 237 with self._graph.as_default(): 238 self._worker_manager = WorkerHeartbeatManager.from_devices( 239 self._session, self._devices) 240 241 if stopping: 242 timeout_ms = -1 243 shutdown_mode = event_pb2.NOT_CONFIGURED 244 else: 245 timeout_ms = self.shutdown_timeout * 1000 246 shutdown_mode = event_pb2.WAIT_FOR_COORDINATOR 247 248 self._worker_manager.configure( 249 event_pb2.WorkerHeartbeatRequest( 250 watchdog_config=event_pb2.WatchdogConfig(timeout_ms=timeout_ms), 251 shutdown_mode=shutdown_mode)) 252 253 def configure_and_run(self): 254 logging.info( 255 'Enabling watchdog timer with %d second timeout ' 256 'and %d second ping interval.', self.shutdown_timeout, 257 self.ping_interval) 258 self._reset_manager() 259 self._running = True 260 self.start() 261 262 def stop(self): 263 logging.info('Stopping worker watchdog.') 264 self._reset_manager(stopping=True) 265 self._running = False 266 self.join() 267 268 def __enter__(self): 269 self.configure_and_run() 270 271 def __exit__(self, exc_type, exc_val, exc_tb): 272 self.stop() 273 274 def run(self): 275 # Don't fetch logs or adjust timing: just ping the watchdog. 276 # 277 # If we hit an exception, reset our session as it is likely broken. 278 while self._running: 279 try: 280 self._worker_manager.ping(request=None) 281 time.sleep(self.ping_interval) 282 except errors.OpError as e: 283 # Catch any TF errors that occur so we don't stop sending heartbeats 284 logging.debug('Caught error while sending heartbeat: %s', e) 285 self._reset_manager() 286 287 288def start_worker_watchdog(session, 289 devices=None, 290 ping_interval=60, 291 shutdown_timeout=3600): 292 """Start global worker watchdog to shutdown workers on coordinator exit.""" 293 global _WATCHDOG 294 if _WATCHDOG is None: 295 # Ensure we can send a few pings before we timeout! 296 ping_interval = min(shutdown_timeout / 10., ping_interval) 297 _WATCHDOG = WatchdogManager(session, devices, ping_interval, 298 shutdown_timeout) 299 _WATCHDOG.configure_and_run() 300 301 302def stop_worker_watchdog(): 303 """Stop global worker watchdog.""" 304 global _WATCHDOG 305 if _WATCHDOG is not None: 306 _WATCHDOG.stop() 307 _WATCHDOG = None 308 309 310class GracefulShutdownHook(session_run_hook.SessionRunHook): 311 """Session hook that watches for shutdown events. 312 313 If a shutdown is indicated, `saver.save(checkpoint_prefix)` is executed, and a 314 SystemShutdown exception is raised to terminate the main session. If `saver` 315 is None the `SAVERS` collection will be read to find a saver. 316 317 `on_shutdown_hooks` is an optional list of functions that should be called 318 after checkpointing. The function is called with (`run_context`, 319 `all_workers`, `lame_workers`). 320 321 If `heartbeat_group` is not specified, it will default to all CPU workers 322 in the system. 323 """ 324 325 def __init__(self, checkpoint_prefix, saver=None, on_shutdown_hooks=None): 326 self._saver = saver 327 self._checkpoint_prefix = checkpoint_prefix 328 self._on_shutdown_hooks = on_shutdown_hooks if on_shutdown_hooks else [] 329 330 # Worker heartbeats are managed independently of the main training graph. 331 self._graph = ops.Graph() 332 self._workers = None 333 self._session = None 334 self._heartbeat_supported = False 335 336 def after_create_session(self, training_session, coord): # pylint: disable=unused-argument 337 # N.B. We have to pull the global step here to avoid it being unavailable 338 # at checkpoint time; the graph has been frozen at that point. 339 if training_util.get_global_step() is None and self.saver() is not None: 340 raise ValueError( 341 'Saver defined but no global step. Run `get_or_create_global_step()`' 342 ' in your model definition to allow checkpointing.') 343 344 with self._graph.as_default(): 345 logging.info('Installing graceful shutdown hook.') 346 self._session = _clone_session(training_session, self._graph) 347 self._workers = WorkerHeartbeatManager.from_devices( 348 self._session, all_worker_devices(self._session)) 349 self._heartbeat_supported = self._workers.num_workers() > 0 350 if self._heartbeat_supported: 351 try: 352 self._workers.configure( 353 event_pb2.WorkerHeartbeatRequest( 354 shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR)) 355 except errors.InvalidArgumentError: 356 logging.warn( 357 'TPU device does not support heartbeats. Failure ' 358 'handling will be disabled.') 359 self._heartbeat_supported = False 360 else: 361 logging.warn( 362 'No workers support hearbeats. Failure handling will be disabled.') 363 364 def saver(self): 365 if self._saver: 366 return self._saver 367 368 savers = ops.get_collection(ops.GraphKeys.SAVERS) 369 if not savers: 370 return None 371 372 if not isinstance(savers, list): 373 return savers 374 375 if len(savers) > 1: 376 logging.error( 377 'Multiple savers in the SAVERS collection. On-demand checkpointing ' 378 'will be disabled. Pass an explicit `saver` to the constructor to ' 379 'override this behavior.') 380 return None 381 382 return savers[0] 383 384 def after_run(self, run_context, run_values): 385 del run_values 386 if not self._heartbeat_supported: 387 return 388 389 lame_workers = self._workers.lame_workers() 390 391 if lame_workers: 392 logging.info('ShutdownHook: lame workers found: %s', lame_workers) 393 394 if self.saver(): 395 logging.info('ShutdownHook: saving checkpoint to %s', 396 self._checkpoint_prefix) 397 self.saver().save( 398 run_context.session, 399 self._checkpoint_prefix, 400 global_step=training_util.get_global_step(), 401 write_state=True, 402 ) 403 else: 404 logging.info('ShutdownHook: no Saver defined.') 405 406 for fn in self._on_shutdown_hooks: 407 fn(run_context, self._workers, lame_workers) 408 409 410class ResetComputation(object): 411 """Hook to reset a TPUEstimator computation loop. 412 413 This hook shuts down all workers and resets the monitored session loop by 414 throwing a CoordinatorResetError. 415 """ 416 417 def __init__(self): 418 pass 419 420 def __call__(self, run_context, all_workers, lame_workers): 421 del run_context, lame_workers 422 all_workers.shutdown() 423 424 logging.info('Resetting coordinator.') 425 raise CoordinatorResetError() 426 427 428class ShutdownLameWorkers(object): 429 """Shutdown lamed workers. 430 431 Processing will continue normally (typically by waiting for the down 432 workers to be restarted). 433 """ 434 435 def __init__(self): 436 pass 437 438 def __call__(self, run_context, all_workers, lame_workers): 439 lame_workers.shutdown() 440 441 442class ShutdownAllWorkers(object): 443 """Shutdown all workers. 444 445 Processing will continue normally (typically by waiting for the down 446 workers to be restarted). 447 """ 448 449 def __init__(self): 450 pass 451 452 def __call__(self, run_context, all_workers, lame_workers): 453 all_workers.shutdown() 454