• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for model.fit calls with a Dataset object passed as validation_data."""
16
17import io
18import sys
19
20from absl.testing import parameterized
21import numpy as np
22
23from tensorflow.python import keras
24from tensorflow.python.data.ops import dataset_ops
25from tensorflow.python.framework import test_util
26from tensorflow.python.keras import keras_parameterized
27from tensorflow.python.keras import testing_utils
28from tensorflow.python.keras.layers import core
29from tensorflow.python.platform import test
30
31
32@keras_parameterized.run_with_all_model_types
33@keras_parameterized.run_all_keras_modes
34class ValidationDatasetNoLimitTest(keras_parameterized.TestCase):
35
36  def create_dataset(self, num_samples, batch_size):
37    input_data = np.random.rand(num_samples, 1)
38    expected_data = input_data * 3
39    dataset = dataset_ops.Dataset.from_tensor_slices((input_data,
40                                                      expected_data))
41    return dataset.shuffle(10 * batch_size).batch(batch_size)
42
43  def test_validation_dataset_with_no_step_arg(self):
44    # Create a model that learns y=Mx.
45    layers = [core.Dense(1)]
46    model = testing_utils.get_model_from_layers(layers, input_shape=(1,))
47    model.compile(loss="mse", optimizer="adam", metrics=["mean_absolute_error"])
48
49    train_dataset = self.create_dataset(num_samples=200, batch_size=10)
50    eval_dataset = self.create_dataset(num_samples=50, batch_size=25)
51
52    history = model.fit(x=train_dataset, validation_data=eval_dataset, epochs=2)
53    evaluation = model.evaluate(x=eval_dataset)
54
55    # If the fit call used the entire dataset, then the final val MAE error
56    # from the fit history should be equal to the final element in the output
57    # of evaluating the model on the same eval dataset.
58    self.assertAlmostEqual(history.history["val_mean_absolute_error"][-1],
59                           evaluation[-1], places=5)
60
61
62class PrintTrainingInfoTest(keras_parameterized.TestCase,
63                            parameterized.TestCase):
64
65  @test_util.run_v1_only("Only relevant in graph mode.")
66  def test_print_info_with_datasets(self):
67    """Print training info should work with val datasets (b/133391839)."""
68
69    model = keras.models.Sequential([keras.layers.Dense(1, input_shape=(1,))])
70    model.compile(loss="mse", optimizer="sgd")
71
72    dataset = dataset_ops.Dataset.from_tensors(
73        ([1.], [1.])).repeat(100).batch(10)
74
75    val_dataset = dataset_ops.Dataset.from_tensors(
76        ([1.], [1.])).repeat(50).batch(10)
77
78    mock_stdout = io.StringIO()
79    with test.mock.patch.object(sys, "stdout", mock_stdout):
80      model.fit(dataset, epochs=2, validation_data=val_dataset)
81
82    self.assertIn(
83        "Train on 10 steps, validate on 5 steps", mock_stdout.getvalue())
84
85  @parameterized.named_parameters(
86      ("with_validation", True), ("without_validation", False))
87  @test_util.run_v1_only("Only relevant in graph mode.")
88  def test_print_info_with_numpy(self, do_validation):
89    """Print training info should work with val datasets (b/133391839)."""
90
91    model = keras.models.Sequential([keras.layers.Dense(1, input_shape=(2,))])
92    model.compile(loss="mse", optimizer="sgd")
93
94    dataset = np.arange(200).reshape(100, 2)
95
96    if do_validation:
97      val_data = (np.arange(100).reshape(50, 2), np.arange(50).reshape(50, 1))
98    else:
99      val_data = None
100
101    mock_stdout = io.StringIO()
102    with test.mock.patch.object(sys, "stdout", mock_stdout):
103      model.fit(dataset, batch_size=10, epochs=2, validation_data=val_data)
104
105    self.assertIn("Train on 100 samples", mock_stdout.getvalue())
106
107    if do_validation:
108      self.assertIn(", validate on 50 samples", mock_stdout.getvalue())
109
110  @keras_parameterized.run_all_keras_modes
111  def test_dict_float64_input(self):
112
113    class MyModel(keras.Model):
114
115      def __init__(self):
116        super(MyModel, self).__init__(self)
117        self.dense1 = keras.layers.Dense(10, activation="relu")
118        self.dense2 = keras.layers.Dense(10, activation="relu")
119        self.concat = keras.layers.Concatenate()
120        self.dense3 = keras.layers.Dense(1, activation="sigmoid")
121
122      def call(self, inputs):
123        d1 = self.dense1(inputs["one"])
124        d2 = self.dense2(inputs["two"])
125        concat = self.concat([d1, d2])
126        return self.dense3(concat)
127
128    model = MyModel()
129    model.compile(
130        loss="mae",
131        optimizer="adam",
132        run_eagerly=testing_utils.should_run_eagerly())
133
134    model.fit(
135        x={
136            "one": np.random.rand(100, 10, 1),
137            "two": np.random.rand(100, 10, 1)
138        },
139        y=np.random.rand(100, 10, 1))
140
141  def test_dict_validation_input(self):
142    """Test case for GitHub issue 30122."""
143    train_input_0 = np.random.rand(1000, 1)
144    train_input_1 = np.random.rand(1000, 1)
145    train_labels = np.random.rand(1000, 1)
146    val_input_0 = np.random.rand(1000, 1)
147    val_input_1 = np.random.rand(1000, 1)
148    val_labels = np.random.rand(1000, 1)
149
150    input_0 = keras.Input(shape=(None,), name="input_0")
151    input_1 = keras.Input(shape=(None,), name="input_1")
152
153    class my_model(keras.Model):
154
155      def __init__(self):
156        super(my_model, self).__init__(self)
157        self.hidden_layer_0 = keras.layers.Dense(100, activation="relu")
158        self.hidden_layer_1 = keras.layers.Dense(100, activation="relu")
159        self.concat = keras.layers.Concatenate()
160        self.out_layer = keras.layers.Dense(1, activation="sigmoid")
161
162      def call(self, inputs=[input_0, input_1]):
163        activation_0 = self.hidden_layer_0(inputs["input_0"])
164        activation_1 = self.hidden_layer_1(inputs["input_1"])
165        concat = self.concat([activation_0, activation_1])
166        return self.out_layer(concat)
167
168    model = my_model()
169    model.compile(loss="mae", optimizer="adam")
170
171    model.fit(
172        x={
173            "input_0": train_input_0,
174            "input_1": train_input_1
175        },
176        y=train_labels,
177        validation_data=({
178            "input_0": val_input_0,
179            "input_1": val_input_1
180        }, val_labels))
181
182
183if __name__ == "__main__":
184  test.main()
185