• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the 'License');
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an 'AS IS' BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ======================================
15"""Hook for asynchronous checkpointing.
16
17This hook dispatches checkpoint writing operations in a separate thread to
18allow execution to continue on the main thread.
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import os
26import threading
27import time
28
29from tensorflow.core.util import event_pb2
30from tensorflow.python.framework import meta_graph
31from tensorflow.python.framework import ops
32from tensorflow.python.platform import tf_logging as logging
33from tensorflow.python.training import basic_session_run_hooks
34from tensorflow.python.training import training_util
35from tensorflow.python.training.session_run_hook import SessionRunArgs
36from tensorflow.python.training.summary_io import SummaryWriterCache
37
38
39class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook):
40  """Saves checkpoints every N steps or seconds."""
41
42  def __init__(self,
43               checkpoint_dir,
44               save_secs=None,
45               save_steps=None,
46               saver=None,
47               checkpoint_basename="model.ckpt",
48               scaffold=None,
49               listeners=None):
50    """Initializes a `CheckpointSaverHook`.
51
52    Args:
53      checkpoint_dir: `str`, base directory for the checkpoint files.
54      save_secs: `int`, save every N secs.
55      save_steps: `int`, save every N steps.
56      saver: `Saver` object, used for saving.
57      checkpoint_basename: `str`, base name for the checkpoint files.
58      scaffold: `Scaffold`, use to get saver object.
59      listeners: List of `CheckpointSaverListener` subclass instances. Used for
60        callbacks that run immediately before or after this hook saves the
61        checkpoint.
62
63    Raises:
64      ValueError: One of `save_steps` or `save_secs` should be set.
65      ValueError: At most one of `saver` or `scaffold` should be set.
66    """
67    save_path = os.path.join(checkpoint_dir, checkpoint_basename)
68    logging.info("Create AsyncCheckpointSaverHook saving to path\n%s",
69                 save_path)
70    if listeners:
71      logging.info(" with %d listener(s).", len(listeners))
72    if saver is not None and scaffold is not None:
73      raise ValueError("You cannot provide both saver and scaffold.")
74    self._saver = saver
75    self._save_thread = None
76    self._write_graph_thread = None
77    self._checkpoint_dir = checkpoint_dir
78    self._save_path = save_path
79    self._scaffold = scaffold
80    self._timer = basic_session_run_hooks.SecondOrStepTimer(
81        every_secs=save_secs, every_steps=save_steps)
82    self._listeners = listeners or []
83    self._steps_per_run = 1
84    self._summary_writer = None
85    self._global_step_tensor = None
86
87    self._last_checkpoint_step = None
88
89  def _set_steps_per_run(self, steps_per_run):
90    self._steps_per_run = steps_per_run
91
92  def begin(self):
93    self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
94    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
95    if self._global_step_tensor is None:
96      raise RuntimeError(
97          "Global step should be created to use CheckpointSaverHook.")
98    for l in self._listeners:
99      l.begin()
100
101  def after_create_session(self, session, coord):
102    global_step = session.run(self._global_step_tensor)
103
104    # We do write graph and saver_def at the first call of before_run.
105    # We cannot do this in begin, since we let other hooks to change graph and
106    # add variables in begin. Graph is finalized after all begin calls.
107    def _write_graph_fn(self):
108      training_util.write_graph(
109          ops.get_default_graph().as_graph_def(add_shapes=True),
110          self._checkpoint_dir, "graph.pbtxt")
111    self._write_graph_thread = threading.Thread(target=_write_graph_fn,
112                                                args=[self])
113    self._write_graph_thread.start()
114
115    saver_def = self._get_saver().saver_def if self._get_saver() else None
116    graph = ops.get_default_graph()
117    meta_graph_def = meta_graph.create_meta_graph_def(
118        graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def)
119    self._summary_writer.add_graph(graph)
120    self._summary_writer.add_meta_graph(meta_graph_def)
121    # The checkpoint saved here is the state at step "global_step".
122    self._save(session, global_step)
123    self._timer.update_last_triggered_step(global_step)
124
125  def before_run(self, run_context):  # pylint: disable=unused-argument
126    return SessionRunArgs(self._global_step_tensor)
127
128  def after_run(self, run_context, run_values):
129    global_step = run_context.session.run(self._global_step_tensor)
130    if self._timer.should_trigger_for_step(global_step):
131      self._timer.update_last_triggered_step(global_step)
132      logging.info("Triggering checkpoint. %s", global_step)
133      if self._save(run_context.session, global_step):
134        run_context.request_stop()
135
136  def end(self, session):
137    if self._save_thread:
138      logging.info("Waiting for any pending checkpoints to finish.")
139      self._save_thread.join()
140    if self._write_graph_thread:
141      logging.info("Waiting for any pending write_graph to finish.")
142      self._write_graph_thread.join()
143
144    last_step = session.run(self._global_step_tensor)
145
146    if self._last_checkpoint_step != last_step:
147      self._save(session, last_step, asynchronous=False)
148
149    for l in self._listeners:
150      l.end(session, last_step)
151
152  def _save(self, session, step, asynchronous=True):
153    """Saves the latest checkpoint, returns should_stop."""
154
155    def _save_fn():
156      """Run the saver process."""
157      logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
158
159      start_time = time.time()
160      for l in self._listeners:
161        l.before_save(session, step)
162
163      self._get_saver().save(session, self._save_path, global_step=step)
164      self._summary_writer.add_session_log(
165          event_pb2.SessionLog(
166              status=event_pb2.SessionLog.CHECKPOINT,
167              checkpoint_path=self._save_path), step)
168
169      for l in self._listeners:
170        l.after_save(session, step)
171
172      end_time = time.time()
173      logging.info("Checkpoint actual writing time: (%.3f sec)",
174                   end_time - start_time)
175      logging.info("Checkpoint finished for %d into %s.", step, self._save_path)
176
177    if not asynchronous:
178      self._last_checkpoint_step = step
179      _save_fn()
180      return
181
182    if self._save_thread is not None:
183      self._save_thread.join(timeout=0.1)
184      if self._save_thread.is_alive():
185        logging.info("Saver thread still in progress, skipping checkpoint.")
186        return
187
188    self._last_checkpoint_step = step
189    self._save_thread = threading.Thread(target=_save_fn)
190    self._save_thread.start()
191
192  def _get_saver(self):
193    if self._saver is not None:
194      return self._saver
195    elif self._scaffold is not None:
196      return self._scaffold.saver
197
198    # Get saver from the SAVERS collection if present.
199    collection_key = ops.GraphKeys.SAVERS
200    savers = ops.get_collection(collection_key)
201    if not savers:
202      raise RuntimeError(
203          "No items in collection {}. Please add a saver to the collection "
204          "or provide a saver or scaffold.".format(collection_key))
205    elif len(savers) > 1:
206      raise RuntimeError(
207          "More than one item in collection {}. "
208          "Please indicate which one to use by passing it to the constructor."
209          .format(collection_key))
210
211    self._saver = savers[0]
212    return savers[0]
213