1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Training helper that checkpoints models and creates session.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import time 21import numpy as np 22 23from tensorflow.python.client import session 24from tensorflow.python.distribute import distribution_strategy_context 25from tensorflow.python.framework import errors 26from tensorflow.python.framework import ops 27from tensorflow.python.platform import tf_logging as logging 28from tensorflow.python.training import checkpoint_management 29from tensorflow.python.util.tf_export import tf_export 30 31 32def _maybe_name(obj): 33 """Returns object name if it has one, or a message otherwise. 34 35 This is useful for names that apper in error messages. 36 Args: 37 obj: Object to get the name of. 38 Returns: 39 name, "None", or a "no name" message. 40 """ 41 if obj is None: 42 return "None" 43 elif hasattr(obj, "name"): 44 return obj.name 45 else: 46 return "<no name for %s>" % type(obj) 47 48 49@tf_export(v1=["train.SessionManager"]) 50class SessionManager(object): 51 """Training helper that restores from checkpoint and creates session. 52 53 This class is a small wrapper that takes care of session creation and 54 checkpoint recovery. It also provides functions that to facilitate 55 coordination among multiple training threads or processes. 56 57 * Checkpointing trained variables as the training progresses. 58 * Initializing variables on startup, restoring them from the most recent 59 checkpoint after a crash, or wait for checkpoints to become available. 60 61 ### Usage: 62 63 ```python 64 with tf.Graph().as_default(): 65 ...add operations to the graph... 66 # Create a SessionManager that will checkpoint the model in '/tmp/mydir'. 67 sm = SessionManager() 68 sess = sm.prepare_session(master, init_op, saver, checkpoint_dir) 69 # Use the session to train the graph. 70 while True: 71 sess.run(<my_train_op>) 72 ``` 73 74 `prepare_session()` initializes or restores a model. It requires `init_op` 75 and `saver` as an argument. 76 77 A second process could wait for the model to be ready by doing the following: 78 79 ```python 80 with tf.Graph().as_default(): 81 ...add operations to the graph... 82 # Create a SessionManager that will wait for the model to become ready. 83 sm = SessionManager() 84 sess = sm.wait_for_session(master) 85 # Use the session to train the graph. 86 while True: 87 sess.run(<my_train_op>) 88 ``` 89 90 `wait_for_session()` waits for a model to be initialized by other processes. 91 92 """ 93 94 def __init__(self, 95 local_init_op=None, 96 ready_op=None, 97 ready_for_local_init_op=None, 98 graph=None, 99 recovery_wait_secs=30, 100 local_init_run_options=None): 101 """Creates a SessionManager. 102 103 The `local_init_op` is an `Operation` that is run always after a new session 104 was created. If `None`, this step is skipped. 105 106 The `ready_op` is an `Operation` used to check if the model is ready. The 107 model is considered ready if that operation returns an empty 1D string 108 tensor. If the operation returns a non empty 1D string tensor, the elements 109 are concatenated and used to indicate to the user why the model is not 110 ready. 111 112 The `ready_for_local_init_op` is an `Operation` used to check if the model 113 is ready to run local_init_op. The model is considered ready if that 114 operation returns an empty 1D string tensor. If the operation returns a non 115 empty 1D string tensor, the elements are concatenated and used to indicate 116 to the user why the model is not ready. 117 118 If `ready_op` is `None`, the model is not checked for readiness. 119 120 `recovery_wait_secs` is the number of seconds between checks that 121 the model is ready. It is used by processes to wait for a model to 122 be initialized or restored. Defaults to 30 seconds. 123 124 Args: 125 local_init_op: An `Operation` run immediately after session creation. 126 Usually used to initialize tables and local variables. 127 ready_op: An `Operation` to check if the model is initialized. 128 ready_for_local_init_op: An `Operation` to check if the model is ready 129 to run local_init_op. 130 graph: The `Graph` that the model will use. 131 recovery_wait_secs: Seconds between checks for the model to be ready. 132 local_init_run_options: RunOptions to be passed to session.run when 133 executing the local_init_op. 134 135 Raises: 136 ValueError: If ready_for_local_init_op is not None but local_init_op is 137 None 138 """ 139 # Sets default values of arguments. 140 if graph is None: 141 graph = ops.get_default_graph() 142 self._local_init_op = local_init_op 143 self._ready_op = ready_op 144 self._ready_for_local_init_op = ready_for_local_init_op 145 self._graph = graph 146 self._recovery_wait_secs = recovery_wait_secs 147 self._target = None 148 self._local_init_run_options = local_init_run_options 149 if ready_for_local_init_op is not None and local_init_op is None: 150 raise ValueError("If you pass a ready_for_local_init_op " 151 "you must also pass a local_init_op " 152 ", ready_for_local_init_op [%s]" % 153 ready_for_local_init_op) 154 155 def _restore_checkpoint(self, 156 master, 157 saver=None, 158 checkpoint_dir=None, 159 checkpoint_filename_with_path=None, 160 wait_for_checkpoint=False, 161 max_wait_secs=7200, 162 config=None): 163 """Creates a `Session`, and tries to restore a checkpoint. 164 165 166 Args: 167 master: `String` representation of the TensorFlow master to use. 168 saver: A `Saver` object used to restore a model. 169 checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the 170 dir will be used to restore. 171 checkpoint_filename_with_path: Full file name path to the checkpoint file. 172 wait_for_checkpoint: Whether to wait for checkpoint to become available. 173 max_wait_secs: Maximum time to wait for checkpoints to become available. 174 config: Optional `ConfigProto` proto used to configure the session. 175 176 Returns: 177 A pair (sess, is_restored) where 'is_restored' is `True` if 178 the session could be restored, `False` otherwise. 179 180 Raises: 181 ValueError: If both checkpoint_dir and checkpoint_filename_with_path are 182 set. 183 """ 184 self._target = master 185 186 # This is required to so that we initialize the TPU device before 187 # restoring from checkpoint since we'll be placing variables on the device 188 # and TPUInitialize wipes out the memory of the device. 189 strategy = distribution_strategy_context.get_strategy() 190 if strategy and hasattr(strategy.extended, 191 "_experimental_initialize_system"): 192 strategy.extended._experimental_initialize_system() # pylint: disable=protected-access 193 194 sess = session.Session(self._target, graph=self._graph, config=config) 195 if checkpoint_dir and checkpoint_filename_with_path: 196 raise ValueError("Can not provide both checkpoint_dir and " 197 "checkpoint_filename_with_path.") 198 # If either saver or checkpoint_* is not specified, cannot restore. Just 199 # return. 200 if not saver or not (checkpoint_dir or checkpoint_filename_with_path): 201 return sess, False 202 203 if checkpoint_filename_with_path: 204 saver.restore(sess, checkpoint_filename_with_path) 205 return sess, True 206 207 # Waits up until max_wait_secs for checkpoint to become available. 208 wait_time = 0 209 ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir) 210 while not ckpt or not ckpt.model_checkpoint_path: 211 if wait_for_checkpoint and wait_time < max_wait_secs: 212 logging.info("Waiting for checkpoint to be available.") 213 time.sleep(self._recovery_wait_secs) 214 wait_time += self._recovery_wait_secs 215 ckpt = checkpoint_management.get_checkpoint_state(checkpoint_dir) 216 else: 217 return sess, False 218 219 # Loads the checkpoint. 220 saver.restore(sess, ckpt.model_checkpoint_path) 221 saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths) 222 return sess, True 223 224 def prepare_session(self, 225 master, 226 init_op=None, 227 saver=None, 228 checkpoint_dir=None, 229 checkpoint_filename_with_path=None, 230 wait_for_checkpoint=False, 231 max_wait_secs=7200, 232 config=None, 233 init_feed_dict=None, 234 init_fn=None): 235 """Creates a `Session`. Makes sure the model is ready to be used. 236 237 Creates a `Session` on 'master'. If a `saver` object is passed in, and 238 `checkpoint_dir` points to a directory containing valid checkpoint 239 files, then it will try to recover the model from checkpoint. If 240 no checkpoint files are available, and `wait_for_checkpoint` is 241 `True`, then the process would check every `recovery_wait_secs`, 242 up to `max_wait_secs`, for recovery to succeed. 243 244 If the model cannot be recovered successfully then it is initialized by 245 running the `init_op` and calling `init_fn` if they are provided. 246 The `local_init_op` is also run after init_op and init_fn, regardless of 247 whether the model was recovered successfully, but only if 248 `ready_for_local_init_op` passes. 249 250 If the model is recovered from a checkpoint it is assumed that all 251 global variables have been initialized, in particular neither `init_op` 252 nor `init_fn` will be executed. 253 254 It is an error if the model cannot be recovered and no `init_op` 255 or `init_fn` or `local_init_op` are passed. 256 257 Args: 258 master: `String` representation of the TensorFlow master to use. 259 init_op: Optional `Operation` used to initialize the model. 260 saver: A `Saver` object used to restore a model. 261 checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the 262 dir will be used to restore. 263 checkpoint_filename_with_path: Full file name path to the checkpoint file. 264 wait_for_checkpoint: Whether to wait for checkpoint to become available. 265 max_wait_secs: Maximum time to wait for checkpoints to become available. 266 config: Optional `ConfigProto` proto used to configure the session. 267 init_feed_dict: Optional dictionary that maps `Tensor` objects to feed 268 values. This feed dictionary is passed to the session `run()` call when 269 running the init op. 270 init_fn: Optional callable used to initialize the model. Called after the 271 optional `init_op` is called. The callable must accept one argument, 272 the session being initialized. 273 274 Returns: 275 A `Session` object that can be used to drive the model. 276 277 Raises: 278 RuntimeError: If the model cannot be initialized or recovered. 279 ValueError: If both checkpoint_dir and checkpoint_filename_with_path are 280 set. 281 """ 282 283 sess, is_loaded_from_checkpoint = self._restore_checkpoint( 284 master, 285 saver, 286 checkpoint_dir=checkpoint_dir, 287 checkpoint_filename_with_path=checkpoint_filename_with_path, 288 wait_for_checkpoint=wait_for_checkpoint, 289 max_wait_secs=max_wait_secs, 290 config=config) 291 if not is_loaded_from_checkpoint: 292 if init_op is None and not init_fn and self._local_init_op is None: 293 raise RuntimeError("Model is not initialized and no init_op or " 294 "init_fn or local_init_op was given") 295 if init_op is not None: 296 sess.run(init_op, feed_dict=init_feed_dict) 297 if init_fn: 298 init_fn(sess) 299 300 local_init_success, msg = self._try_run_local_init_op(sess) 301 if not local_init_success: 302 raise RuntimeError( 303 "Init operations did not make model ready for local_init. " 304 "Init op: %s, init fn: %s, error: %s" % (_maybe_name(init_op), 305 init_fn, 306 msg)) 307 308 is_ready, msg = self._model_ready(sess) 309 if not is_ready: 310 raise RuntimeError( 311 "Init operations did not make model ready. " 312 "Init op: %s, init fn: %s, local_init_op: %s, error: %s" % 313 (_maybe_name(init_op), init_fn, self._local_init_op, msg)) 314 return sess 315 316 def recover_session(self, 317 master, 318 saver=None, 319 checkpoint_dir=None, 320 checkpoint_filename_with_path=None, 321 wait_for_checkpoint=False, 322 max_wait_secs=7200, 323 config=None): 324 """Creates a `Session`, recovering if possible. 325 326 Creates a new session on 'master'. If the session is not initialized 327 and can be recovered from a checkpoint, recover it. 328 329 Args: 330 master: `String` representation of the TensorFlow master to use. 331 saver: A `Saver` object used to restore a model. 332 checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the 333 dir will be used to restore. 334 checkpoint_filename_with_path: Full file name path to the checkpoint file. 335 wait_for_checkpoint: Whether to wait for checkpoint to become available. 336 max_wait_secs: Maximum time to wait for checkpoints to become available. 337 config: Optional `ConfigProto` proto used to configure the session. 338 339 Returns: 340 A pair (sess, initialized) where 'initialized' is `True` if 341 the session could be recovered and initialized, `False` otherwise. 342 343 Raises: 344 ValueError: If both checkpoint_dir and checkpoint_filename_with_path are 345 set. 346 """ 347 348 sess, is_loaded_from_checkpoint = self._restore_checkpoint( 349 master, 350 saver, 351 checkpoint_dir=checkpoint_dir, 352 checkpoint_filename_with_path=checkpoint_filename_with_path, 353 wait_for_checkpoint=wait_for_checkpoint, 354 max_wait_secs=max_wait_secs, 355 config=config) 356 357 # Always try to run local_init_op 358 local_init_success, msg = self._try_run_local_init_op(sess) 359 360 if not is_loaded_from_checkpoint: 361 # Do not need to run checks for readiness 362 return sess, False 363 364 restoring_file = checkpoint_dir or checkpoint_filename_with_path 365 if not local_init_success: 366 logging.info( 367 "Restoring model from %s did not make model ready for local init:" 368 " %s", restoring_file, msg) 369 return sess, False 370 371 is_ready, msg = self._model_ready(sess) 372 if not is_ready: 373 logging.info("Restoring model from %s did not make model ready: %s", 374 restoring_file, msg) 375 return sess, False 376 377 logging.info("Restored model from %s", restoring_file) 378 return sess, is_loaded_from_checkpoint 379 380 def wait_for_session(self, master, config=None, max_wait_secs=float("Inf")): 381 """Creates a new `Session` and waits for model to be ready. 382 383 Creates a new `Session` on 'master'. Waits for the model to be 384 initialized or recovered from a checkpoint. It's expected that 385 another thread or process will make the model ready, and that this 386 is intended to be used by threads/processes that participate in a 387 distributed training configuration where a different thread/process 388 is responsible for initializing or recovering the model being trained. 389 390 NB: The amount of time this method waits for the session is bounded 391 by max_wait_secs. By default, this function will wait indefinitely. 392 393 Args: 394 master: `String` representation of the TensorFlow master to use. 395 config: Optional ConfigProto proto used to configure the session. 396 max_wait_secs: Maximum time to wait for the session to become available. 397 398 Returns: 399 A `Session`. May be None if the operation exceeds the timeout 400 specified by config.operation_timeout_in_ms. 401 402 Raises: 403 tf.DeadlineExceededError: if the session is not available after 404 max_wait_secs. 405 """ 406 self._target = master 407 408 if max_wait_secs is None: 409 max_wait_secs = float("Inf") 410 timer = _CountDownTimer(max_wait_secs) 411 412 while True: 413 sess = session.Session(self._target, graph=self._graph, config=config) 414 not_ready_msg = None 415 not_ready_local_msg = None 416 local_init_success, not_ready_local_msg = self._try_run_local_init_op( 417 sess) 418 if local_init_success: 419 # Successful if local_init_op is None, or ready_for_local_init_op passes 420 is_ready, not_ready_msg = self._model_ready(sess) 421 if is_ready: 422 return sess 423 424 self._safe_close(sess) 425 426 # Do we have enough time left to try again? 427 remaining_ms_after_wait = ( 428 timer.secs_remaining() - self._recovery_wait_secs) 429 if remaining_ms_after_wait < 0: 430 raise errors.DeadlineExceededError( 431 None, None, 432 "Session was not ready after waiting %d secs." % (max_wait_secs,)) 433 434 logging.info("Waiting for model to be ready. " 435 "Ready_for_local_init_op: %s, ready: %s", 436 not_ready_local_msg, not_ready_msg) 437 time.sleep(self._recovery_wait_secs) 438 439 def _safe_close(self, sess): 440 """Closes a session without raising an exception. 441 442 Just like sess.close() but ignores exceptions. 443 444 Args: 445 sess: A `Session`. 446 """ 447 # pylint: disable=broad-except 448 try: 449 sess.close() 450 except Exception: 451 # Intentionally not logging to avoid user complaints that 452 # they get cryptic errors. We really do not care that Close 453 # fails. 454 pass 455 # pylint: enable=broad-except 456 457 def _model_ready(self, sess): 458 """Checks if the model is ready or not. 459 460 Args: 461 sess: A `Session`. 462 463 Returns: 464 A tuple (is_ready, msg), where is_ready is True if ready and False 465 otherwise, and msg is `None` if the model is ready, a `String` with the 466 reason why it is not ready otherwise. 467 """ 468 return _ready(self._ready_op, sess, "Model not ready") 469 470 def _model_ready_for_local_init(self, sess): 471 """Checks if the model is ready to run local_init_op. 472 473 Args: 474 sess: A `Session`. 475 476 Returns: 477 A tuple (is_ready, msg), where is_ready is True if ready to run 478 local_init_op and False otherwise, and msg is `None` if the model is 479 ready to run local_init_op, a `String` with the reason why it is not ready 480 otherwise. 481 """ 482 return _ready(self._ready_for_local_init_op, sess, 483 "Model not ready for local init") 484 485 def _try_run_local_init_op(self, sess): 486 """Tries to run _local_init_op, if not None, and is ready for local init. 487 488 Args: 489 sess: A `Session`. 490 491 Returns: 492 A tuple (is_successful, msg), where is_successful is True if 493 _local_init_op is None, or we ran _local_init_op, and False otherwise; 494 and msg is a `String` with the reason why the model was not ready to run 495 local init. 496 """ 497 if self._local_init_op is not None: 498 is_ready_for_local_init, msg = self._model_ready_for_local_init(sess) 499 if is_ready_for_local_init: 500 logging.info("Running local_init_op.") 501 sess.run(self._local_init_op, options=self._local_init_run_options) 502 logging.info("Done running local_init_op.") 503 return True, None 504 else: 505 return False, msg 506 return True, None 507 508 509def _ready(op, sess, msg): 510 """Checks if the model is ready or not, as determined by op. 511 512 Args: 513 op: An op, either _ready_op or _ready_for_local_init_op, which defines the 514 readiness of the model. 515 sess: A `Session`. 516 msg: A message to log to warning if not ready 517 518 Returns: 519 A tuple (is_ready, msg), where is_ready is True if ready and False 520 otherwise, and msg is `None` if the model is ready, a `String` with the 521 reason why it is not ready otherwise. 522 """ 523 if op is None: 524 return True, None 525 else: 526 try: 527 ready_value = sess.run(op) 528 # The model is considered ready if ready_op returns an empty 1-D tensor. 529 # Also compare to `None` and dtype being int32 for backward 530 # compatibility. 531 if (ready_value is None or ready_value.dtype == np.int32 or 532 ready_value.size == 0): 533 return True, None 534 else: 535 # TODO(sherrym): If a custom ready_op returns other types of tensor, 536 # or strings other than variable names, this message could be 537 # confusing. 538 non_initialized_varnames = ", ".join( 539 [i.decode("utf-8") for i in ready_value]) 540 return False, "Variables not initialized: " + non_initialized_varnames 541 except errors.FailedPreconditionError as e: 542 if "uninitialized" not in str(e): 543 logging.warning("%s : error [%s]", msg, str(e)) 544 raise e 545 return False, str(e) 546 547 548class _CountDownTimer(object): 549 550 def __init__(self, duration_secs): 551 self._start_time_secs = time.time() 552 self._duration_secs = duration_secs 553 554 def secs_remaining(self): 555 diff = self._duration_secs - (time.time() - self._start_time_secs) 556 return max(0, diff) 557