1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for model.fit calls with a Dataset object passed as validation_data.""" 16 17import io 18import sys 19 20from absl.testing import parameterized 21import numpy as np 22 23from tensorflow.python import keras 24from tensorflow.python.data.ops import dataset_ops 25from tensorflow.python.framework import test_util 26from tensorflow.python.keras import keras_parameterized 27from tensorflow.python.keras import testing_utils 28from tensorflow.python.keras.layers import core 29from tensorflow.python.platform import test 30 31 32@keras_parameterized.run_with_all_model_types 33@keras_parameterized.run_all_keras_modes 34class ValidationDatasetNoLimitTest(keras_parameterized.TestCase): 35 36 def create_dataset(self, num_samples, batch_size): 37 input_data = np.random.rand(num_samples, 1) 38 expected_data = input_data * 3 39 dataset = dataset_ops.Dataset.from_tensor_slices((input_data, 40 expected_data)) 41 return dataset.shuffle(10 * batch_size).batch(batch_size) 42 43 def test_validation_dataset_with_no_step_arg(self): 44 # Create a model that learns y=Mx. 45 layers = [core.Dense(1)] 46 model = testing_utils.get_model_from_layers(layers, input_shape=(1,)) 47 model.compile(loss="mse", optimizer="adam", metrics=["mean_absolute_error"]) 48 49 train_dataset = self.create_dataset(num_samples=200, batch_size=10) 50 eval_dataset = self.create_dataset(num_samples=50, batch_size=25) 51 52 history = model.fit(x=train_dataset, validation_data=eval_dataset, epochs=2) 53 evaluation = model.evaluate(x=eval_dataset) 54 55 # If the fit call used the entire dataset, then the final val MAE error 56 # from the fit history should be equal to the final element in the output 57 # of evaluating the model on the same eval dataset. 58 self.assertAlmostEqual(history.history["val_mean_absolute_error"][-1], 59 evaluation[-1], places=5) 60 61 62class PrintTrainingInfoTest(keras_parameterized.TestCase, 63 parameterized.TestCase): 64 65 @test_util.run_v1_only("Only relevant in graph mode.") 66 def test_print_info_with_datasets(self): 67 """Print training info should work with val datasets (b/133391839).""" 68 69 model = keras.models.Sequential([keras.layers.Dense(1, input_shape=(1,))]) 70 model.compile(loss="mse", optimizer="sgd") 71 72 dataset = dataset_ops.Dataset.from_tensors( 73 ([1.], [1.])).repeat(100).batch(10) 74 75 val_dataset = dataset_ops.Dataset.from_tensors( 76 ([1.], [1.])).repeat(50).batch(10) 77 78 mock_stdout = io.StringIO() 79 with test.mock.patch.object(sys, "stdout", mock_stdout): 80 model.fit(dataset, epochs=2, validation_data=val_dataset) 81 82 self.assertIn( 83 "Train on 10 steps, validate on 5 steps", mock_stdout.getvalue()) 84 85 @parameterized.named_parameters( 86 ("with_validation", True), ("without_validation", False)) 87 @test_util.run_v1_only("Only relevant in graph mode.") 88 def test_print_info_with_numpy(self, do_validation): 89 """Print training info should work with val datasets (b/133391839).""" 90 91 model = keras.models.Sequential([keras.layers.Dense(1, input_shape=(2,))]) 92 model.compile(loss="mse", optimizer="sgd") 93 94 dataset = np.arange(200).reshape(100, 2) 95 96 if do_validation: 97 val_data = (np.arange(100).reshape(50, 2), np.arange(50).reshape(50, 1)) 98 else: 99 val_data = None 100 101 mock_stdout = io.StringIO() 102 with test.mock.patch.object(sys, "stdout", mock_stdout): 103 model.fit(dataset, batch_size=10, epochs=2, validation_data=val_data) 104 105 self.assertIn("Train on 100 samples", mock_stdout.getvalue()) 106 107 if do_validation: 108 self.assertIn(", validate on 50 samples", mock_stdout.getvalue()) 109 110 @keras_parameterized.run_all_keras_modes 111 def test_dict_float64_input(self): 112 113 class MyModel(keras.Model): 114 115 def __init__(self): 116 super(MyModel, self).__init__(self) 117 self.dense1 = keras.layers.Dense(10, activation="relu") 118 self.dense2 = keras.layers.Dense(10, activation="relu") 119 self.concat = keras.layers.Concatenate() 120 self.dense3 = keras.layers.Dense(1, activation="sigmoid") 121 122 def call(self, inputs): 123 d1 = self.dense1(inputs["one"]) 124 d2 = self.dense2(inputs["two"]) 125 concat = self.concat([d1, d2]) 126 return self.dense3(concat) 127 128 model = MyModel() 129 model.compile( 130 loss="mae", 131 optimizer="adam", 132 run_eagerly=testing_utils.should_run_eagerly()) 133 134 model.fit( 135 x={ 136 "one": np.random.rand(100, 10, 1), 137 "two": np.random.rand(100, 10, 1) 138 }, 139 y=np.random.rand(100, 10, 1)) 140 141 def test_dict_validation_input(self): 142 """Test case for GitHub issue 30122.""" 143 train_input_0 = np.random.rand(1000, 1) 144 train_input_1 = np.random.rand(1000, 1) 145 train_labels = np.random.rand(1000, 1) 146 val_input_0 = np.random.rand(1000, 1) 147 val_input_1 = np.random.rand(1000, 1) 148 val_labels = np.random.rand(1000, 1) 149 150 input_0 = keras.Input(shape=(None,), name="input_0") 151 input_1 = keras.Input(shape=(None,), name="input_1") 152 153 class my_model(keras.Model): 154 155 def __init__(self): 156 super(my_model, self).__init__(self) 157 self.hidden_layer_0 = keras.layers.Dense(100, activation="relu") 158 self.hidden_layer_1 = keras.layers.Dense(100, activation="relu") 159 self.concat = keras.layers.Concatenate() 160 self.out_layer = keras.layers.Dense(1, activation="sigmoid") 161 162 def call(self, inputs=[input_0, input_1]): 163 activation_0 = self.hidden_layer_0(inputs["input_0"]) 164 activation_1 = self.hidden_layer_1(inputs["input_1"]) 165 concat = self.concat([activation_0, activation_1]) 166 return self.out_layer(concat) 167 168 model = my_model() 169 model.compile(loss="mae", optimizer="adam") 170 171 model.fit( 172 x={ 173 "input_0": train_input_0, 174 "input_1": train_input_1 175 }, 176 y=train_labels, 177 validation_data=({ 178 "input_0": val_input_0, 179 "input_1": val_input_1 180 }, val_labels)) 181 182 183if __name__ == "__main__": 184 test.main() 185