1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Training helper that checkpoints models and creates session.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import time 21import numpy as np 22 23from tensorflow.python.client import session 24from tensorflow.python.framework import errors 25from tensorflow.python.framework import ops 26from tensorflow.python.platform import tf_logging as logging 27from tensorflow.python.training import saver as saver_mod 28from tensorflow.python.util.tf_export import tf_export 29 30 31def _maybe_name(obj): 32 """Returns object name if it has one, or a message otherwise. 33 34 This is useful for names that apper in error messages. 35 Args: 36 obj: Object to get the name of. 37 Returns: 38 name, "None", or a "no name" message. 39 """ 40 if obj is None: 41 return "None" 42 elif hasattr(obj, "name"): 43 return obj.name 44 else: 45 return "<no name for %s>" % type(obj) 46 47 48@tf_export("train.SessionManager") 49class SessionManager(object): 50 """Training helper that restores from checkpoint and creates session. 51 52 This class is a small wrapper that takes care of session creation and 53 checkpoint recovery. It also provides functions that to facilitate 54 coordination among multiple training threads or processes. 55 56 * Checkpointing trained variables as the training progresses. 57 * Initializing variables on startup, restoring them from the most recent 58 checkpoint after a crash, or wait for checkpoints to become available. 59 60 ### Usage: 61 62 ```python 63 with tf.Graph().as_default(): 64 ...add operations to the graph... 65 # Create a SessionManager that will checkpoint the model in '/tmp/mydir'. 66 sm = SessionManager() 67 sess = sm.prepare_session(master, init_op, saver, checkpoint_dir) 68 # Use the session to train the graph. 69 while True: 70 sess.run(<my_train_op>) 71 ``` 72 73 `prepare_session()` initializes or restores a model. It requires `init_op` 74 and `saver` as an argument. 75 76 A second process could wait for the model to be ready by doing the following: 77 78 ```python 79 with tf.Graph().as_default(): 80 ...add operations to the graph... 81 # Create a SessionManager that will wait for the model to become ready. 82 sm = SessionManager() 83 sess = sm.wait_for_session(master) 84 # Use the session to train the graph. 85 while True: 86 sess.run(<my_train_op>) 87 ``` 88 89 `wait_for_session()` waits for a model to be initialized by other processes. 90 91 """ 92 93 def __init__(self, 94 local_init_op=None, 95 ready_op=None, 96 ready_for_local_init_op=None, 97 graph=None, 98 recovery_wait_secs=30): 99 """Creates a SessionManager. 100 101 The `local_init_op` is an `Operation` that is run always after a new session 102 was created. If `None`, this step is skipped. 103 104 The `ready_op` is an `Operation` used to check if the model is ready. The 105 model is considered ready if that operation returns an empty 1D string 106 tensor. If the operation returns a non empty 1D string tensor, the elements 107 are concatenated and used to indicate to the user why the model is not 108 ready. 109 110 The `ready_for_local_init_op` is an `Operation` used to check if the model 111 is ready to run local_init_op. The model is considered ready if that 112 operation returns an empty 1D string tensor. If the operation returns a non 113 empty 1D string tensor, the elements are concatenated and used to indicate 114 to the user why the model is not ready. 115 116 If `ready_op` is `None`, the model is not checked for readiness. 117 118 `recovery_wait_secs` is the number of seconds between checks that 119 the model is ready. It is used by processes to wait for a model to 120 be initialized or restored. Defaults to 30 seconds. 121 122 Args: 123 local_init_op: An `Operation` run immediately after session creation. 124 Usually used to initialize tables and local variables. 125 ready_op: An `Operation` to check if the model is initialized. 126 ready_for_local_init_op: An `Operation` to check if the model is ready 127 to run local_init_op. 128 graph: The `Graph` that the model will use. 129 recovery_wait_secs: Seconds between checks for the model to be ready. 130 131 Raises: 132 ValueError: If ready_for_local_init_op is not None but local_init_op is 133 None 134 """ 135 # Sets default values of arguments. 136 if graph is None: 137 graph = ops.get_default_graph() 138 self._local_init_op = local_init_op 139 self._ready_op = ready_op 140 self._ready_for_local_init_op = ready_for_local_init_op 141 self._graph = graph 142 self._recovery_wait_secs = recovery_wait_secs 143 self._target = None 144 if ready_for_local_init_op is not None and local_init_op is None: 145 raise ValueError("If you pass a ready_for_local_init_op " 146 "you must also pass a local_init_op " 147 ", ready_for_local_init_op [%s]" % 148 ready_for_local_init_op) 149 150 def _restore_checkpoint(self, 151 master, 152 saver=None, 153 checkpoint_dir=None, 154 checkpoint_filename_with_path=None, 155 wait_for_checkpoint=False, 156 max_wait_secs=7200, 157 config=None): 158 """Creates a `Session`, and tries to restore a checkpoint. 159 160 161 Args: 162 master: `String` representation of the TensorFlow master to use. 163 saver: A `Saver` object used to restore a model. 164 checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the 165 dir will be used to restore. 166 checkpoint_filename_with_path: Full file name path to the checkpoint file. 167 wait_for_checkpoint: Whether to wait for checkpoint to become available. 168 max_wait_secs: Maximum time to wait for checkpoints to become available. 169 config: Optional `ConfigProto` proto used to configure the session. 170 171 Returns: 172 A pair (sess, is_restored) where 'is_restored' is `True` if 173 the session could be restored, `False` otherwise. 174 175 Raises: 176 ValueError: If both checkpoint_dir and checkpoint_filename_with_path are 177 set. 178 """ 179 self._target = master 180 sess = session.Session(self._target, graph=self._graph, config=config) 181 182 if checkpoint_dir and checkpoint_filename_with_path: 183 raise ValueError("Can not provide both checkpoint_dir and " 184 "checkpoint_filename_with_path.") 185 # If either saver or checkpoint_* is not specified, cannot restore. Just 186 # return. 187 if not saver or not (checkpoint_dir or checkpoint_filename_with_path): 188 return sess, False 189 190 if checkpoint_filename_with_path: 191 saver.restore(sess, checkpoint_filename_with_path) 192 return sess, True 193 194 # Waits up until max_wait_secs for checkpoint to become available. 195 wait_time = 0 196 ckpt = saver_mod.get_checkpoint_state(checkpoint_dir) 197 while not ckpt or not ckpt.model_checkpoint_path: 198 if wait_for_checkpoint and wait_time < max_wait_secs: 199 logging.info("Waiting for checkpoint to be available.") 200 time.sleep(self._recovery_wait_secs) 201 wait_time += self._recovery_wait_secs 202 ckpt = saver_mod.get_checkpoint_state(checkpoint_dir) 203 else: 204 return sess, False 205 206 # Loads the checkpoint. 207 saver.restore(sess, ckpt.model_checkpoint_path) 208 saver.recover_last_checkpoints(ckpt.all_model_checkpoint_paths) 209 return sess, True 210 211 def prepare_session(self, 212 master, 213 init_op=None, 214 saver=None, 215 checkpoint_dir=None, 216 checkpoint_filename_with_path=None, 217 wait_for_checkpoint=False, 218 max_wait_secs=7200, 219 config=None, 220 init_feed_dict=None, 221 init_fn=None): 222 """Creates a `Session`. Makes sure the model is ready to be used. 223 224 Creates a `Session` on 'master'. If a `saver` object is passed in, and 225 `checkpoint_dir` points to a directory containing valid checkpoint 226 files, then it will try to recover the model from checkpoint. If 227 no checkpoint files are available, and `wait_for_checkpoint` is 228 `True`, then the process would check every `recovery_wait_secs`, 229 up to `max_wait_secs`, for recovery to succeed. 230 231 If the model cannot be recovered successfully then it is initialized by 232 either running the provided `init_op`, or calling the provided `init_fn`. 233 The local_init_op is also run after init_op and init_fn, regardless of 234 whether the model was recovered successfully, but only if 235 ready_for_local_init_op passes. 236 237 It is an error if the model cannot be recovered and no `init_op` 238 or `init_fn` or `local_init_op` are passed. 239 240 Args: 241 master: `String` representation of the TensorFlow master to use. 242 init_op: Optional `Operation` used to initialize the model. 243 saver: A `Saver` object used to restore a model. 244 checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the 245 dir will be used to restore. 246 checkpoint_filename_with_path: Full file name path to the checkpoint file. 247 wait_for_checkpoint: Whether to wait for checkpoint to become available. 248 max_wait_secs: Maximum time to wait for checkpoints to become available. 249 config: Optional `ConfigProto` proto used to configure the session. 250 init_feed_dict: Optional dictionary that maps `Tensor` objects to feed 251 values. This feed dictionary is passed to the session `run()` call when 252 running the init op. 253 init_fn: Optional callable used to initialize the model. Called after the 254 optional `init_op` is called. The callable must accept one argument, 255 the session being initialized. 256 257 Returns: 258 A `Session` object that can be used to drive the model. 259 260 Raises: 261 RuntimeError: If the model cannot be initialized or recovered. 262 263 Raises: 264 ValueError: If both checkpoint_dir and checkpoint_filename_with_path are 265 set. 266 """ 267 268 sess, is_loaded_from_checkpoint = self._restore_checkpoint( 269 master, 270 saver, 271 checkpoint_dir=checkpoint_dir, 272 checkpoint_filename_with_path=checkpoint_filename_with_path, 273 wait_for_checkpoint=wait_for_checkpoint, 274 max_wait_secs=max_wait_secs, 275 config=config) 276 if not is_loaded_from_checkpoint: 277 if init_op is None and not init_fn and self._local_init_op is None: 278 raise RuntimeError("Model is not initialized and no init_op or " 279 "init_fn or local_init_op was given") 280 if init_op is not None: 281 sess.run(init_op, feed_dict=init_feed_dict) 282 if init_fn: 283 init_fn(sess) 284 285 local_init_success, msg = self._try_run_local_init_op(sess) 286 if not local_init_success: 287 raise RuntimeError( 288 "Init operations did not make model ready for local_init. " 289 "Init op: %s, init fn: %s, error: %s" % (_maybe_name(init_op), 290 init_fn, 291 msg)) 292 293 is_ready, msg = self._model_ready(sess) 294 if not is_ready: 295 raise RuntimeError( 296 "Init operations did not make model ready. " 297 "Init op: %s, init fn: %s, local_init_op: %s, error: %s" % 298 (_maybe_name(init_op), init_fn, self._local_init_op, msg)) 299 return sess 300 301 def recover_session(self, 302 master, 303 saver=None, 304 checkpoint_dir=None, 305 checkpoint_filename_with_path=None, 306 wait_for_checkpoint=False, 307 max_wait_secs=7200, 308 config=None): 309 """Creates a `Session`, recovering if possible. 310 311 Creates a new session on 'master'. If the session is not initialized 312 and can be recovered from a checkpoint, recover it. 313 314 Args: 315 master: `String` representation of the TensorFlow master to use. 316 saver: A `Saver` object used to restore a model. 317 checkpoint_dir: Path to the checkpoint files. The latest checkpoint in the 318 dir will be used to restore. 319 checkpoint_filename_with_path: Full file name path to the checkpoint file. 320 wait_for_checkpoint: Whether to wait for checkpoint to become available. 321 max_wait_secs: Maximum time to wait for checkpoints to become available. 322 config: Optional `ConfigProto` proto used to configure the session. 323 324 Returns: 325 A pair (sess, initialized) where 'initialized' is `True` if 326 the session could be recovered and initialized, `False` otherwise. 327 328 Raises: 329 ValueError: If both checkpoint_dir and checkpoint_filename_with_path are 330 set. 331 """ 332 333 sess, is_loaded_from_checkpoint = self._restore_checkpoint( 334 master, 335 saver, 336 checkpoint_dir=checkpoint_dir, 337 checkpoint_filename_with_path=checkpoint_filename_with_path, 338 wait_for_checkpoint=wait_for_checkpoint, 339 max_wait_secs=max_wait_secs, 340 config=config) 341 342 # Always try to run local_init_op 343 local_init_success, msg = self._try_run_local_init_op(sess) 344 345 if not is_loaded_from_checkpoint: 346 # Do not need to run checks for readiness 347 return sess, False 348 349 restoring_file = checkpoint_dir or checkpoint_filename_with_path 350 if not local_init_success: 351 logging.info( 352 "Restoring model from %s did not make model ready for local init:" 353 " %s", restoring_file, msg) 354 return sess, False 355 356 is_ready, msg = self._model_ready(sess) 357 if not is_ready: 358 logging.info("Restoring model from %s did not make model ready: %s", 359 restoring_file, msg) 360 return sess, False 361 362 logging.info("Restored model from %s", restoring_file) 363 return sess, is_loaded_from_checkpoint 364 365 def wait_for_session(self, master, config=None, max_wait_secs=float("Inf")): 366 """Creates a new `Session` and waits for model to be ready. 367 368 Creates a new `Session` on 'master'. Waits for the model to be 369 initialized or recovered from a checkpoint. It's expected that 370 another thread or process will make the model ready, and that this 371 is intended to be used by threads/processes that participate in a 372 distributed training configuration where a different thread/process 373 is responsible for initializing or recovering the model being trained. 374 375 NB: The amount of time this method waits for the session is bounded 376 by max_wait_secs. By default, this function will wait indefinitely. 377 378 Args: 379 master: `String` representation of the TensorFlow master to use. 380 config: Optional ConfigProto proto used to configure the session. 381 max_wait_secs: Maximum time to wait for the session to become available. 382 383 Returns: 384 A `Session`. May be None if the operation exceeds the timeout 385 specified by config.operation_timeout_in_ms. 386 387 Raises: 388 tf.DeadlineExceededError: if the session is not available after 389 max_wait_secs. 390 """ 391 self._target = master 392 393 if max_wait_secs is None: 394 max_wait_secs = float("Inf") 395 timer = _CountDownTimer(max_wait_secs) 396 397 while True: 398 sess = session.Session(self._target, graph=self._graph, config=config) 399 not_ready_msg = None 400 not_ready_local_msg = None 401 local_init_success, not_ready_local_msg = self._try_run_local_init_op( 402 sess) 403 if local_init_success: 404 # Successful if local_init_op is None, or ready_for_local_init_op passes 405 is_ready, not_ready_msg = self._model_ready(sess) 406 if is_ready: 407 return sess 408 409 self._safe_close(sess) 410 411 # Do we have enough time left to try again? 412 remaining_ms_after_wait = ( 413 timer.secs_remaining() - self._recovery_wait_secs) 414 if remaining_ms_after_wait < 0: 415 raise errors.DeadlineExceededError( 416 None, None, 417 "Session was not ready after waiting %d secs." % (max_wait_secs,)) 418 419 logging.info("Waiting for model to be ready. " 420 "Ready_for_local_init_op: %s, ready: %s", 421 not_ready_local_msg, not_ready_msg) 422 time.sleep(self._recovery_wait_secs) 423 424 def _safe_close(self, sess): 425 """Closes a session without raising an exception. 426 427 Just like sess.close() but ignores exceptions. 428 429 Args: 430 sess: A `Session`. 431 """ 432 # pylint: disable=broad-except 433 try: 434 sess.close() 435 except Exception: 436 # Intentionally not logging to avoid user complaints that 437 # they get cryptic errors. We really do not care that Close 438 # fails. 439 pass 440 # pylint: enable=broad-except 441 442 def _model_ready(self, sess): 443 """Checks if the model is ready or not. 444 445 Args: 446 sess: A `Session`. 447 448 Returns: 449 A tuple (is_ready, msg), where is_ready is True if ready and False 450 otherwise, and msg is `None` if the model is ready, a `String` with the 451 reason why it is not ready otherwise. 452 """ 453 return _ready(self._ready_op, sess, "Model not ready") 454 455 def _model_ready_for_local_init(self, sess): 456 """Checks if the model is ready to run local_init_op. 457 458 Args: 459 sess: A `Session`. 460 461 Returns: 462 A tuple (is_ready, msg), where is_ready is True if ready to run 463 local_init_op and False otherwise, and msg is `None` if the model is 464 ready to run local_init_op, a `String` with the reason why it is not ready 465 otherwise. 466 """ 467 return _ready(self._ready_for_local_init_op, sess, 468 "Model not ready for local init") 469 470 def _try_run_local_init_op(self, sess): 471 """Tries to run _local_init_op, if not None, and is ready for local init. 472 473 Args: 474 sess: A `Session`. 475 476 Returns: 477 A tuple (is_successful, msg), where is_successful is True if 478 _local_init_op is None, or we ran _local_init_op, and False otherwise; 479 and msg is a `String` with the reason why the model was not ready to run 480 local init. 481 """ 482 if self._local_init_op is not None: 483 is_ready_for_local_init, msg = self._model_ready_for_local_init(sess) 484 if is_ready_for_local_init: 485 logging.info("Running local_init_op.") 486 sess.run(self._local_init_op) 487 logging.info("Done running local_init_op.") 488 return True, None 489 else: 490 return False, msg 491 return True, None 492 493 494def _ready(op, sess, msg): 495 """Checks if the model is ready or not, as determined by op. 496 497 Args: 498 op: An op, either _ready_op or _ready_for_local_init_op, which defines the 499 readiness of the model. 500 sess: A `Session`. 501 msg: A message to log to warning if not ready 502 503 Returns: 504 A tuple (is_ready, msg), where is_ready is True if ready and False 505 otherwise, and msg is `None` if the model is ready, a `String` with the 506 reason why it is not ready otherwise. 507 """ 508 if op is None: 509 return True, None 510 else: 511 try: 512 ready_value = sess.run(op) 513 # The model is considered ready if ready_op returns an empty 1-D tensor. 514 # Also compare to `None` and dtype being int32 for backward 515 # compatibility. 516 if (ready_value is None or ready_value.dtype == np.int32 or 517 ready_value.size == 0): 518 return True, None 519 else: 520 # TODO(sherrym): If a custom ready_op returns other types of tensor, 521 # or strings other than variable names, this message could be 522 # confusing. 523 non_initialized_varnames = ", ".join( 524 [i.decode("utf-8") for i in ready_value]) 525 return False, "Variables not initialized: " + non_initialized_varnames 526 except errors.FailedPreconditionError as e: 527 if "uninitialized" not in str(e): 528 logging.warning("%s : error [%s]", msg, str(e)) 529 raise e 530 return False, str(e) 531 532 533class _CountDownTimer(object): 534 535 def __init__(self, duration_secs): 536 self._start_time_secs = time.time() 537 self._duration_secs = duration_secs 538 539 def secs_remaining(self): 540 diff = self._duration_secs - (time.time() - self._start_time_secs) 541 return max(0, diff) 542