1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Training helper that checkpoints models and creates session.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import time 21 22import numpy as np 23 24from tensorflow.python.client import session 25from tensorflow.python.distribute import distribution_strategy_context 26from tensorflow.python.framework import errors 27from tensorflow.python.framework import ops 28from tensorflow.python.platform import tf_logging as logging 29from tensorflow.python.training import checkpoint_management 30from tensorflow.python.util.tf_export import tf_export 31 32 33def _maybe_name(obj): 34 """Returns object name if it has one, or a message otherwise. 35 36 This is useful for names that apper in error messages. 37 Args: 38 obj: Object to get the name of. 39 Returns: 40 name, "None", or a "no name" message. 41 """ 42 if obj is None: 43 return "None" 44 elif hasattr(obj, "name"): 45 return obj.name 46 else: 47 return "<no name for %s>" % type(obj) 48 49 50def _restore_checkpoint_and_maybe_run_saved_model_initializers( 51 sess, saver, path): 52 """Restores checkpoint values and SavedModel initializers if found.""" 53 # NOTE: All references to SavedModel refer to SavedModels loaded from the 54 # load_v2 API (which does not require the `sess` argument). 55 56 # If the graph contains resources loaded from a SavedModel, they are not 57 # restored when calling `saver.restore`. Thus, the SavedModel initializer must 58 # be called with `saver.restore` to properly initialize the model. 59 60 # The SavedModel init is stored in the "saved_model_initializers" collection. 61 # This collection is part of the MetaGraph's default_init_op, so it is already 62 # called by MonitoredSession as long as the saver doesn't restore any 63 # checkpoints from the working dir. 64 saved_model_init_ops = ops.get_collection("saved_model_initializers") 65 if saved_model_init_ops: 66 sess.run(saved_model_init_ops) 67 68 # The saver must be called *after* the SavedModel init, because the SavedModel 69 # init will restore the variables from the SavedModel variables directory. 70 # Initializing/restoring twice is not ideal but there's no other way to do it. 71 saver.restore(sess, path) 72 73 74@tf_export(v1=["train.SessionManager"]) 75class SessionManager(object): 76 """Training helper that restores from checkpoint and creates session. 77 78 This class is a small wrapper that takes care of session creation and 79 checkpoint recovery. It also provides functions that to facilitate 80 coordination among multiple training threads or processes. 81 82 * Checkpointing trained variables as the training progresses. 83 * Initializing variables on startup, restoring them from the most recent 84 checkpoint after a crash, or wait for checkpoints to become available. 85 86 ### Usage: 87 88 ```python 89 with tf.Graph().as_default(): 90 ...add operations to the graph... 91 # Create a SessionManager that will checkpoint the model in '/tmp/mydir'. 92 sm = SessionManager() 93 sess = sm.prepare_session(master, init_op, saver, checkpoint_dir) 94 # Use the session to train the graph. 95 while True: 96 sess.run(<my_train_op>) 97 ``` 98 99 `prepare_session()` initializes or restores a model. It requires `init_op` 100 and `saver` as an argument. 101 102 A second process could wait for the model to be ready by doing the following: 103 104 ```python 105 with tf.Graph().as_default(): 106 ...add operations to the graph... 107 # Create a SessionManager that will wait for the model to become ready. 108 sm = SessionManager() 109 sess = sm.wait_for_session(master) 110 # Use the session to train the graph. 111 while True: 112 sess.run(<my_train_op>) 113 ``` 114 115 `wait_for_session()` waits for a model to be initialized by other processes. 116 117 """ 118 119 def __init__(self, 120 local_init_op=None, 121 ready_op=None, 122 ready_for_local_init_op=None, 123 graph=None, 124 recovery_wait_secs=30, 125 local_init_run_options=None, 126 local_init_feed_dict=None): 127 """Creates a SessionManager. 128 129 The `local_init_op` is an `Operation` that is run always after a new session 130 was created. If `None`, this step is skipped. 131 132 The `ready_op` is an `Operation` used to check if the model is ready. The 133 model is considered ready if that operation returns an empty 1D string 134 tensor. If the operation returns a non empty 1D string tensor, the elements 135 are concatenated and used to indicate to the user why the model is not 136 ready. 137 138 The `ready_for_local_init_op` is an `Operation` used to check if the model 139 is ready to run local_init_op. The model is considered ready if that 140 operation returns an empty 1D string tensor. If the operation returns a non 141 empty 1D string tensor, the elements are concatenated and used to indicate 142 to the user why the model is not ready. 143 144 If `ready_op` is `None`, the model is not checked for readiness. 145 146 `recovery_wait_secs` is the number of seconds between checks that 147 the model is ready. It is used by processes to wait for a model to 148 be initialized or restored. Defaults to 30 seconds. 149 150 Args: 151 local_init_op: An `Operation` run immediately after session creation. 152 Usually used to initialize tables and local variables. 153 ready_op: An `Operation` to check if the model is initialized. 154 ready_for_local_init_op: An `Operation` to check if the model is ready 155 to run local_init_op. 156 graph: The `Graph` that the model will use. 157 recovery_wait_secs: Seconds between checks for the model to be ready. 158 local_init_run_options: RunOptions to be passed to session.run when 159 executing the local_init_op. 160 local_init_feed_dict: Optional session feed dictionary to use when running 161 the local_init_op. 162 163 Raises: 164 ValueError: If ready_for_local_init_op is not None but local_init_op is 165 None 166 """ 167 # Sets default values of arguments. 168 if graph is None: 169 graph = ops.get_default_graph() 170 self._local_init_op = local_init_op 171 self._ready_op = ready_op 172 self._ready_for_local_init_op = ready_for_local_init_op 173 self._graph = graph 174 self._recovery_wait_secs = recovery_wait_secs 175 self._target = None 176 self._local_init_run_options = local_init_run_options 177 self._local_init_feed_dict = local_init_feed_dict 178 if ready_for_local_init_op is not None and local_init_op is None: 179 raise ValueError("If you pass a ready_for_local_init_op " 180 "you must also pass a local_init_op " 181 ", ready_for_local_init_op [%s]" % 182 ready_for_local_init_op) 183 184 def _restore_checkpoint(self, 185 master, 186 saver=None, 187 checkpoint_dir=None, 188 checkpoint_filename_with_path=None, 189 wait_for_checkpoint=False, 190 max_wait_secs=7200, 191 config=None): 192 """Creates a `Session`, and tries to restore a checkpoint. 193 194 195 Args: 196 master: `String` representation of the TensorFlow master to use. 197 saver: A `Saver` object used to restore a model. 198 checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the 199 dir will be used to restore. 200 checkpoint_filename_with_path: Full file name path to the checkpoint file. 201 wait_for_checkpoint: Whether to wait for checkpoint to become available. 202 max_wait_secs: Maximum time to wait for checkpoints to become available. 203 config: Optional `ConfigProto` proto used to configure the session. 204 205 Returns: 206 A pair (sess, is_restored) where 'is_restored' is `True` if 207 the session could be restored, `False` otherwise. 208 209 Raises: 210 ValueError: If both checkpoint_dir and checkpoint_filename_with_path are 211 set. 212 """ 213 self._target = master 214 215 # This is required to so that we initialize the TPU device before 216 # restoring from checkpoint since we'll be placing variables on the device 217 # and TPUInitialize wipes out the memory of the device. 218 strategy = distribution_strategy_context.get_strategy() 219 if strategy and hasattr(strategy.extended, 220 "_experimental_initialize_system"): 221 strategy.extended._experimental_initialize_system() # pylint: disable=protected-access 222 223 sess = session.Session(self._target, graph=self._graph, config=config) 224 if checkpoint_dir and checkpoint_filename_with_path: 225 raise ValueError("Can not provide both checkpoint_dir and " 226 "checkpoint_filename_with_path.") 227 # If either saver or checkpoint_* is not specified, cannot restore. Just 228 # return. 229 if not saver or not (checkpoint_dir or checkpoint_filename_with_path): 230 return sess, False 231 232 if checkpoint_filename_with_path: 233 _restore_checkpoint_and_maybe_run_saved_model_initializers( 234 sess, saver, checkpoint_filename_with_path) 235 return sess, True 236 237 # Waits up until max_wait_secs for checkpoint to become available. 238 wait_time = 0 239 ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir) 240 while not ckpt or not ckpt.model_checkpoint_path: 241 if wait_for_checkpoint and wait_time < max_wait_secs: 242 logging.info("Waiting for checkpoint to be available.") 243 time.sleep(self._recovery_wait_secs) 244 wait_time += self._recovery_wait_secs 245 ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir) 246 else: 247 return sess, False 248 249 # Loads the checkpoint. 250 _restore_checkpoint_and_maybe_run_saved_model_initializers( 251 sess, saver, ckpt.model_checkpoint_path) 252 saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths) 253 return sess, True 254 255 def prepare_session(self, 256 master, 257 init_op=None, 258 saver=None, 259 checkpoint_dir=None, 260 checkpoint_filename_with_path=None, 261 wait_for_checkpoint=False, 262 max_wait_secs=7200, 263 config=None, 264 init_feed_dict=None, 265 init_fn=None): 266 """Creates a `Session`. Makes sure the model is ready to be used. 267 268 Creates a `Session` on 'master'. If a `saver` object is passed in, and 269 `checkpoint_dir` points to a directory containing valid checkpoint 270 files, then it will try to recover the model from checkpoint. If 271 no checkpoint files are available, and `wait_for_checkpoint` is 272 `True`, then the process would check every `recovery_wait_secs`, 273 up to `max_wait_secs`, for recovery to succeed. 274 275 If the model cannot be recovered successfully then it is initialized by 276 running the `init_op` and calling `init_fn` if they are provided. 277 The `local_init_op` is also run after init_op and init_fn, regardless of 278 whether the model was recovered successfully, but only if 279 `ready_for_local_init_op` passes. 280 281 If the model is recovered from a checkpoint it is assumed that all 282 global variables have been initialized, in particular neither `init_op` 283 nor `init_fn` will be executed. 284 285 It is an error if the model cannot be recovered and no `init_op` 286 or `init_fn` or `local_init_op` are passed. 287 288 Args: 289 master: `String` representation of the TensorFlow master to use. 290 init_op: Optional `Operation` used to initialize the model. 291 saver: A `Saver` object used to restore a model. 292 checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the 293 dir will be used to restore. 294 checkpoint_filename_with_path: Full file name path to the checkpoint file. 295 wait_for_checkpoint: Whether to wait for checkpoint to become available. 296 max_wait_secs: Maximum time to wait for checkpoints to become available. 297 config: Optional `ConfigProto` proto used to configure the session. 298 init_feed_dict: Optional dictionary that maps `Tensor` objects to feed 299 values. This feed dictionary is passed to the session `run()` call when 300 running the init op. 301 init_fn: Optional callable used to initialize the model. Called after the 302 optional `init_op` is called. The callable must accept one argument, 303 the session being initialized. 304 305 Returns: 306 A `Session` object that can be used to drive the model. 307 308 Raises: 309 RuntimeError: If the model cannot be initialized or recovered. 310 ValueError: If both checkpoint_dir and checkpoint_filename_with_path are 311 set. 312 """ 313 314 sess, is_loaded_from_checkpoint = self._restore_checkpoint( 315 master, 316 saver, 317 checkpoint_dir=checkpoint_dir, 318 checkpoint_filename_with_path=checkpoint_filename_with_path, 319 wait_for_checkpoint=wait_for_checkpoint, 320 max_wait_secs=max_wait_secs, 321 config=config) 322 if not is_loaded_from_checkpoint: 323 if init_op is None and not init_fn and self._local_init_op is None: 324 raise RuntimeError("Model is not initialized and no init_op or " 325 "init_fn or local_init_op was given") 326 if init_op is not None: 327 sess.run(init_op, feed_dict=init_feed_dict) 328 if init_fn: 329 init_fn(sess) 330 331 local_init_success, msg = self._try_run_local_init_op(sess) 332 if not local_init_success: 333 raise RuntimeError( 334 "Init operations did not make model ready for local_init. " 335 "Init op: %s, init fn: %s, error: %s" % (_maybe_name(init_op), 336 init_fn, 337 msg)) 338 339 is_ready, msg = self._model_ready(sess) 340 if not is_ready: 341 raise RuntimeError( 342 "Init operations did not make model ready. " 343 "Init op: %s, init fn: %s, local_init_op: %s, error: %s" % 344 (_maybe_name(init_op), init_fn, self._local_init_op, msg)) 345 return sess 346 347 def recover_session(self, 348 master, 349 saver=None, 350 checkpoint_dir=None, 351 checkpoint_filename_with_path=None, 352 wait_for_checkpoint=False, 353 max_wait_secs=7200, 354 config=None): 355 """Creates a `Session`, recovering if possible. 356 357 Creates a new session on 'master'. If the session is not initialized 358 and can be recovered from a checkpoint, recover it. 359 360 Args: 361 master: `String` representation of the TensorFlow master to use. 362 saver: A `Saver` object used to restore a model. 363 checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the 364 dir will be used to restore. 365 checkpoint_filename_with_path: Full file name path to the checkpoint file. 366 wait_for_checkpoint: Whether to wait for checkpoint to become available. 367 max_wait_secs: Maximum time to wait for checkpoints to become available. 368 config: Optional `ConfigProto` proto used to configure the session. 369 370 Returns: 371 A pair (sess, initialized) where 'initialized' is `True` if 372 the session could be recovered and initialized, `False` otherwise. 373 374 Raises: 375 ValueError: If both checkpoint_dir and checkpoint_filename_with_path are 376 set. 377 """ 378 379 sess, is_loaded_from_checkpoint = self._restore_checkpoint( 380 master, 381 saver, 382 checkpoint_dir=checkpoint_dir, 383 checkpoint_filename_with_path=checkpoint_filename_with_path, 384 wait_for_checkpoint=wait_for_checkpoint, 385 max_wait_secs=max_wait_secs, 386 config=config) 387 388 # Always try to run local_init_op 389 local_init_success, msg = self._try_run_local_init_op(sess) 390 391 if not is_loaded_from_checkpoint: 392 # Do not need to run checks for readiness 393 return sess, False 394 395 restoring_file = checkpoint_dir or checkpoint_filename_with_path 396 if not local_init_success: 397 logging.info( 398 "Restoring model from %s did not make model ready for local init:" 399 " %s", restoring_file, msg) 400 return sess, False 401 402 is_ready, msg = self._model_ready(sess) 403 if not is_ready: 404 logging.info("Restoring model from %s did not make model ready: %s", 405 restoring_file, msg) 406 return sess, False 407 408 logging.info("Restored model from %s", restoring_file) 409 return sess, is_loaded_from_checkpoint 410 411 def wait_for_session(self, master, config=None, max_wait_secs=float("Inf")): 412 """Creates a new `Session` and waits for model to be ready. 413 414 Creates a new `Session` on 'master'. Waits for the model to be 415 initialized or recovered from a checkpoint. It's expected that 416 another thread or process will make the model ready, and that this 417 is intended to be used by threads/processes that participate in a 418 distributed training configuration where a different thread/process 419 is responsible for initializing or recovering the model being trained. 420 421 NB: The amount of time this method waits for the session is bounded 422 by max_wait_secs. By default, this function will wait indefinitely. 423 424 Args: 425 master: `String` representation of the TensorFlow master to use. 426 config: Optional ConfigProto proto used to configure the session. 427 max_wait_secs: Maximum time to wait for the session to become available. 428 429 Returns: 430 A `Session`. May be None if the operation exceeds the timeout 431 specified by config.operation_timeout_in_ms. 432 433 Raises: 434 tf.DeadlineExceededError: if the session is not available after 435 max_wait_secs. 436 """ 437 self._target = master 438 439 if max_wait_secs is None: 440 max_wait_secs = float("Inf") 441 timer = _CountDownTimer(max_wait_secs) 442 443 while True: 444 sess = session.Session(self._target, graph=self._graph, config=config) 445 not_ready_msg = None 446 not_ready_local_msg = None 447 local_init_success, not_ready_local_msg = self._try_run_local_init_op( 448 sess) 449 if local_init_success: 450 # Successful if local_init_op is None, or ready_for_local_init_op passes 451 is_ready, not_ready_msg = self._model_ready(sess) 452 if is_ready: 453 return sess 454 455 self._safe_close(sess) 456 457 # Do we have enough time left to try again? 458 remaining_ms_after_wait = ( 459 timer.secs_remaining() - self._recovery_wait_secs) 460 if remaining_ms_after_wait < 0: 461 raise errors.DeadlineExceededError( 462 None, None, 463 "Session was not ready after waiting %d secs." % (max_wait_secs,)) 464 465 logging.info("Waiting for model to be ready. " 466 "Ready_for_local_init_op: %s, ready: %s", 467 not_ready_local_msg, not_ready_msg) 468 time.sleep(self._recovery_wait_secs) 469 470 def _safe_close(self, sess): 471 """Closes a session without raising an exception. 472 473 Just like sess.close() but ignores exceptions. 474 475 Args: 476 sess: A `Session`. 477 """ 478 # pylint: disable=broad-except 479 try: 480 sess.close() 481 except Exception: 482 # Intentionally not logging to avoid user complaints that 483 # they get cryptic errors. We really do not care that Close 484 # fails. 485 pass 486 # pylint: enable=broad-except 487 488 def _model_ready(self, sess): 489 """Checks if the model is ready or not. 490 491 Args: 492 sess: A `Session`. 493 494 Returns: 495 A tuple (is_ready, msg), where is_ready is True if ready and False 496 otherwise, and msg is `None` if the model is ready, a `String` with the 497 reason why it is not ready otherwise. 498 """ 499 return _ready(self._ready_op, sess, "Model not ready") 500 501 def _model_ready_for_local_init(self, sess): 502 """Checks if the model is ready to run local_init_op. 503 504 Args: 505 sess: A `Session`. 506 507 Returns: 508 A tuple (is_ready, msg), where is_ready is True if ready to run 509 local_init_op and False otherwise, and msg is `None` if the model is 510 ready to run local_init_op, a `String` with the reason why it is not ready 511 otherwise. 512 """ 513 return _ready(self._ready_for_local_init_op, sess, 514 "Model not ready for local init") 515 516 def _try_run_local_init_op(self, sess): 517 """Tries to run _local_init_op, if not None, and is ready for local init. 518 519 Args: 520 sess: A `Session`. 521 522 Returns: 523 A tuple (is_successful, msg), where is_successful is True if 524 _local_init_op is None, or we ran _local_init_op, and False otherwise; 525 and msg is a `String` with the reason why the model was not ready to run 526 local init. 527 """ 528 if self._local_init_op is not None: 529 is_ready_for_local_init, msg = self._model_ready_for_local_init(sess) 530 if is_ready_for_local_init: 531 logging.info("Running local_init_op.") 532 sess.run(self._local_init_op, feed_dict=self._local_init_feed_dict, 533 options=self._local_init_run_options) 534 logging.info("Done running local_init_op.") 535 return True, None 536 else: 537 return False, msg 538 return True, None 539 540 541def _ready(op, sess, msg): 542 """Checks if the model is ready or not, as determined by op. 543 544 Args: 545 op: An op, either _ready_op or _ready_for_local_init_op, which defines the 546 readiness of the model. 547 sess: A `Session`. 548 msg: A message to log to warning if not ready 549 550 Returns: 551 A tuple (is_ready, msg), where is_ready is True if ready and False 552 otherwise, and msg is `None` if the model is ready, a `String` with the 553 reason why it is not ready otherwise. 554 """ 555 if op is None: 556 return True, None 557 else: 558 try: 559 ready_value = sess.run(op) 560 # The model is considered ready if ready_op returns an empty 1-D tensor. 561 # Also compare to `None` and dtype being int32 for backward 562 # compatibility. 563 if (ready_value is None or ready_value.dtype == np.int32 or 564 ready_value.size == 0): 565 return True, None 566 else: 567 # TODO(sherrym): If a custom ready_op returns other types of tensor, 568 # or strings other than variable names, this message could be 569 # confusing. 570 non_initialized_varnames = ", ".join( 571 [i.decode("utf-8") for i in ready_value]) 572 return False, "Variables not initialized: " + non_initialized_varnames 573 except errors.FailedPreconditionError as e: 574 if "uninitialized" not in str(e): 575 logging.warning("%s : error [%s]", msg, str(e)) 576 raise e 577 return False, str(e) 578 579 580class _CountDownTimer(object): 581 582 __slots__ = ["_start_time_secs", "_duration_secs"] 583 584 def __init__(self, duration_secs): 585 self._start_time_secs = time.time() 586 self._duration_secs = duration_secs 587 588 def secs_remaining(self): 589 diff = self._duration_secs - (time.time() - self._start_time_secs) 590 return max(0, diff) 591