1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Iterator ops.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.data.ops import iterator_ops 22from tensorflow.python.data.ops import options as options_lib 23from tensorflow.python.framework import ops 24from tensorflow.python.training import basic_session_run_hooks 25from tensorflow.python.training import checkpoint_management 26from tensorflow.python.training import saver as saver_lib 27from tensorflow.python.training import session_run_hook 28from tensorflow.python.util import deprecation 29from tensorflow.python.util.tf_export import tf_export 30 31 32def _convert_external_state_policy_to_enum(external_state_policy): 33 if isinstance(external_state_policy, options_lib.ExternalStatePolicy): 34 return external_state_policy 35 if external_state_policy == "warn": 36 return options_lib.ExternalStatePolicy.WARN 37 if external_state_policy == "ignore": 38 return options_lib.ExternalStatePolicy.IGNORE 39 if external_state_policy == "fail": 40 return options_lib.ExternalStatePolicy.FAIL 41 raise ValueError( 42 "Failed to convert {} to an instance of ExternalStatePolicy." 43 "Supported values include: 'warn', 'ignore' and 'fail'".format( 44 external_state_policy)) 45 46 47@tf_export("data.experimental.make_saveable_from_iterator") 48@deprecation.deprecated( 49 None, "`make_saveable_from_iterator` is intended for use in TF1 with " 50 "`tf.compat.v1.Saver`. In TF2, use `tf.train.Checkpoint` instead.") 51def make_saveable_from_iterator(iterator, external_state_policy=None): 52 """Returns a SaveableObject for saving/restoring iterator state using Saver. 53 54 Args: 55 iterator: Iterator. 56 external_state_policy: A string that identifies how to handle input 57 pipelines that depend on external state. Possible values are 58 'ignore': The external state is silently ignored. 59 'warn': The external state is ignored, logging a warning. 60 'fail': The operation fails upon encountering external state. 61 By default we set it to 'fail'. 62 63 Returns: 64 A SaveableObject for saving/restoring iterator state using Saver. 65 66 Raises: 67 ValueError: If iterator does not support checkpointing. 68 ValueError: If `external_state_policy` is not one of 'warn', 'ignore' or 69 'fail'. 70 71 For example: 72 73 ```python 74 with tf.Graph().as_default(): 75 ds = tf.data.Dataset.range(10) 76 iterator = ds.make_initializable_iterator() 77 # Build the iterator SaveableObject. 78 saveable_obj = tf.data.experimental.make_saveable_from_iterator(iterator) 79 # Add the SaveableObject to the SAVEABLE_OBJECTS collection so 80 # it can be automatically saved using Saver. 81 tf.compat.v1.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable_obj) 82 saver = tf.compat.v1.train.Saver() 83 84 while continue_training: 85 ... Perform training ... 86 if should_save_checkpoint: 87 saver.save() 88 ``` 89 90 Note: When restoring the iterator, the existing iterator state is completely 91 discarded. This means that any changes you may have made to the Dataset 92 graph will be discarded as well! This includes the new Dataset graph 93 that you may have built during validation. So, while running validation, 94 make sure to run the initializer for the validation input pipeline after 95 restoring the checkpoint. 96 97 Note: Not all iterators support checkpointing yet. Attempting to save the 98 state of an unsupported iterator will throw an error. 99 """ 100 if external_state_policy is None: 101 external_state_policy = "fail" 102 policy_enum = _convert_external_state_policy_to_enum(external_state_policy) 103 return iterator_ops._IteratorSaveable( # pylint: disable=protected-access 104 iterator._iterator_resource, # pylint: disable=protected-access 105 iterator._iterator_resource.name, # pylint: disable=protected-access 106 external_state_policy=policy_enum) 107 108 109@tf_export("data.experimental.CheckpointInputPipelineHook") 110class CheckpointInputPipelineHook(session_run_hook.SessionRunHook): 111 """Checkpoints input pipeline state every N steps or seconds. 112 113 This hook saves the state of the iterators in the `Graph` so that when 114 training is resumed the input pipeline continues from where it left off. 115 This could potentially avoid overfitting in certain pipelines where the 116 number of training steps per eval are small compared to the dataset 117 size or if the training pipeline is pre-empted. 118 119 Differences from `CheckpointSaverHook`: 120 1. Saves only the input pipelines in the "iterators" collection and not the 121 global variables or other saveable objects. 122 2. Does not write the `GraphDef` and `MetaGraphDef` to the summary. 123 124 Example of checkpointing the training pipeline: 125 126 ```python 127 est = tf.estimator.Estimator(model_fn) 128 while True: 129 est.train( 130 train_input_fn, 131 hooks=[tf.data.experimental.CheckpointInputPipelineHook(est)], 132 steps=train_steps_per_eval) 133 # Note: We do not pass the hook here. 134 metrics = est.evaluate(eval_input_fn) 135 if should_stop_the_training(metrics): 136 break 137 ``` 138 139 This hook should be used if the input pipeline state needs to be saved 140 separate from the model checkpoint. Doing so may be useful for a few reasons: 141 1. The input pipeline checkpoint may be large, if there are large shuffle 142 or prefetch buffers for instance, and may bloat the checkpoint size. 143 2. If the input pipeline is shared between training and validation, restoring 144 the checkpoint during validation may override the validation input 145 pipeline. 146 147 For saving the input pipeline checkpoint alongside the model weights use 148 `tf.data.experimental.make_saveable_from_iterator` directly to create a 149 `SaveableObject` and add to the `SAVEABLE_OBJECTS` collection. Note, however, 150 that you will need to be careful not to restore the training iterator during 151 eval. You can do that by not adding the iterator to the SAVEABLE_OBJECTS 152 collector when building the eval graph. 153 """ 154 155 def __init__(self, estimator, external_state_policy=None): 156 """Initializes a `CheckpointInputPipelineHook`. 157 158 If the input pipeline depends on external state (e.g. seeds for 159 RandomUniform) beyond the input pipeline, this hook would be unable to 160 serialize and deserialize that state. If its acceptable to ignore that state 161 change the external_state_policy argument to 'warn' or 'ignore'. For e.g. 162 163 ```python 164 est = tf.estimator.Estimator(model_fn) 165 while True: 166 est.train( 167 train_input_fn, 168 hooks=[tf.data.experimental.CheckpointInputPipelineHook( 169 est, external_state_policy='warn')], 170 steps=train_steps_per_eval) 171 # Note: We do not pass the hook here. 172 metrics = est.evaluate(eval_input_fn) 173 if should_stop_the_training(metrics): 174 break 175 ``` 176 177 Args: 178 estimator: Estimator. 179 external_state_policy: A string that identifies how to handle input 180 pipelines that depend on external state. Possible values are 181 'ignore': The external state is silently ignored. 182 'warn': The external state is ignored, logging a warning. 183 'fail': The operation fails upon encountering external state. 184 By default we set it to 'fail'. 185 186 Raises: 187 ValueError: One of `save_steps` or `save_secs` should be set. 188 ValueError: At most one of saver or scaffold should be set. 189 ValueError: If `external_state_policy` is not one of 'warn', 'ignore' or 190 'fail'. 191 """ 192 if external_state_policy is None: 193 external_state_policy = "fail" 194 self._external_state_policy = _convert_external_state_policy_to_enum( 195 external_state_policy) 196 # `checkpoint_basename` is "input.ckpt" for non-distributed pipelines or 197 # of the form "input_<task_type>_<task_id>.ckpt" for distributed pipelines. 198 # Note: The default `checkpoint_basename` used by `CheckpointSaverHook` is 199 # "model.ckpt". We intentionally choose the input pipeline checkpoint prefix 200 # to be different to avoid conflicts with the model checkpoint. 201 202 # pylint: disable=protected-access 203 checkpoint_prefix = "input" 204 if estimator._config.num_worker_replicas > 1: 205 # Distributed setting. 206 suffix = "_{}_{}".format(estimator._config.task_type, 207 estimator._config.task_id) 208 checkpoint_prefix += suffix 209 # pylint: enable=protected-access 210 211 # We use a composition paradigm instead of inheriting from 212 # `CheckpointSaverHook` because `Estimator` does an `isinstance` check 213 # to check whether a `CheckpointSaverHook` is already present in the list 214 # of hooks and if not, adds one. Inheriting from `CheckpointSaverHook` 215 # would thwart this behavior. This hook checkpoints *only the iterators* 216 # and not the graph variables. 217 self._checkpoint_saver_hook = basic_session_run_hooks.CheckpointSaverHook( 218 estimator.model_dir, 219 save_secs=estimator._config.save_checkpoints_secs, # pylint: disable=protected-access 220 save_steps=estimator._config.save_checkpoints_steps, # pylint: disable=protected-access 221 checkpoint_basename=checkpoint_prefix + ".ckpt") 222 223 # Name for the protocol buffer file that will contain the list of most 224 # recent checkpoints stored as a `CheckpointState` protocol buffer. 225 # This file, kept in the same directory as the checkpoint files, is 226 # automatically managed by the `Saver` to keep track of recent checkpoints. 227 # The default name used by the `Saver` for this file is "checkpoint". Here 228 # we use the name "checkpoint_<checkpoint_prefix>" so that in case the 229 # `checkpoint_dir` is the same as the model checkpoint directory, there are 230 # no conflicts during restore. 231 self._latest_filename = "checkpoint_" + checkpoint_prefix 232 233 def begin(self): 234 # Build a Saver that saves all iterators in the `GLOBAL_ITERATORS` 235 # collection if no `Saver` or `Scaffold` is provided. 236 # pylint: disable=protected-access 237 if (self._checkpoint_saver_hook._saver is None and 238 self._checkpoint_saver_hook._scaffold is None): 239 iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS) 240 saveables = [ 241 iterator_ops._IteratorSaveable( 242 i, i.name, external_state_policy=self._external_state_policy) 243 for i in iterators 244 ] 245 self._checkpoint_saver_hook._saver = _CustomSaver( 246 saveables, self._latest_filename, sharded=True) 247 # pylint: enable=protected-access 248 self._checkpoint_saver_hook.begin() 249 250 def after_create_session(self, session, coord): 251 # If a new session was created, we set _first_run to True so that we can 252 # restore if needed. 253 self._first_run = True 254 255 def _restore_or_save_initial_ckpt(self, session): 256 # Ideally this should be run in after_create_session but is not for the 257 # following reason: 258 # Currently there is no way of enforcing an order of running the 259 # `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook` 260 # is run *after* this hook. That is troublesome because 261 # 1. If a checkpoint exists and this hook restores it, the initializer hook 262 # will override it. 263 # 2. If no checkpoint exists, this hook will try to save an uninitialized 264 # iterator which will result in an exception. 265 # 266 # As a temporary fix we enter the following implicit contract between this 267 # hook and the _DatasetInitializerHook. 268 # 1. The _DatasetInitializerHook initializes the iterator in the call to 269 # after_create_session. 270 # 2. This hook saves the iterator on the first call to `before_run()`, which 271 # is guaranteed to happen after `after_create_session()` of all hooks 272 # have been run. 273 274 # Check if there is an existing checkpoint. If so, restore from it. 275 # pylint: disable=protected-access 276 latest_checkpoint_path = checkpoint_management.latest_checkpoint( 277 self._checkpoint_saver_hook._checkpoint_dir, 278 latest_filename=self._latest_filename) 279 if latest_checkpoint_path: 280 self._checkpoint_saver_hook._get_saver().restore(session, 281 latest_checkpoint_path) 282 else: 283 # The checkpoint saved here is the state at step "global_step". 284 # Note: We do not save the GraphDef or MetaGraphDef here. 285 global_step = session.run(self._checkpoint_saver_hook._global_step_tensor) 286 self._checkpoint_saver_hook._save(session, global_step) 287 self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step) 288 # pylint: enable=protected-access 289 290 def before_run(self, run_context): 291 if self._first_run: 292 self._restore_or_save_initial_ckpt(run_context.session) 293 self._first_run = False 294 return self._checkpoint_saver_hook.before_run(run_context) 295 296 def after_run(self, run_context, run_values): 297 self._checkpoint_saver_hook.after_run(run_context, run_values) 298 299 def end(self, session): 300 self._checkpoint_saver_hook.end(session) 301 302 303class _CustomSaver(saver_lib.Saver): 304 """`Saver` with a different default `latest_filename`. 305 306 This is used in the `CheckpointInputPipelineHook` to avoid conflicts with 307 the model ckpt saved by the `CheckpointSaverHook`. 308 """ 309 310 def __init__(self, var_list, latest_filename, sharded=False): 311 super(_CustomSaver, self).__init__(var_list, sharded=sharded) 312 self._latest_filename = latest_filename 313 314 def save(self, 315 sess, 316 save_path, 317 global_step=None, 318 latest_filename=None, 319 meta_graph_suffix="meta", 320 write_meta_graph=True, 321 write_state=True, 322 strip_default_attrs=False): 323 return super(_CustomSaver, self).save( 324 sess, save_path, global_step, latest_filename or self._latest_filename, 325 meta_graph_suffix, write_meta_graph, write_state, strip_default_attrs) 326