• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Iterator ops."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.data.ops import iterator_ops
22from tensorflow.python.data.ops import options as options_lib
23from tensorflow.python.framework import ops
24from tensorflow.python.training import basic_session_run_hooks
25from tensorflow.python.training import checkpoint_management
26from tensorflow.python.training import saver as saver_lib
27from tensorflow.python.training import session_run_hook
28from tensorflow.python.util import deprecation
29from tensorflow.python.util.tf_export import tf_export
30
31
32def _convert_external_state_policy_to_enum(external_state_policy):
33  if isinstance(external_state_policy, options_lib.ExternalStatePolicy):
34    return external_state_policy
35  if external_state_policy == "warn":
36    return options_lib.ExternalStatePolicy.WARN
37  if external_state_policy == "ignore":
38    return options_lib.ExternalStatePolicy.IGNORE
39  if external_state_policy == "fail":
40    return options_lib.ExternalStatePolicy.FAIL
41  raise ValueError(
42      "Failed to convert {} to an instance of ExternalStatePolicy."
43      "Supported values include: 'warn', 'ignore' and 'fail'".format(
44          external_state_policy))
45
46
47@tf_export("data.experimental.make_saveable_from_iterator")
48@deprecation.deprecated(
49    None, "`make_saveable_from_iterator` is intended for use in TF1 with "
50    "`tf.compat.v1.Saver`. In TF2, use `tf.train.Checkpoint` instead.")
51def make_saveable_from_iterator(iterator, external_state_policy=None):
52  """Returns a SaveableObject for saving/restoring iterator state using Saver.
53
54  Args:
55    iterator: Iterator.
56    external_state_policy: A string that identifies how to handle input
57      pipelines that depend on external state. Possible values are
58      'ignore': The external state is silently ignored.
59      'warn': The external state is ignored, logging a warning.
60      'fail': The operation fails upon encountering external state.
61      By default we set it to 'fail'.
62
63  Returns:
64    A SaveableObject for saving/restoring iterator state using Saver.
65
66  Raises:
67    ValueError: If iterator does not support checkpointing.
68    ValueError: If `external_state_policy` is not one of 'warn', 'ignore' or
69      'fail'.
70
71  For example:
72
73  ```python
74  with tf.Graph().as_default():
75    ds = tf.data.Dataset.range(10)
76    iterator = ds.make_initializable_iterator()
77    # Build the iterator SaveableObject.
78    saveable_obj = tf.data.experimental.make_saveable_from_iterator(iterator)
79    # Add the SaveableObject to the SAVEABLE_OBJECTS collection so
80    # it can be automatically saved using Saver.
81    tf.compat.v1.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable_obj)
82    saver = tf.compat.v1.train.Saver()
83
84    while continue_training:
85      ... Perform training ...
86      if should_save_checkpoint:
87        saver.save()
88  ```
89
90  Note: When restoring the iterator, the existing iterator state is completely
91  discarded. This means that any changes you may have made to the Dataset
92  graph will be discarded as well! This includes the new Dataset graph
93  that you may have built during validation. So, while running validation,
94  make sure to run the initializer for the validation input pipeline after
95  restoring the checkpoint.
96
97  Note: Not all iterators support checkpointing yet. Attempting to save the
98  state of an unsupported iterator will throw an error.
99  """
100  if external_state_policy is None:
101    external_state_policy = "fail"
102  policy_enum = _convert_external_state_policy_to_enum(external_state_policy)
103  return iterator_ops._IteratorSaveable(  # pylint: disable=protected-access
104      iterator._iterator_resource,  # pylint: disable=protected-access
105      iterator._iterator_resource.name,  # pylint: disable=protected-access
106      external_state_policy=policy_enum)
107
108
109@tf_export("data.experimental.CheckpointInputPipelineHook")
110class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
111  """Checkpoints input pipeline state every N steps or seconds.
112
113  This hook saves the state of the iterators in the `Graph` so that when
114  training is resumed the input pipeline continues from where it left off.
115  This could potentially avoid overfitting in certain pipelines where the
116  number of training steps per eval are small compared to the dataset
117  size or if the training pipeline is pre-empted.
118
119  Differences from `CheckpointSaverHook`:
120  1. Saves only the input pipelines in the "iterators" collection and not the
121     global variables or other saveable objects.
122  2. Does not write the `GraphDef` and `MetaGraphDef` to the summary.
123
124  Example of checkpointing the training pipeline:
125
126  ```python
127  est = tf.estimator.Estimator(model_fn)
128  while True:
129    est.train(
130        train_input_fn,
131        hooks=[tf.data.experimental.CheckpointInputPipelineHook(est)],
132        steps=train_steps_per_eval)
133    # Note: We do not pass the hook here.
134    metrics = est.evaluate(eval_input_fn)
135    if should_stop_the_training(metrics):
136      break
137  ```
138
139  This hook should be used if the input pipeline state needs to be saved
140  separate from the model checkpoint. Doing so may be useful for a few reasons:
141  1. The input pipeline checkpoint may be large, if there are large shuffle
142     or prefetch buffers for instance, and may bloat the checkpoint size.
143  2. If the input pipeline is shared between training and validation, restoring
144     the checkpoint during validation may override the validation input
145     pipeline.
146
147  For saving the input pipeline checkpoint alongside the model weights use
148  `tf.data.experimental.make_saveable_from_iterator` directly to create a
149  `SaveableObject` and add to the `SAVEABLE_OBJECTS` collection. Note, however,
150  that you will need to be careful not to restore the training iterator during
151  eval. You can do that by not adding the iterator to the SAVEABLE_OBJECTS
152  collector when building the eval graph.
153  """
154
155  def __init__(self, estimator, external_state_policy=None):
156    """Initializes a `CheckpointInputPipelineHook`.
157
158    If the input pipeline depends on external state (e.g. seeds for
159    RandomUniform) beyond the input pipeline, this hook would be unable to
160    serialize and deserialize that state. If its acceptable to ignore that state
161    change the external_state_policy argument to 'warn' or 'ignore'. For e.g.
162
163    ```python
164    est = tf.estimator.Estimator(model_fn)
165    while True:
166      est.train(
167          train_input_fn,
168          hooks=[tf.data.experimental.CheckpointInputPipelineHook(
169              est, external_state_policy='warn')],
170          steps=train_steps_per_eval)
171      # Note: We do not pass the hook here.
172      metrics = est.evaluate(eval_input_fn)
173      if should_stop_the_training(metrics):
174        break
175    ```
176
177    Args:
178      estimator: Estimator.
179      external_state_policy: A string that identifies how to handle input
180        pipelines that depend on external state. Possible values are
181        'ignore': The external state is silently ignored.
182        'warn': The external state is ignored, logging a warning.
183        'fail': The operation fails upon encountering external state.
184        By default we set it to 'fail'.
185
186    Raises:
187      ValueError: One of `save_steps` or `save_secs` should be set.
188      ValueError: At most one of saver or scaffold should be set.
189      ValueError: If `external_state_policy` is not one of 'warn', 'ignore' or
190        'fail'.
191    """
192    if external_state_policy is None:
193      external_state_policy = "fail"
194    self._external_state_policy = _convert_external_state_policy_to_enum(
195        external_state_policy)
196    # `checkpoint_basename` is "input.ckpt" for non-distributed pipelines or
197    # of the form "input_<task_type>_<task_id>.ckpt" for distributed pipelines.
198    # Note: The default `checkpoint_basename` used by `CheckpointSaverHook` is
199    # "model.ckpt". We intentionally choose the input pipeline checkpoint prefix
200    # to be different to avoid conflicts with the model checkpoint.
201
202    # pylint: disable=protected-access
203    checkpoint_prefix = "input"
204    if estimator._config.num_worker_replicas > 1:
205      # Distributed setting.
206      suffix = "_{}_{}".format(estimator._config.task_type,
207                               estimator._config.task_id)
208      checkpoint_prefix += suffix
209    # pylint: enable=protected-access
210
211    # We use a composition paradigm instead of inheriting from
212    # `CheckpointSaverHook` because `Estimator` does an `isinstance` check
213    # to check whether a `CheckpointSaverHook` is already present in the list
214    # of hooks and if not, adds one. Inheriting from `CheckpointSaverHook`
215    # would thwart this behavior. This hook checkpoints *only the iterators*
216    # and not the graph variables.
217    self._checkpoint_saver_hook = basic_session_run_hooks.CheckpointSaverHook(
218        estimator.model_dir,
219        save_secs=estimator._config.save_checkpoints_secs,  # pylint: disable=protected-access
220        save_steps=estimator._config.save_checkpoints_steps,  # pylint: disable=protected-access
221        checkpoint_basename=checkpoint_prefix + ".ckpt")
222
223    # Name for the protocol buffer file that will contain the list of most
224    # recent checkpoints stored as a `CheckpointState` protocol buffer.
225    # This file, kept in the same directory as the checkpoint files, is
226    # automatically managed by the `Saver` to keep track of recent checkpoints.
227    # The default name used by the `Saver` for this file is "checkpoint". Here
228    # we use the name "checkpoint_<checkpoint_prefix>" so that in case the
229    # `checkpoint_dir` is the same as the model checkpoint directory, there are
230    # no conflicts during restore.
231    self._latest_filename = "checkpoint_" + checkpoint_prefix
232
233  def begin(self):
234    # Build a Saver that saves all iterators in the `GLOBAL_ITERATORS`
235    # collection if no `Saver` or `Scaffold` is provided.
236    # pylint: disable=protected-access
237    if (self._checkpoint_saver_hook._saver is None and
238        self._checkpoint_saver_hook._scaffold is None):
239      iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS)
240      saveables = [
241          iterator_ops._IteratorSaveable(
242              i, i.name, external_state_policy=self._external_state_policy)
243          for i in iterators
244      ]
245      self._checkpoint_saver_hook._saver = _CustomSaver(
246          saveables, self._latest_filename, sharded=True)
247    # pylint: enable=protected-access
248    self._checkpoint_saver_hook.begin()
249
250  def after_create_session(self, session, coord):
251    # If a new session was created, we set _first_run to True so that we can
252    # restore if needed.
253    self._first_run = True
254
255  def _restore_or_save_initial_ckpt(self, session):
256    # Ideally this should be run in after_create_session but is not for the
257    # following reason:
258    # Currently there is no way of enforcing an order of running the
259    # `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook`
260    # is run *after* this hook. That is troublesome because
261    # 1. If a checkpoint exists and this hook restores it, the initializer hook
262    #    will override it.
263    # 2. If no checkpoint exists, this hook will try to save an uninitialized
264    #    iterator which will result in an exception.
265    #
266    # As a temporary fix we enter the following implicit contract between this
267    # hook and the _DatasetInitializerHook.
268    # 1. The _DatasetInitializerHook initializes the iterator in the call to
269    #    after_create_session.
270    # 2. This hook saves the iterator on the first call to `before_run()`, which
271    #    is guaranteed to happen after `after_create_session()` of all hooks
272    #    have been run.
273
274    # Check if there is an existing checkpoint. If so, restore from it.
275    # pylint: disable=protected-access
276    latest_checkpoint_path = checkpoint_management.latest_checkpoint(
277        self._checkpoint_saver_hook._checkpoint_dir,
278        latest_filename=self._latest_filename)
279    if latest_checkpoint_path:
280      self._checkpoint_saver_hook._get_saver().restore(session,
281                                                       latest_checkpoint_path)
282    else:
283      # The checkpoint saved here is the state at step "global_step".
284      # Note: We do not save the GraphDef or MetaGraphDef here.
285      global_step = session.run(self._checkpoint_saver_hook._global_step_tensor)
286      self._checkpoint_saver_hook._save(session, global_step)
287      self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step)
288    # pylint: enable=protected-access
289
290  def before_run(self, run_context):
291    if self._first_run:
292      self._restore_or_save_initial_ckpt(run_context.session)
293      self._first_run = False
294    return self._checkpoint_saver_hook.before_run(run_context)
295
296  def after_run(self, run_context, run_values):
297    self._checkpoint_saver_hook.after_run(run_context, run_values)
298
299  def end(self, session):
300    self._checkpoint_saver_hook.end(session)
301
302
303class _CustomSaver(saver_lib.Saver):
304  """`Saver` with a different default `latest_filename`.
305
306  This is used in the `CheckpointInputPipelineHook` to avoid conflicts with
307  the model ckpt saved by the `CheckpointSaverHook`.
308  """
309
310  def __init__(self, var_list, latest_filename, sharded=False):
311    super(_CustomSaver, self).__init__(var_list, sharded=sharded)
312    self._latest_filename = latest_filename
313
314  def save(self,
315           sess,
316           save_path,
317           global_step=None,
318           latest_filename=None,
319           meta_graph_suffix="meta",
320           write_meta_graph=True,
321           write_state=True,
322           strip_default_attrs=False):
323    return super(_CustomSaver, self).save(
324        sess, save_path, global_step, latest_filename or self._latest_filename,
325        meta_graph_suffix, write_meta_graph, write_state, strip_default_attrs)
326