1# Lint as: python3 2# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Tests for `DatasetCreator` with `Model.fit` across usages and strategies.""" 17 18import os 19 20from absl.testing import parameterized 21import numpy as np 22 23from tensorflow.python import keras 24from tensorflow.python.data.ops import dataset_ops 25from tensorflow.python.framework import constant_op 26from tensorflow.python.keras import callbacks as callbacks_lib 27from tensorflow.python.keras.engine import sequential 28from tensorflow.python.keras.layers import core as core_layers 29from tensorflow.python.keras.layers.preprocessing import string_lookup 30from tensorflow.python.keras.optimizer_v2 import gradient_descent 31from tensorflow.python.keras.utils import dataset_creator 32from tensorflow.python.ops import random_ops 33from tensorflow.python.platform import test 34from tensorflow.python.platform import tf_logging as logging 35 36 37class DatasetCreatorModelFitTestBase(test.TestCase, parameterized.TestCase): 38 """The base class for DatasetCreator with Model.fit tests.""" 39 40 def _get_dataset_fn(self, use_lookup_layer): 41 42 if use_lookup_layer: 43 44 filepath = os.path.join(self.get_temp_dir(), "vocab") 45 with open(filepath, "w") as f: 46 f.write("\n".join(["earth", "wind", "and", "fire"])) 47 48 def dataset_fn(input_context): 49 del input_context 50 lookup_layer = string_lookup.StringLookup( 51 num_oov_indices=1, vocabulary=filepath) 52 x = np.array([["earth", "wind", "and", "fire"], 53 ["fire", "and", "earth", "michigan"]]) 54 y = np.array([0, 1]) 55 map_fn = lambda x, y: (lookup_layer(x), y) 56 return dataset_ops.DatasetV2.from_tensor_slices( 57 (x, y)).shuffle(10).repeat().batch(2).map(map_fn) 58 59 else: 60 61 def dataset_fn(input_context): 62 del input_context 63 x = random_ops.random_uniform((10, 10)) 64 y = random_ops.random_uniform((10,)) 65 return dataset_ops.DatasetV2.from_tensor_slices( 66 (x, y)).shuffle(10).repeat().batch(2) 67 68 return dataset_fn 69 70 def _model_compile(self, 71 strategy, 72 steps_per_execution=1, 73 run_eagerly=False, 74 with_normalization_layer=False, 75 use_lookup_layer=False): 76 77 class ResultAssertingCallback(callbacks_lib.Callback): 78 """A callback that asserts the result of the tests.""" 79 80 def __init__(self): 81 self._prev_epoch = -1 82 83 def on_epoch_end(self, epoch, logs=None): 84 logging.info("testModelFit: epoch=%r, logs=%r", epoch, logs) 85 if epoch <= self._prev_epoch: 86 raise RuntimeError("Epoch is supposed to be larger than previous.") 87 self._prev_epoch = epoch 88 is_loss_float = ( 89 logs.get("loss", None) is not None and 90 isinstance(logs["loss"], (float, np.floating))) 91 if not is_loss_float: 92 raise RuntimeError("loss is supposed to be in the logs and float.") 93 94 with strategy.scope(): 95 model = sequential.Sequential([core_layers.Dense(10)]) 96 if with_normalization_layer: 97 norm = keras.layers.BatchNormalization( 98 axis=-1, input_shape=(4, 4, 3), momentum=0.8) 99 model.add(norm) 100 model.add(core_layers.Dense(1, activation="sigmoid")) 101 self._accuracy_metric = keras.metrics.Accuracy() 102 103 model.compile( 104 gradient_descent.SGD(), 105 loss="binary_crossentropy", 106 metrics=[self._accuracy_metric], 107 steps_per_execution=steps_per_execution, 108 run_eagerly=run_eagerly) 109 return model, [ResultAssertingCallback()] 110 111 def _model_fit(self, 112 strategy, 113 steps_per_execution=1, 114 validation_data=None, 115 x=None, 116 y=None, 117 shuffle=True, 118 batch_size=None, 119 steps_per_epoch=10, 120 run_eagerly=False, 121 with_normalization_layer=False, 122 callbacks=None, 123 use_lookup_layer=False): 124 if callbacks is None: 125 callbacks = [] 126 127 model, default_callbacks = self._model_compile(strategy, 128 steps_per_execution, 129 run_eagerly, 130 with_normalization_layer, 131 use_lookup_layer) 132 callbacks += default_callbacks 133 134 if x is None: 135 x = dataset_creator.DatasetCreator(self._get_dataset_fn(use_lookup_layer)) 136 137 if validation_data is None: 138 validation_data = dataset_creator.DatasetCreator( 139 self._get_dataset_fn(use_lookup_layer)) 140 141 model.fit( 142 x, 143 y, 144 shuffle=shuffle, 145 batch_size=batch_size, 146 epochs=10, 147 steps_per_epoch=steps_per_epoch, 148 callbacks=callbacks, 149 validation_data=validation_data, 150 validation_steps=steps_per_epoch) 151 return model 152 153 def _model_evaluate(self, 154 strategy, 155 steps_per_execution=1, 156 x=None, 157 y=None, 158 batch_size=None, 159 steps=10, 160 run_eagerly=False, 161 with_normalization_layer=False, 162 callbacks=None): 163 if callbacks is None: 164 callbacks = [] 165 166 model, default_callbacks = self._model_compile( 167 strategy, 168 steps_per_execution, 169 run_eagerly, 170 with_normalization_layer, 171 ) 172 callbacks += default_callbacks 173 174 def dataset_fn(input_context): 175 del input_context 176 x = random_ops.random_uniform((10, 10)) 177 y = random_ops.random_uniform((10, 1)) 178 return dataset_ops.DatasetV2.from_tensor_slices( 179 (x, y)).shuffle(10).repeat().batch(8) 180 181 if x is None: 182 x = dataset_creator.DatasetCreator(dataset_fn) 183 184 model.evaluate( 185 x=x, y=y, steps=steps, callbacks=callbacks, batch_size=batch_size) 186 return model 187 188 def _model_predict( 189 self, 190 strategy, 191 model=None, 192 steps_per_execution=1, 193 test_data=None, 194 steps=10, 195 with_normalization_layer=False, 196 ): 197 callbacks = [] 198 199 if model is None: 200 model, default_callbacks = self._model_compile( 201 strategy, 202 steps_per_execution, 203 with_normalization_layer=with_normalization_layer, 204 ) 205 callbacks += default_callbacks 206 207 def create_test_data(): 208 x = constant_op.constant([1., 2., 3., 1., 5., 1.]) 209 return dataset_ops.DatasetV2.from_tensor_slices(x).repeat().batch(2) 210 211 if test_data is None: 212 test_data = create_test_data() 213 214 predictions = model.predict(x=test_data, steps=steps, callbacks=callbacks) 215 predictions = np.around(predictions, 4) 216 return model, predictions 217