1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Training helper that checkpoints models and creates session.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import time 21 22import numpy as np 23 24from tensorflow.python.client import session 25from tensorflow.python.distribute import distribution_strategy_context 26from tensorflow.python.framework import errors 27from tensorflow.python.framework import ops 28from tensorflow.python.platform import tf_logging as logging 29from tensorflow.python.training import checkpoint_management 30from tensorflow.python.util.tf_export import tf_export 31 32 33def _maybe_name(obj): 34 """Returns object name if it has one, or a message otherwise. 35 36 This is useful for names that apper in error messages. 37 Args: 38 obj: Object to get the name of. 39 Returns: 40 name, "None", or a "no name" message. 41 """ 42 if obj is None: 43 return "None" 44 elif hasattr(obj, "name"): 45 return obj.name 46 else: 47 return "<no name for %s>" % type(obj) 48 49 50@tf_export(v1=["train.SessionManager"]) 51class SessionManager(object): 52 """Training helper that restores from checkpoint and creates session. 53 54 This class is a small wrapper that takes care of session creation and 55 checkpoint recovery. It also provides functions that to facilitate 56 coordination among multiple training threads or processes. 57 58 * Checkpointing trained variables as the training progresses. 59 * Initializing variables on startup, restoring them from the most recent 60 checkpoint after a crash, or wait for checkpoints to become available. 61 62 ### Usage: 63 64 ```python 65 with tf.Graph().as_default(): 66 ...add operations to the graph... 67 # Create a SessionManager that will checkpoint the model in '/tmp/mydir'. 68 sm = SessionManager() 69 sess = sm.prepare_session(master, init_op, saver, checkpoint_dir) 70 # Use the session to train the graph. 71 while True: 72 sess.run(<my_train_op>) 73 ``` 74 75 `prepare_session()` initializes or restores a model. It requires `init_op` 76 and `saver` as an argument. 77 78 A second process could wait for the model to be ready by doing the following: 79 80 ```python 81 with tf.Graph().as_default(): 82 ...add operations to the graph... 83 # Create a SessionManager that will wait for the model to become ready. 84 sm = SessionManager() 85 sess = sm.wait_for_session(master) 86 # Use the session to train the graph. 87 while True: 88 sess.run(<my_train_op>) 89 ``` 90 91 `wait_for_session()` waits for a model to be initialized by other processes. 92 93 """ 94 95 def __init__(self, 96 local_init_op=None, 97 ready_op=None, 98 ready_for_local_init_op=None, 99 graph=None, 100 recovery_wait_secs=30, 101 local_init_run_options=None, 102 local_init_feed_dict=None): 103 """Creates a SessionManager. 104 105 The `local_init_op` is an `Operation` that is run always after a new session 106 was created. If `None`, this step is skipped. 107 108 The `ready_op` is an `Operation` used to check if the model is ready. The 109 model is considered ready if that operation returns an empty 1D string 110 tensor. If the operation returns a non empty 1D string tensor, the elements 111 are concatenated and used to indicate to the user why the model is not 112 ready. 113 114 The `ready_for_local_init_op` is an `Operation` used to check if the model 115 is ready to run local_init_op. The model is considered ready if that 116 operation returns an empty 1D string tensor. If the operation returns a non 117 empty 1D string tensor, the elements are concatenated and used to indicate 118 to the user why the model is not ready. 119 120 If `ready_op` is `None`, the model is not checked for readiness. 121 122 `recovery_wait_secs` is the number of seconds between checks that 123 the model is ready. It is used by processes to wait for a model to 124 be initialized or restored. Defaults to 30 seconds. 125 126 Args: 127 local_init_op: An `Operation` run immediately after session creation. 128 Usually used to initialize tables and local variables. 129 ready_op: An `Operation` to check if the model is initialized. 130 ready_for_local_init_op: An `Operation` to check if the model is ready 131 to run local_init_op. 132 graph: The `Graph` that the model will use. 133 recovery_wait_secs: Seconds between checks for the model to be ready. 134 local_init_run_options: RunOptions to be passed to session.run when 135 executing the local_init_op. 136 local_init_feed_dict: Optional session feed dictionary to use when running 137 the local_init_op. 138 139 Raises: 140 ValueError: If ready_for_local_init_op is not None but local_init_op is 141 None 142 """ 143 # Sets default values of arguments. 144 if graph is None: 145 graph = ops.get_default_graph() 146 self._local_init_op = local_init_op 147 self._ready_op = ready_op 148 self._ready_for_local_init_op = ready_for_local_init_op 149 self._graph = graph 150 self._recovery_wait_secs = recovery_wait_secs 151 self._target = None 152 self._local_init_run_options = local_init_run_options 153 self._local_init_feed_dict = local_init_feed_dict 154 if ready_for_local_init_op is not None and local_init_op is None: 155 raise ValueError("If you pass a ready_for_local_init_op " 156 "you must also pass a local_init_op " 157 ", ready_for_local_init_op [%s]" % 158 ready_for_local_init_op) 159 160 def _restore_checkpoint(self, 161 master, 162 saver=None, 163 checkpoint_dir=None, 164 checkpoint_filename_with_path=None, 165 wait_for_checkpoint=False, 166 max_wait_secs=7200, 167 config=None): 168 """Creates a `Session`, and tries to restore a checkpoint. 169 170 171 Args: 172 master: `String` representation of the TensorFlow master to use. 173 saver: A `Saver` object used to restore a model. 174 checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the 175 dir will be used to restore. 176 checkpoint_filename_with_path: Full file name path to the checkpoint file. 177 wait_for_checkpoint: Whether to wait for checkpoint to become available. 178 max_wait_secs: Maximum time to wait for checkpoints to become available. 179 config: Optional `ConfigProto` proto used to configure the session. 180 181 Returns: 182 A pair (sess, is_restored) where 'is_restored' is `True` if 183 the session could be restored, `False` otherwise. 184 185 Raises: 186 ValueError: If both checkpoint_dir and checkpoint_filename_with_path are 187 set. 188 """ 189 self._target = master 190 191 # This is required to so that we initialize the TPU device before 192 # restoring from checkpoint since we'll be placing variables on the device 193 # and TPUInitialize wipes out the memory of the device. 194 strategy = distribution_strategy_context.get_strategy() 195 if strategy and hasattr(strategy.extended, 196 "_experimental_initialize_system"): 197 strategy.extended._experimental_initialize_system() # pylint: disable=protected-access 198 199 sess = session.Session(self._target, graph=self._graph, config=config) 200 if checkpoint_dir and checkpoint_filename_with_path: 201 raise ValueError("Can not provide both checkpoint_dir and " 202 "checkpoint_filename_with_path.") 203 # If either saver or checkpoint_* is not specified, cannot restore. Just 204 # return. 205 if not saver or not (checkpoint_dir or checkpoint_filename_with_path): 206 return sess, False 207 208 if checkpoint_filename_with_path: 209 saver.restore(sess, checkpoint_filename_with_path) 210 return sess, True 211 212 # Waits up until max_wait_secs for checkpoint to become available. 213 wait_time = 0 214 ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir) 215 while not ckpt or not ckpt.model_checkpoint_path: 216 if wait_for_checkpoint and wait_time < max_wait_secs: 217 logging.info("Waiting for checkpoint to be available.") 218 time.sleep(self._recovery_wait_secs) 219 wait_time += self._recovery_wait_secs 220 ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir) 221 else: 222 return sess, False 223 224 # Loads the checkpoint. 225 saver.restore(sess, ckpt.model_checkpoint_path) 226 saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths) 227 return sess, True 228 229 def prepare_session(self, 230 master, 231 init_op=None, 232 saver=None, 233 checkpoint_dir=None, 234 checkpoint_filename_with_path=None, 235 wait_for_checkpoint=False, 236 max_wait_secs=7200, 237 config=None, 238 init_feed_dict=None, 239 init_fn=None): 240 """Creates a `Session`. Makes sure the model is ready to be used. 241 242 Creates a `Session` on 'master'. If a `saver` object is passed in, and 243 `checkpoint_dir` points to a directory containing valid checkpoint 244 files, then it will try to recover the model from checkpoint. If 245 no checkpoint files are available, and `wait_for_checkpoint` is 246 `True`, then the process would check every `recovery_wait_secs`, 247 up to `max_wait_secs`, for recovery to succeed. 248 249 If the model cannot be recovered successfully then it is initialized by 250 running the `init_op` and calling `init_fn` if they are provided. 251 The `local_init_op` is also run after init_op and init_fn, regardless of 252 whether the model was recovered successfully, but only if 253 `ready_for_local_init_op` passes. 254 255 If the model is recovered from a checkpoint it is assumed that all 256 global variables have been initialized, in particular neither `init_op` 257 nor `init_fn` will be executed. 258 259 It is an error if the model cannot be recovered and no `init_op` 260 or `init_fn` or `local_init_op` are passed. 261 262 Args: 263 master: `String` representation of the TensorFlow master to use. 264 init_op: Optional `Operation` used to initialize the model. 265 saver: A `Saver` object used to restore a model. 266 checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the 267 dir will be used to restore. 268 checkpoint_filename_with_path: Full file name path to the checkpoint file. 269 wait_for_checkpoint: Whether to wait for checkpoint to become available. 270 max_wait_secs: Maximum time to wait for checkpoints to become available. 271 config: Optional `ConfigProto` proto used to configure the session. 272 init_feed_dict: Optional dictionary that maps `Tensor` objects to feed 273 values. This feed dictionary is passed to the session `run()` call when 274 running the init op. 275 init_fn: Optional callable used to initialize the model. Called after the 276 optional `init_op` is called. The callable must accept one argument, 277 the session being initialized. 278 279 Returns: 280 A `Session` object that can be used to drive the model. 281 282 Raises: 283 RuntimeError: If the model cannot be initialized or recovered. 284 ValueError: If both checkpoint_dir and checkpoint_filename_with_path are 285 set. 286 """ 287 288 sess, is_loaded_from_checkpoint = self._restore_checkpoint( 289 master, 290 saver, 291 checkpoint_dir=checkpoint_dir, 292 checkpoint_filename_with_path=checkpoint_filename_with_path, 293 wait_for_checkpoint=wait_for_checkpoint, 294 max_wait_secs=max_wait_secs, 295 config=config) 296 if not is_loaded_from_checkpoint: 297 if init_op is None and not init_fn and self._local_init_op is None: 298 raise RuntimeError("Model is not initialized and no init_op or " 299 "init_fn or local_init_op was given") 300 if init_op is not None: 301 sess.run(init_op, feed_dict=init_feed_dict) 302 if init_fn: 303 init_fn(sess) 304 305 local_init_success, msg = self._try_run_local_init_op(sess) 306 if not local_init_success: 307 raise RuntimeError( 308 "Init operations did not make model ready for local_init. " 309 "Init op: %s, init fn: %s, error: %s" % (_maybe_name(init_op), 310 init_fn, 311 msg)) 312 313 is_ready, msg = self._model_ready(sess) 314 if not is_ready: 315 raise RuntimeError( 316 "Init operations did not make model ready. " 317 "Init op: %s, init fn: %s, local_init_op: %s, error: %s" % 318 (_maybe_name(init_op), init_fn, self._local_init_op, msg)) 319 return sess 320 321 def recover_session(self, 322 master, 323 saver=None, 324 checkpoint_dir=None, 325 checkpoint_filename_with_path=None, 326 wait_for_checkpoint=False, 327 max_wait_secs=7200, 328 config=None): 329 """Creates a `Session`, recovering if possible. 330 331 Creates a new session on 'master'. If the session is not initialized 332 and can be recovered from a checkpoint, recover it. 333 334 Args: 335 master: `String` representation of the TensorFlow master to use. 336 saver: A `Saver` object used to restore a model. 337 checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the 338 dir will be used to restore. 339 checkpoint_filename_with_path: Full file name path to the checkpoint file. 340 wait_for_checkpoint: Whether to wait for checkpoint to become available. 341 max_wait_secs: Maximum time to wait for checkpoints to become available. 342 config: Optional `ConfigProto` proto used to configure the session. 343 344 Returns: 345 A pair (sess, initialized) where 'initialized' is `True` if 346 the session could be recovered and initialized, `False` otherwise. 347 348 Raises: 349 ValueError: If both checkpoint_dir and checkpoint_filename_with_path are 350 set. 351 """ 352 353 sess, is_loaded_from_checkpoint = self._restore_checkpoint( 354 master, 355 saver, 356 checkpoint_dir=checkpoint_dir, 357 checkpoint_filename_with_path=checkpoint_filename_with_path, 358 wait_for_checkpoint=wait_for_checkpoint, 359 max_wait_secs=max_wait_secs, 360 config=config) 361 362 # Always try to run local_init_op 363 local_init_success, msg = self._try_run_local_init_op(sess) 364 365 if not is_loaded_from_checkpoint: 366 # Do not need to run checks for readiness 367 return sess, False 368 369 restoring_file = checkpoint_dir or checkpoint_filename_with_path 370 if not local_init_success: 371 logging.info( 372 "Restoring model from %s did not make model ready for local init:" 373 " %s", restoring_file, msg) 374 return sess, False 375 376 is_ready, msg = self._model_ready(sess) 377 if not is_ready: 378 logging.info("Restoring model from %s did not make model ready: %s", 379 restoring_file, msg) 380 return sess, False 381 382 logging.info("Restored model from %s", restoring_file) 383 return sess, is_loaded_from_checkpoint 384 385 def wait_for_session(self, master, config=None, max_wait_secs=float("Inf")): 386 """Creates a new `Session` and waits for model to be ready. 387 388 Creates a new `Session` on 'master'. Waits for the model to be 389 initialized or recovered from a checkpoint. It's expected that 390 another thread or process will make the model ready, and that this 391 is intended to be used by threads/processes that participate in a 392 distributed training configuration where a different thread/process 393 is responsible for initializing or recovering the model being trained. 394 395 NB: The amount of time this method waits for the session is bounded 396 by max_wait_secs. By default, this function will wait indefinitely. 397 398 Args: 399 master: `String` representation of the TensorFlow master to use. 400 config: Optional ConfigProto proto used to configure the session. 401 max_wait_secs: Maximum time to wait for the session to become available. 402 403 Returns: 404 A `Session`. May be None if the operation exceeds the timeout 405 specified by config.operation_timeout_in_ms. 406 407 Raises: 408 tf.DeadlineExceededError: if the session is not available after 409 max_wait_secs. 410 """ 411 self._target = master 412 413 if max_wait_secs is None: 414 max_wait_secs = float("Inf") 415 timer = _CountDownTimer(max_wait_secs) 416 417 while True: 418 sess = session.Session(self._target, graph=self._graph, config=config) 419 not_ready_msg = None 420 not_ready_local_msg = None 421 local_init_success, not_ready_local_msg = self._try_run_local_init_op( 422 sess) 423 if local_init_success: 424 # Successful if local_init_op is None, or ready_for_local_init_op passes 425 is_ready, not_ready_msg = self._model_ready(sess) 426 if is_ready: 427 return sess 428 429 self._safe_close(sess) 430 431 # Do we have enough time left to try again? 432 remaining_ms_after_wait = ( 433 timer.secs_remaining() - self._recovery_wait_secs) 434 if remaining_ms_after_wait < 0: 435 raise errors.DeadlineExceededError( 436 None, None, 437 "Session was not ready after waiting %d secs." % (max_wait_secs,)) 438 439 logging.info("Waiting for model to be ready. " 440 "Ready_for_local_init_op: %s, ready: %s", 441 not_ready_local_msg, not_ready_msg) 442 time.sleep(self._recovery_wait_secs) 443 444 def _safe_close(self, sess): 445 """Closes a session without raising an exception. 446 447 Just like sess.close() but ignores exceptions. 448 449 Args: 450 sess: A `Session`. 451 """ 452 # pylint: disable=broad-except 453 try: 454 sess.close() 455 except Exception: 456 # Intentionally not logging to avoid user complaints that 457 # they get cryptic errors. We really do not care that Close 458 # fails. 459 pass 460 # pylint: enable=broad-except 461 462 def _model_ready(self, sess): 463 """Checks if the model is ready or not. 464 465 Args: 466 sess: A `Session`. 467 468 Returns: 469 A tuple (is_ready, msg), where is_ready is True if ready and False 470 otherwise, and msg is `None` if the model is ready, a `String` with the 471 reason why it is not ready otherwise. 472 """ 473 return _ready(self._ready_op, sess, "Model not ready") 474 475 def _model_ready_for_local_init(self, sess): 476 """Checks if the model is ready to run local_init_op. 477 478 Args: 479 sess: A `Session`. 480 481 Returns: 482 A tuple (is_ready, msg), where is_ready is True if ready to run 483 local_init_op and False otherwise, and msg is `None` if the model is 484 ready to run local_init_op, a `String` with the reason why it is not ready 485 otherwise. 486 """ 487 return _ready(self._ready_for_local_init_op, sess, 488 "Model not ready for local init") 489 490 def _try_run_local_init_op(self, sess): 491 """Tries to run _local_init_op, if not None, and is ready for local init. 492 493 Args: 494 sess: A `Session`. 495 496 Returns: 497 A tuple (is_successful, msg), where is_successful is True if 498 _local_init_op is None, or we ran _local_init_op, and False otherwise; 499 and msg is a `String` with the reason why the model was not ready to run 500 local init. 501 """ 502 if self._local_init_op is not None: 503 is_ready_for_local_init, msg = self._model_ready_for_local_init(sess) 504 if is_ready_for_local_init: 505 logging.info("Running local_init_op.") 506 sess.run(self._local_init_op, feed_dict=self._local_init_feed_dict, 507 options=self._local_init_run_options) 508 logging.info("Done running local_init_op.") 509 return True, None 510 else: 511 return False, msg 512 return True, None 513 514 515def _ready(op, sess, msg): 516 """Checks if the model is ready or not, as determined by op. 517 518 Args: 519 op: An op, either _ready_op or _ready_for_local_init_op, which defines the 520 readiness of the model. 521 sess: A `Session`. 522 msg: A message to log to warning if not ready 523 524 Returns: 525 A tuple (is_ready, msg), where is_ready is True if ready and False 526 otherwise, and msg is `None` if the model is ready, a `String` with the 527 reason why it is not ready otherwise. 528 """ 529 if op is None: 530 return True, None 531 else: 532 try: 533 ready_value = sess.run(op) 534 # The model is considered ready if ready_op returns an empty 1-D tensor. 535 # Also compare to `None` and dtype being int32 for backward 536 # compatibility. 537 if (ready_value is None or ready_value.dtype == np.int32 or 538 ready_value.size == 0): 539 return True, None 540 else: 541 # TODO(sherrym): If a custom ready_op returns other types of tensor, 542 # or strings other than variable names, this message could be 543 # confusing. 544 non_initialized_varnames = ", ".join( 545 [i.decode("utf-8") for i in ready_value]) 546 return False, "Variables not initialized: " + non_initialized_varnames 547 except errors.FailedPreconditionError as e: 548 if "uninitialized" not in str(e): 549 logging.warning("%s : error [%s]", msg, str(e)) 550 raise e 551 return False, str(e) 552 553 554class _CountDownTimer(object): 555 556 __slots__ = ["_start_time_secs", "_duration_secs"] 557 558 def __init__(self, duration_secs): 559 self._start_time_secs = time.time() 560 self._duration_secs = duration_secs 561 562 def secs_remaining(self): 563 diff = self._duration_secs - (time.time() - self._start_time_secs) 564 return max(0, diff) 565