• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python3
2# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Test covering sidecar_evaluator.py."""
17
18import enum
19import os
20
21from absl import logging
22from absl.testing import parameterized
23import numpy as np
24
25from tensorflow.python import keras
26from tensorflow.python.compat import v2_compat
27from tensorflow.python.data.ops import dataset_ops
28from tensorflow.python.distribute import combinations as ds_combinations
29from tensorflow.python.framework import test_combinations as combinations
30from tensorflow.python.keras.distribute import sidecar_evaluator as sidecar_evaluator_lib
31from tensorflow.python.keras.optimizer_v2 import gradient_descent
32from tensorflow.python.lib.io import file_io
33from tensorflow.python.platform import test
34from tensorflow.python.summary import summary_iterator
35from tensorflow.python.training import checkpoint_management
36from tensorflow.python.training.tracking import util as tracking_util
37
38_BATCH_SIZE = 32
39
40
41class TestModel(keras.Model):
42
43  def __init__(self):
44    super().__init__(name='test_model')
45    self.dense = keras.layers.Dense(10)
46
47  def call(self, inputs):
48    return self.dense(inputs)
49
50
51class DictMetric(keras.metrics.MeanSquaredError):
52
53  def result(self):
54    res = super().result()
55    return {'mean_squared_error_1': res, 'mean_squared_error_2': res}
56
57
58class ModelType(enum.Enum):
59  SEQUENTIAL = 'sequential'
60  SUBCLASS = 'subclass'
61
62
63def _test_model_builder(model_type: ModelType, compile_model, build_model):
64  if model_type == ModelType.SEQUENTIAL:
65    model = keras.Sequential([keras.layers.Dense(10)])
66  elif model_type == ModelType.SUBCLASS:
67    model = TestModel()
68
69  if compile_model:
70    model.compile(
71        gradient_descent.SGD(),
72        loss='mse',
73        metrics=[keras.metrics.CategoricalAccuracy(),
74                 DictMetric()])
75  if build_model:
76    model.build((None, 32))
77
78  return model
79
80
81class SidecarEvaluatorTest(test.TestCase, parameterized.TestCase):
82
83  def assertSummaryEventsWritten(self, log_dir):
84    # Asserts summary files do get written when log_dir is provided.
85    summary_files = file_io.list_directory_v2(log_dir)
86    self.assertNotEmpty(
87        summary_files, 'Summary should have been written and '
88        'log_dir should not be empty.')
89
90    # Asserts the content of the summary file.
91    event_pb_written = False
92    event_tags = []
93    for summary_file in summary_files:
94      for event_pb in summary_iterator.summary_iterator(
95          os.path.join(log_dir, summary_file)):
96        if event_pb.step > 0:
97          self.assertEqual(event_pb.step, 32)
98          event_tags.append(event_pb.summary.value[0].tag)
99          event_pb_written = True
100    self.assertCountEqual(event_tags, [
101        'evaluation_categorical_accuracy_vs_iterations',
102        'evaluation_loss_vs_iterations',
103        'evaluation_mean_squared_error_1_vs_iterations',
104        'evaluation_mean_squared_error_2_vs_iterations',
105    ])
106
107    # Verifying at least one non-zeroth step is written to summary.
108    self.assertTrue(event_pb_written)
109
110  def assertModelsSameVariables(self, model_a, model_b):
111    # Check both have the same number of variables.
112    self.assertEqual(len(model_a.variables), len(model_b.variables))
113
114    # Check variable values to be equal.
115    for var_a, var_b in zip(model_a.variables, model_b.variables):
116      self.assertAllEqual(var_a.numpy(), var_b.numpy())
117
118  @ds_combinations.generate(
119      combinations.combine(
120          mode=['eager'], model_type=[ModelType.SEQUENTIAL,
121                                      ModelType.SUBCLASS]))
122  def testIterationsNotSavedWillRaiseError(self, model_type):
123    model = _test_model_builder(
124        model_type=model_type, compile_model=False, build_model=True)
125
126    checkpoint_dir = self.get_temp_dir()
127    checkpoint = tracking_util.Checkpoint(model=model)
128    checkpoint_manager = checkpoint_management.CheckpointManager(
129        checkpoint, checkpoint_dir, max_to_keep=2)
130    checkpoint_manager.save()
131
132    sidecar_evaluator = sidecar_evaluator_lib.SidecarEvaluator(
133        model, data=None, checkpoint_dir=checkpoint_dir)
134    with self.assertRaisesRegex(
135        RuntimeError, '`iterations` cannot be loaded '
136        'from the checkpoint file.'):
137      sidecar_evaluator.start()
138
139  @ds_combinations.generate(
140      combinations.combine(
141          mode=['eager'], model_type=[ModelType.SEQUENTIAL,
142                                      ModelType.SUBCLASS]))
143  def testModelNotBuiltRaiseError(self, model_type):
144    model = _test_model_builder(
145        model_type=model_type, compile_model=False, build_model=False)
146
147    checkpoint_dir = self.get_temp_dir()
148    checkpoint = tracking_util.Checkpoint(model=model)
149    checkpoint_manager = checkpoint_management.CheckpointManager(
150        checkpoint, checkpoint_dir, max_to_keep=2)
151    checkpoint_manager.save()
152
153    sidecar_evaluator = sidecar_evaluator_lib.SidecarEvaluator(
154        model, data=None, checkpoint_dir=checkpoint_dir)
155    with self.assertRaisesRegex(AssertionError, 'Nothing to load.'):
156      sidecar_evaluator.start()
157
158  @ds_combinations.generate(
159      combinations.combine(
160          mode=['eager'],
161          model_type=[ModelType.SEQUENTIAL, ModelType.SUBCLASS],
162          build_model=[True, False]))
163  def testSidecarEvaluatorOutputsSummary(self, model_type, build_model):
164    # Create a model with synthetic data, and fit for one epoch.
165    model = _test_model_builder(
166        model_type=model_type, compile_model=True, build_model=False)
167    data = np.random.random((1000, 32))
168    labels = np.random.random((1000, 10))
169    dataset = dataset_ops.Dataset.from_tensor_slices((data, labels))
170    dataset = dataset.batch(32)
171    model.fit(dataset, epochs=1)
172
173    # Save a checkpoint.
174    checkpoint_dir = os.path.join(self.get_temp_dir(), 'ckpt')
175    log_dir = os.path.join(self.get_temp_dir(), 'summary')
176    logging.info('checkpoint_dir = %s, log_dir = %s', checkpoint_dir, log_dir)
177    checkpoint = tracking_util.Checkpoint(
178        model=model, optimizer=model.optimizer)
179    checkpoint_manager = checkpoint_management.CheckpointManager(
180        checkpoint, checkpoint_dir, max_to_keep=2)
181    logging.info('Checkpoint manager saved to: %s', checkpoint_manager.save())
182    self.assertNotEmpty(
183        file_io.list_directory_v2(checkpoint_dir),
184        'Checkpoint should have been written and '
185        'checkpoint_dir should not be empty.')
186
187    # Create a new model used for evaluation.
188    eval_model = _test_model_builder(
189        model_type=model_type, compile_model=True, build_model=build_model)
190    # Have a sidecar_evaluator evaluate once.
191    sidecar_evaluator = sidecar_evaluator_lib.SidecarEvaluator(
192        eval_model,
193        data=dataset,
194        checkpoint_dir=checkpoint_dir,
195        max_evaluations=1,
196        callbacks=[keras.callbacks.TensorBoard(log_dir=log_dir)])
197    sidecar_evaluator.start()
198    # Eval model has been restored to the same state as the original model, so
199    # their weights should match. If not, restoration of the model didn't
200    # work.
201    self.assertModelsSameVariables(model, eval_model)
202
203    self.assertSummaryEventsWritten(os.path.join(log_dir, 'validation'))
204
205  @ds_combinations.generate(
206      combinations.combine(
207          mode=['eager'],
208          model_type=[ModelType.SEQUENTIAL, ModelType.SUBCLASS],
209          build_model=[True, False]))
210  def testSidecarEvaluatorOutputsSummarySavedWithCallback(
211      self, model_type, build_model):
212    checkpoint_dir = os.path.join(self.get_temp_dir(), 'checkpoints')
213    log_dir = os.path.join(self.get_temp_dir(), 'summary')
214    # Create a model with synthetic data, and fit for one epoch.
215    model = _test_model_builder(
216        model_type=model_type, compile_model=True, build_model=False)
217    data = np.random.random((1000, 32))
218    labels = np.random.random((1000, 10))
219    dataset = dataset_ops.Dataset.from_tensor_slices((data, labels))
220    dataset = dataset.batch(_BATCH_SIZE)
221    save_callback = keras.callbacks.ModelCheckpoint(
222        filepath=os.path.join(checkpoint_dir, 'ckpt-{epoch}'),
223        save_weights_only=True)
224    model.fit(dataset, epochs=1, callbacks=[save_callback])
225    self.assertNotEmpty(
226        file_io.list_directory_v2(checkpoint_dir),
227        'Checkpoint should have been written and '
228        'checkpoint_dir should not be empty.')
229
230    # Create a new model used for evaluation.
231    eval_model = _test_model_builder(
232        model_type=model_type, compile_model=True, build_model=build_model)
233    # Have an sidecar_evaluator evaluate once.
234    sidecar_evaluator = sidecar_evaluator_lib.SidecarEvaluator(
235        eval_model,
236        data=dataset,
237        checkpoint_dir=checkpoint_dir,
238        max_evaluations=1,
239        callbacks=[keras.callbacks.TensorBoard(log_dir=log_dir)])
240    with self.assertLogs() as cm:
241      sidecar_evaluator.start()
242
243    metrics_logging = [
244        line for line in cm.output if 'End of evaluation' in line
245    ]
246    self.assertLen(metrics_logging, 1)
247    expected_logged_metrics = [
248        'loss', 'categorical_accuracy', 'mean_squared_error_1',
249        'mean_squared_error_2'
250    ]
251    for metric_name in expected_logged_metrics:
252      self.assertRegex(metrics_logging[0], f'{metric_name}=')
253
254    # Eval model has been restored to the same state as the original model, so
255    # their weights should match. If not, restoration of the model didn't
256    # work.
257    self.assertModelsSameVariables(model, eval_model)
258
259    # check the iterations is restored.
260    self.assertEqual(sidecar_evaluator._iterations.numpy(), _BATCH_SIZE)
261
262    self.assertSummaryEventsWritten(os.path.join(log_dir, 'validation'))
263
264
265if __name__ == '__main__':
266  v2_compat.enable_v2_behavior()
267  test.main()
268