• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for training routines."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import unittest
23
24from absl.testing import parameterized
25import numpy as np
26
27from tensorflow.python import keras
28from tensorflow.python.data.ops import dataset_ops
29from tensorflow.python.data.ops import iterator_ops
30from tensorflow.python.eager import context
31from tensorflow.python.framework import test_util as tf_test_util
32from tensorflow.python.keras import keras_parameterized
33from tensorflow.python.keras import metrics as metrics_module
34from tensorflow.python.keras import testing_utils
35from tensorflow.python.keras.engine import training_generator
36from tensorflow.python.keras.optimizer_v2 import rmsprop
37from tensorflow.python.platform import test
38from tensorflow.python.util import nest
39
40
41def custom_generator(mode=2):
42  batch_size = 10
43  num_samples = 50
44  arr_data = np.random.random((num_samples, 2))
45  arr_labels = np.random.random((num_samples, 4))
46  arr_weights = np.random.random((num_samples,))
47  i = 0
48  while True:
49    batch_index = i * batch_size % num_samples
50    i += 1
51    start = batch_index
52    end = start + batch_size
53    x = arr_data[start: end]
54    y = arr_labels[start: end]
55    w = arr_weights[start: end]
56    if mode == 1:
57      yield x
58    elif mode == 2:
59      yield x, y
60    else:
61      yield x, y, w
62
63
64class TestGeneratorMethods(keras_parameterized.TestCase):
65
66  @unittest.skipIf(
67      os.name == 'nt',
68      'use_multiprocessing=True does not work on windows properly.')
69  @keras_parameterized.run_with_all_model_types
70  @keras_parameterized.run_all_keras_modes
71  def test_fit_generator_method(self):
72    model = testing_utils.get_small_mlp(
73        num_hidden=3, num_classes=4, input_dim=2)
74    model.compile(
75        loss='mse',
76        optimizer=rmsprop.RMSprop(1e-3),
77        metrics=['mae', metrics_module.CategoricalAccuracy()])
78
79    model.fit_generator(custom_generator(),
80                        steps_per_epoch=5,
81                        epochs=1,
82                        verbose=1,
83                        max_queue_size=10,
84                        workers=4,
85                        use_multiprocessing=True)
86    model.fit_generator(custom_generator(),
87                        steps_per_epoch=5,
88                        epochs=1,
89                        verbose=1,
90                        max_queue_size=10,
91                        use_multiprocessing=False)
92    model.fit_generator(custom_generator(),
93                        steps_per_epoch=5,
94                        epochs=1,
95                        verbose=1,
96                        max_queue_size=10,
97                        use_multiprocessing=False,
98                        validation_data=custom_generator(),
99                        validation_steps=10)
100    model.fit_generator(custom_generator(),
101                        steps_per_epoch=5,
102                        validation_data=custom_generator(),
103                        validation_steps=1,
104                        workers=0)
105
106  @unittest.skipIf(
107      os.name == 'nt',
108      'use_multiprocessing=True does not work on windows properly.')
109  @keras_parameterized.run_with_all_model_types
110  @keras_parameterized.run_all_keras_modes
111  def test_evaluate_generator_method(self):
112    model = testing_utils.get_small_mlp(
113        num_hidden=3, num_classes=4, input_dim=2)
114    model.compile(
115        loss='mse',
116        optimizer=rmsprop.RMSprop(1e-3),
117        metrics=['mae', metrics_module.CategoricalAccuracy()],
118        run_eagerly=testing_utils.should_run_eagerly())
119
120    model.evaluate_generator(custom_generator(),
121                             steps=5,
122                             max_queue_size=10,
123                             workers=2,
124                             verbose=1,
125                             use_multiprocessing=True)
126    model.evaluate_generator(custom_generator(),
127                             steps=5,
128                             max_queue_size=10,
129                             use_multiprocessing=False)
130    model.evaluate_generator(custom_generator(),
131                             steps=5,
132                             max_queue_size=10,
133                             use_multiprocessing=False,
134                             workers=0)
135
136  @unittest.skipIf(
137      os.name == 'nt',
138      'use_multiprocessing=True does not work on windows properly.')
139  @keras_parameterized.run_with_all_model_types
140  @keras_parameterized.run_all_keras_modes
141  def test_predict_generator_method(self):
142    model = testing_utils.get_small_mlp(
143        num_hidden=3, num_classes=4, input_dim=2)
144    model.run_eagerly = testing_utils.should_run_eagerly()
145
146    model.predict_generator(custom_generator(),
147                            steps=5,
148                            max_queue_size=10,
149                            workers=2,
150                            use_multiprocessing=True)
151    model.predict_generator(custom_generator(),
152                            steps=5,
153                            max_queue_size=10,
154                            use_multiprocessing=False)
155    model.predict_generator(custom_generator(),
156                            steps=5,
157                            max_queue_size=10,
158                            workers=0)
159    # Test generator with just inputs (no targets)
160    model.predict_generator(custom_generator(mode=1),
161                            steps=5,
162                            max_queue_size=10,
163                            workers=2,
164                            use_multiprocessing=True)
165    model.predict_generator(custom_generator(mode=1),
166                            steps=5,
167                            max_queue_size=10,
168                            use_multiprocessing=False)
169    model.predict_generator(custom_generator(mode=1),
170                            steps=5,
171                            max_queue_size=10,
172                            workers=0)
173
174  @keras_parameterized.run_with_all_model_types
175  @keras_parameterized.run_all_keras_modes
176  def test_generator_methods_with_sample_weights(self):
177    model = testing_utils.get_small_mlp(
178        num_hidden=3, num_classes=4, input_dim=2)
179    model.compile(
180        loss='mse',
181        optimizer=rmsprop.RMSprop(1e-3),
182        metrics=['mae', metrics_module.CategoricalAccuracy()],
183        run_eagerly=testing_utils.should_run_eagerly())
184
185    model.fit_generator(custom_generator(mode=3),
186                        steps_per_epoch=5,
187                        epochs=1,
188                        verbose=1,
189                        max_queue_size=10,
190                        use_multiprocessing=False)
191    model.fit_generator(custom_generator(mode=3),
192                        steps_per_epoch=5,
193                        epochs=1,
194                        verbose=1,
195                        max_queue_size=10,
196                        use_multiprocessing=False,
197                        validation_data=custom_generator(mode=3),
198                        validation_steps=10)
199    model.predict_generator(custom_generator(mode=3),
200                            steps=5,
201                            max_queue_size=10,
202                            use_multiprocessing=False)
203    model.evaluate_generator(custom_generator(mode=3),
204                             steps=5,
205                             max_queue_size=10,
206                             use_multiprocessing=False)
207
208  @keras_parameterized.run_with_all_model_types
209  @keras_parameterized.run_all_keras_modes
210  def test_generator_methods_invalid_use_case(self):
211
212    def invalid_generator():
213      while 1:
214        yield 0
215
216    model = testing_utils.get_small_mlp(
217        num_hidden=3, num_classes=4, input_dim=2)
218    model.compile(loss='mse', optimizer=rmsprop.RMSprop(1e-3),
219                  run_eagerly=testing_utils.should_run_eagerly())
220
221    with self.assertRaises(ValueError):
222      model.fit_generator(invalid_generator(),
223                          steps_per_epoch=5,
224                          epochs=1,
225                          verbose=1,
226                          max_queue_size=10,
227                          use_multiprocessing=False)
228    with self.assertRaises(ValueError):
229      model.fit_generator(custom_generator(),
230                          steps_per_epoch=5,
231                          epochs=1,
232                          verbose=1,
233                          max_queue_size=10,
234                          use_multiprocessing=False,
235                          validation_data=invalid_generator(),
236                          validation_steps=10)
237    with self.assertRaises(AttributeError):
238      model.predict_generator(invalid_generator(),
239                              steps=5,
240                              max_queue_size=10,
241                              use_multiprocessing=False)
242    with self.assertRaises(ValueError):
243      model.evaluate_generator(invalid_generator(),
244                               steps=5,
245                               max_queue_size=10,
246                               use_multiprocessing=False)
247
248  @keras_parameterized.run_with_all_model_types
249  @keras_parameterized.run_all_keras_modes
250  def test_generator_input_to_fit_eval_predict(self):
251    val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
252
253    def ones_generator():
254      while True:
255        yield np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
256
257    model = testing_utils.get_small_mlp(
258        num_hidden=10, num_classes=1, input_dim=10)
259
260    model.compile(rmsprop.RMSprop(0.001), 'binary_crossentropy',
261                  run_eagerly=testing_utils.should_run_eagerly())
262    model.fit(
263        ones_generator(),
264        steps_per_epoch=2,
265        validation_data=val_data,
266        epochs=2)
267    model.evaluate(ones_generator(), steps=2)
268    model.predict(ones_generator(), steps=2)
269
270
271class TestGeneratorMethodsWithSequences(keras_parameterized.TestCase):
272
273  @keras_parameterized.run_with_all_model_types
274  @keras_parameterized.run_all_keras_modes
275  def test_training_with_sequences(self):
276
277    class DummySequence(keras.utils.Sequence):
278
279      def __getitem__(self, idx):
280        return np.zeros([10, 2]), np.ones([10, 4])
281
282      def __len__(self):
283        return 10
284
285    model = testing_utils.get_small_mlp(
286        num_hidden=3, num_classes=4, input_dim=2)
287    model.compile(loss='mse', optimizer=rmsprop.RMSprop(1e-3))
288
289    model.fit_generator(DummySequence(),
290                        steps_per_epoch=10,
291                        validation_data=custom_generator(),
292                        validation_steps=1,
293                        max_queue_size=10,
294                        workers=0,
295                        use_multiprocessing=True)
296    model.fit_generator(DummySequence(),
297                        steps_per_epoch=10,
298                        validation_data=custom_generator(),
299                        validation_steps=1,
300                        max_queue_size=10,
301                        workers=0,
302                        use_multiprocessing=False)
303
304  @keras_parameterized.run_with_all_model_types
305  @keras_parameterized.run_all_keras_modes
306  def test_sequence_input_to_fit_eval_predict(self):
307    val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
308
309    class CustomSequence(keras.utils.Sequence):
310
311      def __getitem__(self, idx):
312        return np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
313
314      def __len__(self):
315        return 2
316
317    model = testing_utils.get_small_mlp(
318        num_hidden=10, num_classes=1, input_dim=10)
319
320    model.compile(rmsprop.RMSprop(0.001), 'binary_crossentropy')
321    model.fit(CustomSequence(), validation_data=val_data, epochs=2)
322    model.evaluate(CustomSequence())
323    model.predict(CustomSequence())
324
325    with self.assertRaisesRegexp(ValueError, '`y` argument is not supported'):
326      model.fit(CustomSequence(), y=np.ones([10, 1]))
327
328    with self.assertRaisesRegexp(ValueError,
329                                 '`sample_weight` argument is not supported'):
330      model.fit(CustomSequence(), sample_weight=np.ones([10, 1]))
331
332
333@tf_test_util.run_all_in_graph_and_eager_modes
334class TestConvertToGeneratorLike(test.TestCase, parameterized.TestCase):
335  simple_inputs = (np.ones((10, 10)), np.ones((10, 1)))
336  nested_inputs = ((np.ones((10, 10)), np.ones((10, 20))), (np.ones((10, 1)),
337                                                            np.ones((10, 3))))
338
339  def _make_dataset(self, inputs, batches):
340    return dataset_ops.DatasetV2.from_tensors(inputs).repeat(batches)
341
342  def _make_iterator(self, inputs, batches):
343    return dataset_ops.make_one_shot_iterator(
344        self._make_dataset(inputs, batches))
345
346  def _make_generator(self, inputs, batches):
347
348    def _gen():
349      for _ in range(batches):
350        yield inputs
351
352    return _gen()
353
354  def _make_numpy(self, inputs, _):
355    return inputs
356
357  @parameterized.named_parameters(
358      ('simple_dataset', _make_dataset, simple_inputs),
359      ('simple_iterator', _make_iterator, simple_inputs),
360      ('simple_generator', _make_generator, simple_inputs),
361      ('simple_numpy', _make_numpy, simple_inputs),
362      ('nested_dataset', _make_dataset, nested_inputs),
363      ('nested_iterator', _make_iterator, nested_inputs),
364      ('nested_generator', _make_generator, nested_inputs),
365      ('nested_numpy', _make_numpy, nested_inputs))
366  def test_convert_to_generator_like(self, input_fn, inputs):
367    expected_batches = 5
368    data = input_fn(self, inputs, expected_batches)
369
370    # Dataset and Iterator not supported in Legacy Graph mode.
371    if (not context.executing_eagerly() and
372        isinstance(data, (dataset_ops.DatasetV2, iterator_ops.Iterator))):
373      return
374
375    generator, steps = training_generator.convert_to_generator_like(
376        data, batch_size=2, steps_per_epoch=expected_batches)
377    self.assertEqual(steps, expected_batches)
378
379    for _ in range(expected_batches):
380      outputs = next(generator)
381    nest.assert_same_structure(outputs, inputs)
382
383
384if __name__ == '__main__':
385  test.main()
386