1# Lint as: python3 2# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Test covering sidecar_evaluator.py.""" 17 18import enum 19import os 20 21from absl import logging 22from absl.testing import parameterized 23import numpy as np 24 25from tensorflow.python import keras 26from tensorflow.python.compat import v2_compat 27from tensorflow.python.data.ops import dataset_ops 28from tensorflow.python.distribute import combinations as ds_combinations 29from tensorflow.python.framework import test_combinations as combinations 30from tensorflow.python.keras.distribute import sidecar_evaluator as sidecar_evaluator_lib 31from tensorflow.python.keras.optimizer_v2 import gradient_descent 32from tensorflow.python.lib.io import file_io 33from tensorflow.python.platform import test 34from tensorflow.python.summary import summary_iterator 35from tensorflow.python.training import checkpoint_management 36from tensorflow.python.training.tracking import util as tracking_util 37 38_BATCH_SIZE = 32 39 40 41class TestModel(keras.Model): 42 43 def __init__(self): 44 super().__init__(name='test_model') 45 self.dense = keras.layers.Dense(10) 46 47 def call(self, inputs): 48 return self.dense(inputs) 49 50 51class DictMetric(keras.metrics.MeanSquaredError): 52 53 def result(self): 54 res = super().result() 55 return {'mean_squared_error_1': res, 'mean_squared_error_2': res} 56 57 58class ModelType(enum.Enum): 59 SEQUENTIAL = 'sequential' 60 SUBCLASS = 'subclass' 61 62 63def _test_model_builder(model_type: ModelType, compile_model, build_model): 64 if model_type == ModelType.SEQUENTIAL: 65 model = keras.Sequential([keras.layers.Dense(10)]) 66 elif model_type == ModelType.SUBCLASS: 67 model = TestModel() 68 69 if compile_model: 70 model.compile( 71 gradient_descent.SGD(), 72 loss='mse', 73 metrics=[keras.metrics.CategoricalAccuracy(), 74 DictMetric()]) 75 if build_model: 76 model.build((None, 32)) 77 78 return model 79 80 81class SidecarEvaluatorTest(test.TestCase, parameterized.TestCase): 82 83 def assertSummaryEventsWritten(self, log_dir): 84 # Asserts summary files do get written when log_dir is provided. 85 summary_files = file_io.list_directory_v2(log_dir) 86 self.assertNotEmpty( 87 summary_files, 'Summary should have been written and ' 88 'log_dir should not be empty.') 89 90 # Asserts the content of the summary file. 91 event_pb_written = False 92 event_tags = [] 93 for summary_file in summary_files: 94 for event_pb in summary_iterator.summary_iterator( 95 os.path.join(log_dir, summary_file)): 96 if event_pb.step > 0: 97 self.assertEqual(event_pb.step, 32) 98 event_tags.append(event_pb.summary.value[0].tag) 99 event_pb_written = True 100 self.assertCountEqual(event_tags, [ 101 'evaluation_categorical_accuracy_vs_iterations', 102 'evaluation_loss_vs_iterations', 103 'evaluation_mean_squared_error_1_vs_iterations', 104 'evaluation_mean_squared_error_2_vs_iterations', 105 ]) 106 107 # Verifying at least one non-zeroth step is written to summary. 108 self.assertTrue(event_pb_written) 109 110 def assertModelsSameVariables(self, model_a, model_b): 111 # Check both have the same number of variables. 112 self.assertEqual(len(model_a.variables), len(model_b.variables)) 113 114 # Check variable values to be equal. 115 for var_a, var_b in zip(model_a.variables, model_b.variables): 116 self.assertAllEqual(var_a.numpy(), var_b.numpy()) 117 118 @ds_combinations.generate( 119 combinations.combine( 120 mode=['eager'], model_type=[ModelType.SEQUENTIAL, 121 ModelType.SUBCLASS])) 122 def testIterationsNotSavedWillRaiseError(self, model_type): 123 model = _test_model_builder( 124 model_type=model_type, compile_model=False, build_model=True) 125 126 checkpoint_dir = self.get_temp_dir() 127 checkpoint = tracking_util.Checkpoint(model=model) 128 checkpoint_manager = checkpoint_management.CheckpointManager( 129 checkpoint, checkpoint_dir, max_to_keep=2) 130 checkpoint_manager.save() 131 132 sidecar_evaluator = sidecar_evaluator_lib.SidecarEvaluator( 133 model, data=None, checkpoint_dir=checkpoint_dir) 134 with self.assertRaisesRegex( 135 RuntimeError, '`iterations` cannot be loaded ' 136 'from the checkpoint file.'): 137 sidecar_evaluator.start() 138 139 @ds_combinations.generate( 140 combinations.combine( 141 mode=['eager'], model_type=[ModelType.SEQUENTIAL, 142 ModelType.SUBCLASS])) 143 def testModelNotBuiltRaiseError(self, model_type): 144 model = _test_model_builder( 145 model_type=model_type, compile_model=False, build_model=False) 146 147 checkpoint_dir = self.get_temp_dir() 148 checkpoint = tracking_util.Checkpoint(model=model) 149 checkpoint_manager = checkpoint_management.CheckpointManager( 150 checkpoint, checkpoint_dir, max_to_keep=2) 151 checkpoint_manager.save() 152 153 sidecar_evaluator = sidecar_evaluator_lib.SidecarEvaluator( 154 model, data=None, checkpoint_dir=checkpoint_dir) 155 with self.assertRaisesRegex(AssertionError, 'Nothing to load.'): 156 sidecar_evaluator.start() 157 158 @ds_combinations.generate( 159 combinations.combine( 160 mode=['eager'], 161 model_type=[ModelType.SEQUENTIAL, ModelType.SUBCLASS], 162 build_model=[True, False])) 163 def testSidecarEvaluatorOutputsSummary(self, model_type, build_model): 164 # Create a model with synthetic data, and fit for one epoch. 165 model = _test_model_builder( 166 model_type=model_type, compile_model=True, build_model=False) 167 data = np.random.random((1000, 32)) 168 labels = np.random.random((1000, 10)) 169 dataset = dataset_ops.Dataset.from_tensor_slices((data, labels)) 170 dataset = dataset.batch(32) 171 model.fit(dataset, epochs=1) 172 173 # Save a checkpoint. 174 checkpoint_dir = os.path.join(self.get_temp_dir(), 'ckpt') 175 log_dir = os.path.join(self.get_temp_dir(), 'summary') 176 logging.info('checkpoint_dir = %s, log_dir = %s', checkpoint_dir, log_dir) 177 checkpoint = tracking_util.Checkpoint( 178 model=model, optimizer=model.optimizer) 179 checkpoint_manager = checkpoint_management.CheckpointManager( 180 checkpoint, checkpoint_dir, max_to_keep=2) 181 logging.info('Checkpoint manager saved to: %s', checkpoint_manager.save()) 182 self.assertNotEmpty( 183 file_io.list_directory_v2(checkpoint_dir), 184 'Checkpoint should have been written and ' 185 'checkpoint_dir should not be empty.') 186 187 # Create a new model used for evaluation. 188 eval_model = _test_model_builder( 189 model_type=model_type, compile_model=True, build_model=build_model) 190 # Have a sidecar_evaluator evaluate once. 191 sidecar_evaluator = sidecar_evaluator_lib.SidecarEvaluator( 192 eval_model, 193 data=dataset, 194 checkpoint_dir=checkpoint_dir, 195 max_evaluations=1, 196 callbacks=[keras.callbacks.TensorBoard(log_dir=log_dir)]) 197 sidecar_evaluator.start() 198 # Eval model has been restored to the same state as the original model, so 199 # their weights should match. If not, restoration of the model didn't 200 # work. 201 self.assertModelsSameVariables(model, eval_model) 202 203 self.assertSummaryEventsWritten(os.path.join(log_dir, 'validation')) 204 205 @ds_combinations.generate( 206 combinations.combine( 207 mode=['eager'], 208 model_type=[ModelType.SEQUENTIAL, ModelType.SUBCLASS], 209 build_model=[True, False])) 210 def testSidecarEvaluatorOutputsSummarySavedWithCallback( 211 self, model_type, build_model): 212 checkpoint_dir = os.path.join(self.get_temp_dir(), 'checkpoints') 213 log_dir = os.path.join(self.get_temp_dir(), 'summary') 214 # Create a model with synthetic data, and fit for one epoch. 215 model = _test_model_builder( 216 model_type=model_type, compile_model=True, build_model=False) 217 data = np.random.random((1000, 32)) 218 labels = np.random.random((1000, 10)) 219 dataset = dataset_ops.Dataset.from_tensor_slices((data, labels)) 220 dataset = dataset.batch(_BATCH_SIZE) 221 save_callback = keras.callbacks.ModelCheckpoint( 222 filepath=os.path.join(checkpoint_dir, 'ckpt-{epoch}'), 223 save_weights_only=True) 224 model.fit(dataset, epochs=1, callbacks=[save_callback]) 225 self.assertNotEmpty( 226 file_io.list_directory_v2(checkpoint_dir), 227 'Checkpoint should have been written and ' 228 'checkpoint_dir should not be empty.') 229 230 # Create a new model used for evaluation. 231 eval_model = _test_model_builder( 232 model_type=model_type, compile_model=True, build_model=build_model) 233 # Have an sidecar_evaluator evaluate once. 234 sidecar_evaluator = sidecar_evaluator_lib.SidecarEvaluator( 235 eval_model, 236 data=dataset, 237 checkpoint_dir=checkpoint_dir, 238 max_evaluations=1, 239 callbacks=[keras.callbacks.TensorBoard(log_dir=log_dir)]) 240 with self.assertLogs() as cm: 241 sidecar_evaluator.start() 242 243 metrics_logging = [ 244 line for line in cm.output if 'End of evaluation' in line 245 ] 246 self.assertLen(metrics_logging, 1) 247 expected_logged_metrics = [ 248 'loss', 'categorical_accuracy', 'mean_squared_error_1', 249 'mean_squared_error_2' 250 ] 251 for metric_name in expected_logged_metrics: 252 self.assertRegex(metrics_logging[0], f'{metric_name}=') 253 254 # Eval model has been restored to the same state as the original model, so 255 # their weights should match. If not, restoration of the model didn't 256 # work. 257 self.assertModelsSameVariables(model, eval_model) 258 259 # check the iterations is restored. 260 self.assertEqual(sidecar_evaluator._iterations.numpy(), _BATCH_SIZE) 261 262 self.assertSummaryEventsWritten(os.path.join(log_dir, 'validation')) 263 264 265if __name__ == '__main__': 266 v2_compat.enable_v2_behavior() 267 test.main() 268