1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the 'License'); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an 'AS IS' BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ====================================== 15"""Hook for asynchronous checkpointing. 16 17This hook dispatches checkpoint writing operations in a separate thread to 18allow execution to continue on the main thread. 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import os 26import threading 27import time 28 29from tensorflow.core.util import event_pb2 30from tensorflow.python.framework import meta_graph 31from tensorflow.python.framework import ops 32from tensorflow.python.platform import tf_logging as logging 33from tensorflow.python.training import basic_session_run_hooks 34from tensorflow.python.training import training_util 35from tensorflow.python.training.session_run_hook import SessionRunArgs 36from tensorflow.python.training.summary_io import SummaryWriterCache 37 38 39class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook): 40 """Saves checkpoints every N steps or seconds.""" 41 42 def __init__(self, 43 checkpoint_dir, 44 save_secs=None, 45 save_steps=None, 46 saver=None, 47 checkpoint_basename="model.ckpt", 48 scaffold=None, 49 listeners=None): 50 """Initializes a `CheckpointSaverHook`. 51 52 Args: 53 checkpoint_dir: `str`, base directory for the checkpoint files. 54 save_secs: `int`, save every N secs. 55 save_steps: `int`, save every N steps. 56 saver: `Saver` object, used for saving. 57 checkpoint_basename: `str`, base name for the checkpoint files. 58 scaffold: `Scaffold`, use to get saver object. 59 listeners: List of `CheckpointSaverListener` subclass instances. Used for 60 callbacks that run immediately before or after this hook saves the 61 checkpoint. 62 63 Raises: 64 ValueError: One of `save_steps` or `save_secs` should be set. 65 ValueError: At most one of `saver` or `scaffold` should be set. 66 """ 67 save_path = os.path.join(checkpoint_dir, checkpoint_basename) 68 logging.info("Create AsyncCheckpointSaverHook saving to path\n%s", 69 save_path) 70 if listeners: 71 logging.info(" with %d listener(s).", len(listeners)) 72 if saver is not None and scaffold is not None: 73 raise ValueError("You cannot provide both saver and scaffold.") 74 self._saver = saver 75 self._save_thread = None 76 self._write_graph_thread = None 77 self._checkpoint_dir = checkpoint_dir 78 self._save_path = save_path 79 self._scaffold = scaffold 80 self._timer = basic_session_run_hooks.SecondOrStepTimer( 81 every_secs=save_secs, every_steps=save_steps) 82 self._listeners = listeners or [] 83 self._steps_per_run = 1 84 self._summary_writer = None 85 self._global_step_tensor = None 86 87 self._last_checkpoint_step = None 88 89 def _set_steps_per_run(self, steps_per_run): 90 self._steps_per_run = steps_per_run 91 92 def begin(self): 93 self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir) 94 self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access 95 if self._global_step_tensor is None: 96 raise RuntimeError( 97 "Global step should be created to use CheckpointSaverHook.") 98 for l in self._listeners: 99 l.begin() 100 101 def after_create_session(self, session, coord): 102 global_step = session.run(self._global_step_tensor) 103 104 # We do write graph and saver_def at the first call of before_run. 105 # We cannot do this in begin, since we let other hooks to change graph and 106 # add variables in begin. Graph is finalized after all begin calls. 107 def _write_graph_fn(self): 108 training_util.write_graph( 109 ops.get_default_graph().as_graph_def(add_shapes=True), 110 self._checkpoint_dir, "graph.pbtxt") 111 self._write_graph_thread = threading.Thread(target=_write_graph_fn, 112 args=[self]) 113 self._write_graph_thread.start() 114 115 saver_def = self._get_saver().saver_def if self._get_saver() else None 116 graph = ops.get_default_graph() 117 meta_graph_def = meta_graph.create_meta_graph_def( 118 graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def) 119 self._summary_writer.add_graph(graph) 120 self._summary_writer.add_meta_graph(meta_graph_def) 121 # The checkpoint saved here is the state at step "global_step". 122 self._save(session, global_step) 123 self._timer.update_last_triggered_step(global_step) 124 125 def before_run(self, run_context): # pylint: disable=unused-argument 126 return SessionRunArgs(self._global_step_tensor) 127 128 def after_run(self, run_context, run_values): 129 global_step = run_context.session.run(self._global_step_tensor) 130 if self._timer.should_trigger_for_step(global_step): 131 self._timer.update_last_triggered_step(global_step) 132 logging.info("Triggering checkpoint. %s", global_step) 133 if self._save(run_context.session, global_step): 134 run_context.request_stop() 135 136 def end(self, session): 137 if self._save_thread: 138 logging.info("Waiting for any pending checkpoints to finish.") 139 self._save_thread.join() 140 if self._write_graph_thread: 141 logging.info("Waiting for any pending write_graph to finish.") 142 self._write_graph_thread.join() 143 144 last_step = session.run(self._global_step_tensor) 145 146 if self._last_checkpoint_step != last_step: 147 self._save(session, last_step, asynchronous=False) 148 149 for l in self._listeners: 150 l.end(session, last_step) 151 152 def _save(self, session, step, asynchronous=True): 153 """Saves the latest checkpoint, returns should_stop.""" 154 155 def _save_fn(): 156 """Run the saver process.""" 157 logging.info("Saving checkpoints for %d into %s.", step, self._save_path) 158 159 start_time = time.time() 160 for l in self._listeners: 161 l.before_save(session, step) 162 163 self._get_saver().save(session, self._save_path, global_step=step) 164 self._summary_writer.add_session_log( 165 event_pb2.SessionLog( 166 status=event_pb2.SessionLog.CHECKPOINT, 167 checkpoint_path=self._save_path), step) 168 169 for l in self._listeners: 170 l.after_save(session, step) 171 172 end_time = time.time() 173 logging.info("Checkpoint actual writing time: (%.3f sec)", 174 end_time - start_time) 175 logging.info("Checkpoint finished for %d into %s.", step, self._save_path) 176 177 if not asynchronous: 178 self._last_checkpoint_step = step 179 _save_fn() 180 return 181 182 if self._save_thread is not None: 183 self._save_thread.join(timeout=0.1) 184 if self._save_thread.is_alive(): 185 logging.info("Saver thread still in progress, skipping checkpoint.") 186 return 187 188 self._last_checkpoint_step = step 189 self._save_thread = threading.Thread(target=_save_fn) 190 self._save_thread.start() 191 192 def _get_saver(self): 193 if self._saver is not None: 194 return self._saver 195 elif self._scaffold is not None: 196 return self._scaffold.saver 197 198 # Get saver from the SAVERS collection if present. 199 collection_key = ops.GraphKeys.SAVERS 200 savers = ops.get_collection(collection_key) 201 if not savers: 202 raise RuntimeError( 203 "No items in collection {}. Please add a saver to the collection " 204 "or provide a saver or scaffold.".format(collection_key)) 205 elif len(savers) > 1: 206 raise RuntimeError( 207 "More than one item in collection {}. " 208 "Please indicate which one to use by passing it to the constructor." 209 .format(collection_key)) 210 211 self._saver = savers[0] 212 return savers[0] 213