• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python3
2# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Python module for evaluation loop."""
17
18# pylint: disable=g-direct-tensorflow-import
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import errors_impl
21from tensorflow.python.ops import variables
22from tensorflow.python.platform import tf_logging as logging
23from tensorflow.python.training import checkpoint_utils
24from tensorflow.python.training.tracking import util as tracking_util
25
26_PRINT_EVAL_STEP_EVERY_SEC = 60.0
27_ITERATIONS_UNINITIALIZED = -1
28
29
30def list_checkpoint_attributes(ckpt_dir_or_file):
31  """Lists all the attributes in a checkpoint.
32
33  Checkpoint keys are paths in a checkpoint graph, and attribute is the first
34  element in the path. e.g. with a checkpoint key
35  "optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE", optimizer is the attribute. The
36  attribute is also used to save/restore a variable in a checkpoint,
37  e.g. tf.train.Checkpoint(optimizer=optimizer, model=model).
38
39  Args:
40    ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint.
41
42  Returns:
43    Set of attributes in a checkpoint.
44  """
45  reader = checkpoint_utils.load_checkpoint(ckpt_dir_or_file)
46  variable_map = reader.get_variable_to_shape_map()
47  return {name.split('/')[0] for name in variable_map.keys()}
48
49
50class SidecarEvaluator(object):
51  """A class designed for a dedicated evaluator task.
52
53  `SidecarEvaluator` is expected to be run in a process on a separate machine
54  from the training cluster. It is meant for the purpose of a dedicated
55  evaluator, evaluating the metric results of a training cluster which has one
56  or more workers performing the training, and saving checkpoints.
57
58  The `SidecarEvaluator` API is compatible with both Custom Training Loop (CTL),
59  and Keras `Model.fit` to be used in the training cluster. Using the model
60  (with compiled metrics) provided at `__init__`, `SidecarEvaluator` repeatedly
61  performs evaluation "epochs" when it finds a checkpoint that has not yet been
62  used. Depending on the `steps` argument, an eval epoch is evaluation over all
63  eval data, or up to certain number of steps (batches). See examples below for
64  how the training program should save the checkpoints in order to be recognized
65  by `SidecarEvaluator`.
66
67  Since under the hood, `SidecarEvaluator` uses `model.evaluate` for evaluation,
68  it also supports arbitrary Keras callbacks. That is, if one or more callbacks
69  are provided, their `on_test_batch_begin` and `on_test_batch_end` methods are
70  called at the start and end of a batch, and their `on_test_begin` and
71  `on_test_end` are called at the start and end of an evaluation epoch. Note
72  that `SidecarEvaluator` may skip some checkpoints because it always picks up
73  the latest checkpoint available, and during an evaluation epoch, multiple
74  checkpoints can be produced from the training side.
75
76  Example:
77  ```python
78  model = tf.keras.models.Sequential(...)
79  model.compile(metrics=tf.keras.metrics.SparseCategoricalAccuracy(
80      name="eval_metrics"))
81  data = tf.data.Dataset.from_tensor_slices(...)
82
83  SidecarEvaluator(
84      model=model,
85      data=data,
86      checkpoint_dir='/tmp/checkpoint_dir',  # dir for training-saved checkpoint
87      steps=None,  # Eval until dataset is exhausted
88      max_evaluations=None,  # The evaluation needs to be stopped manually
89      callbacks=[tf.keras.callbacks.TensorBoard(log_dir='/tmp/log_dir')]
90  ).start()
91  ```
92
93  `SidecarEvaluator.start` writes a series of summary
94  files which can be visualized by tensorboard (which provides a webpage link):
95
96  ```bash
97  $ tensorboard --logdir=/tmp/log_dir/validation
98  ...
99  TensorBoard 2.4.0a0 at http://host:port (Press CTRL+C to quit)
100  ```
101
102  If the training cluster uses a CTL, the `checkpoint_dir` should contain
103  checkpoints that track both `model` and `optimizer`, to fulfill
104  `SidecarEvaluator`'s expectation. This can be done by a
105  `tf.train.Checkpoint` and a `tf.train.CheckpointManager`:
106
107  ```python
108  checkpoint_dir = ...  # Same `checkpoint_dir` supplied to `SidecarEvaluator`.
109  checkpoint = tf.train.Checkpoint(model=model, optimizer=optimizer)
110  checkpoint_manager = tf.train.CheckpointManager(
111      checkpoint, checkpoint_dir=..., max_to_keep=...)
112  checkpoint_manager.save()
113  ```
114
115  If the training cluster uses Keras `Model.fit` API, a
116  `tf.keras.callbacks.ModelCheckpoint` should be used, with
117  `save_weights_only=True`, and the `filepath` should have 'ckpt-{epoch}'
118  appended:
119
120  ```python
121  checkpoint_dir = ...  # Same `checkpoint_dir` supplied to `SidecarEvaluator`.
122  model_checkpoint = tf.keras.callbacks.ModelCheckpoint(
123      filepath=os.path.join(checkpoint_dir, 'ckpt-{epoch}'),
124      save_weights_only=True)
125  model.fit(dataset, epochs, callbacks=[model_checkpoint])
126  ```
127  """
128
129  def __init__(self,
130               model,
131               data,
132               checkpoint_dir,
133               steps=None,
134               max_evaluations=None,
135               callbacks=None):
136    """Initializes an `SidecarEvaluator` object.
137
138    Args:
139      model: Model to use for evaluation. The model object used here should be a
140        `tf.keras.Model`, and should be the same as the one that is used in
141        training, where `tf.keras.Model`s are checkpointed. The model should
142        have one or more metrics compiled before using `SidecarEvaluator`.
143      data: The input data for evaluation. `SidecarEvaluator` supports all data
144        types that Keras `model.evaluate` supports as the input data `x`, such
145        as a `tf.data.Dataset`.
146      checkpoint_dir: Directory where checkpoint files are saved.
147      steps: Number of steps to perform evaluation for, when evaluating a single
148        checkpoint file. If `None`, evaluation continues until the dataset is
149        exhausted. For repeated evaluation dataset, user must specify `steps` to
150        avoid infinite evaluation loop.
151      max_evaluations: Maximum number of the checkpoint file to be evaluated,
152        for `SidecarEvaluator` to know when to stop. The evaluator will stop
153        after it evaluates a checkpoint filepath ending with
154        '<ckpt_name>-<max_evaluations>'. If using
155        `tf.train.CheckpointManager.save` for saving checkpoints, the kth saved
156        checkpoint has the filepath suffix '<ckpt_name>-<k>' (k=1 for the first
157        saved), and if checkpoints are saved every epoch after training, the
158        filepath saved at the kth epoch would end with '<ckpt_name>-<k>. Thus,
159        if training runs for n epochs, and the evaluator should end after the
160        training finishes, use n for this parameter. Note that this is not
161        necessarily equal to the number of total evaluations, since some
162        checkpoints may be skipped if evaluation is slower than checkpoint
163        creation. If `None`, `SidecarEvaluator` will evaluate indefinitely, and
164        the user must terminate evaluator program themselves.
165      callbacks: List of `keras.callbacks.Callback` instances to apply during
166        evaluation. See [callbacks](/api_docs/python/tf/keras/callbacks).
167    """
168    self.model = model
169    self.data = data
170    self.checkpoint_dir = checkpoint_dir
171    self._iterations = variables.Variable(
172        name='iterations',
173        initial_value=_ITERATIONS_UNINITIALIZED,
174        dtype=dtypes.int64)
175    self.max_evaluations = max_evaluations
176    self.steps = steps
177    self.callbacks = callbacks or []
178
179  def start(self):
180    """Starts the evaluation loop."""
181    optimizer_checkpoint = tracking_util.Checkpoint(iter=self._iterations)
182    checkpoint = tracking_util.Checkpoint(
183        model=self.model, optimizer=optimizer_checkpoint)
184
185    for latest_checkpoint in checkpoint_utils.checkpoints_iterator(
186        self.checkpoint_dir):
187      try:
188        # `expect_partial` because the checkpoint can have other `Trackable`s
189        # such as `optimizer`.
190        checkpoint.restore(latest_checkpoint).expect_partial()
191        checkpoint_attributes = list_checkpoint_attributes(latest_checkpoint)
192        # The checkpoint should contain model and optimizer for SidecarEvaluator
193        # to work. But the model weights saved by ModelCheckpoint callback does
194        # not contain model as an attribute. To make SidecarEvaluator compatibly
195        # work in this case, use model.load_weights to load the model's weights,
196        # while self._iterations is still restored by checkpoint variable.
197        if 'model' not in checkpoint_attributes:
198          self.model.load_weights(latest_checkpoint)
199        # The model checkpoint might not include optimizer in cases, e.g.
200        # using a custom training loop. Directly assign the iterations
201        # property to be used in callbacks.
202        if self.model.optimizer:
203          self.model.optimizer.iterations.assign(self._iterations)
204      except (errors_impl.OpError,) as e:
205        # A couple errors can happen here with the coordinator racing to write
206        # checkpoint:
207        # 1) OpError: open failed for <file path>: No such file or directory
208        # 2) NotFoundError (subclass of OpError): Unsuccessful
209        # TensorSliceReader constructor.
210        # TODO(rchao): Remove this except block once b/150954027 is resolved.
211        logging.info(
212            'SidecarEvaluator has an error loading '
213            'checkpoint: %s. Retrying. Error: %s: %s', latest_checkpoint,
214            e.__class__.__name__, e)
215        continue
216
217      if self._iterations.numpy() == _ITERATIONS_UNINITIALIZED:
218        raise RuntimeError(
219            '`iterations` cannot be loaded from the '
220            'checkpoint file. Please ensure `iterations` is '
221            'tracked in the `checkpoint` saved by the coordinator.')
222
223      logging.info(
224          'Evaluation starts: Model weights loaded from latest '
225          'checkpoint file: %s.', latest_checkpoint)
226
227      self.model.evaluate(
228          self.data, steps=self.steps, callbacks=self.callbacks, verbose=2)
229
230      return_metrics = {}
231      for metric in self.model.metrics:
232        result = metric.result()
233        if isinstance(result, dict):
234          return_metrics.update(result)
235        else:
236          return_metrics[metric.name] = result
237
238      logging.info(
239          'End of evaluation. Metrics: %s', ' '.join([
240              '{}={}'.format(name, value.numpy())
241              for name, value in return_metrics.items()
242          ]))
243
244      # TODO(rchao): Make the max evaluation robust in case users save the
245      # checkpoints with epoch format {epoch:03d}.
246      if (self.max_evaluations and
247          latest_checkpoint.endswith('-{}'.format(self.max_evaluations))):
248        # Exit the loop because we have evaluated the final checkpoint file.
249        logging.info('Last checkpoint evaluated. SidecarEvaluator stops.')
250        return
251