• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Training state management."""
16
17import os
18
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import errors
22from tensorflow.python.keras import backend
23from tensorflow.python.keras.distribute import distributed_file_utils
24from tensorflow.python.keras.utils import mode_keys
25from tensorflow.python.lib.io import file_io
26from tensorflow.python.ops import variables
27from tensorflow.python.training import checkpoint_management
28from tensorflow.python.training.tracking import util as trackable_util
29
30# Constant for `tf.keras.Model` attribute to store the epoch at which the most
31# recently saved checkpoint was saved.
32CKPT_SAVED_EPOCH = '_ckpt_saved_epoch'
33
34CKPT_SAVED_EPOCH_UNUSED_VALUE = -1
35
36
37class WorkerTrainingState(object):
38  """Training state management class.
39
40  This class provides apis for backing up and restoring the training state.
41  This allows model and epoch information to be saved periodically and restore
42  for fault-tolerance, also known as preemption-recovery purpose.
43  """
44
45  def __init__(self, model, checkpoint_dir):
46    self._model = model
47
48    # The epoch at which the checkpoint is saved. Used for fault-tolerance.
49    # GPU device only has int64 dtype registered VarHandleOp.
50    self._ckpt_saved_epoch = variables.Variable(
51        initial_value=constant_op.constant(
52            CKPT_SAVED_EPOCH_UNUSED_VALUE, dtype=dtypes.int64),
53        name='ckpt_saved_epoch')
54
55    # Variable initialization.
56    backend.set_value(self._ckpt_saved_epoch, CKPT_SAVED_EPOCH_UNUSED_VALUE)
57
58    # _ckpt_saved_epoch gets tracked and is included in the checkpoint file
59    # when backing up.
60    checkpoint = trackable_util.Checkpoint(
61        model=self._model, ckpt_saved_epoch=self._ckpt_saved_epoch)
62
63    # If this is single-worker training, checkpoint_dir are the same for
64    # write_checkpoint_manager and read_checkpoint_manager.
65    #
66    # If this is multi-worker training, and this worker should not
67    # save checkpoint, we replace the write_checkpoint_manager's checkpoint_dir
68    # with a temp filepath, so it writes to a file that will be removed at the
69    # end of back_up() call. This is necessary because the SyncOnReadVariable
70    # needs to be synced across all the workers in order to be read, and all
71    # workers need to perform `save()`.
72    # But all workers should restore from the same checkpoint_dir as passed in
73    # read_checkpoint_manager.
74    self.read_checkpoint_manager = checkpoint_management.CheckpointManager(
75        checkpoint,
76        directory=os.path.join(checkpoint_dir, 'chief'),
77        max_to_keep=1)
78    write_checkpoint_dir = distributed_file_utils.write_dirpath(
79        checkpoint_dir, self._model.distribute_strategy)
80    if self._model.distribute_strategy.extended.should_checkpoint:
81      self.write_checkpoint_manager = self.read_checkpoint_manager
82    else:
83      self.write_checkpoint_manager = checkpoint_management.CheckpointManager(
84          checkpoint, directory=write_checkpoint_dir, max_to_keep=1)
85
86  def back_up(self, epoch):
87    """Back up the current state of training into a checkpoint file.
88
89    Args:
90      epoch: The current epoch information to be saved.
91    """
92    backend.set_value(self._ckpt_saved_epoch, epoch)
93    # Save the model plus CKPT_SAVED_EPOCH variable.
94    if self.write_checkpoint_manager.save():
95      distributed_file_utils.remove_temp_dirpath(
96          self.write_checkpoint_manager.directory,
97          self._model.distribute_strategy)
98
99  def restore(self):
100    """Restore the training state from the backed up checkpoint file.
101
102    Returns:
103      True if the training state is successfully restored. False if the training
104      state doesn't need to be restored, or error occurred so it can't.
105    """
106    self.read_checkpoint_manager.restore_or_initialize()
107
108  def delete_backup(self):
109    """Delete the backup directories.
110
111    Delete the backup directories which should not exist after `fit()`
112    successfully finishes.
113    """
114    if self.write_checkpoint_manager is self.read_checkpoint_manager:
115      try:
116        file_io.delete_recursively_v2(self.write_checkpoint_manager.directory)
117      except errors.NotFoundError:
118        pass
119
120  def maybe_load_initial_epoch_from_ckpt(self, initial_epoch, mode):
121    """Maybe load initial epoch from ckpt considering possible worker recovery.
122
123    When `_ckpt_saved_epoch` attribute exists and is not
124    `CKPT_SAVED_EPOCH_UNUSED_VALUE`, this is under multi-worker training setting
125    and indicates the worker is recovering from previous failure. In this case,
126    infer `initial_epoch` from `self._ckpt_saved_epoch` to continue previous
127    unfinished training from certain epoch.
128
129    Args:
130      initial_epoch: The original initial_epoch user passes in in `fit()`.
131      mode: The mode for running `model.fit()`.
132
133    Returns:
134      If the training is recovering from previous failure under multi-worker
135      training setting, return the epoch the training is supposed to continue
136      at. Otherwise, return the `initial_epoch` the user passes in.
137    """
138
139    epoch = backend.eval(self._ckpt_saved_epoch)
140    if mode == mode_keys.ModeKeys.TRAIN and epoch >= 0:
141      # The most recently saved epoch is one epoch prior to the epoch it
142      # failed at, so return the value of 'self._ckpt_saved_epoch' plus one.
143      return epoch + 1
144    return initial_epoch
145