• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for training routines."""
16
17import itertools
18
19from absl.testing import parameterized
20import numpy as np
21
22from tensorflow.python.data.ops import dataset_ops
23from tensorflow.python.data.ops import iterator_ops
24from tensorflow.python.eager import context
25from tensorflow.python.keras import combinations
26from tensorflow.python.keras import keras_parameterized
27from tensorflow.python.keras import layers as layers_module
28from tensorflow.python.keras import losses
29from tensorflow.python.keras import metrics as metrics_module
30from tensorflow.python.keras import testing_utils
31from tensorflow.python.keras.engine import input_layer
32from tensorflow.python.keras.engine import training
33from tensorflow.python.keras.engine import training_generator_v1
34from tensorflow.python.keras.optimizer_v2 import rmsprop
35from tensorflow.python.keras.utils import data_utils
36from tensorflow.python.platform import test
37from tensorflow.python.util import nest
38
39
40def custom_generator(mode=2):
41  batch_size = 10
42  num_samples = 50
43  arr_data = np.random.random((num_samples, 2))
44  arr_labels = np.random.random((num_samples, 4))
45  arr_weights = np.random.random((num_samples,))
46  i = 0
47  while True:
48    batch_index = i * batch_size % num_samples
49    i += 1
50    start = batch_index
51    end = start + batch_size
52    x = arr_data[start: end]
53    y = arr_labels[start: end]
54    w = arr_weights[start: end]
55    if mode == 1:
56      yield x
57    elif mode == 2:
58      yield x, y
59    else:
60      yield x, y, w
61
62
63def custom_generator_changing_batch_size(mode=2):
64  batch_size = 10
65  cur_batch_size = 11
66  num_samples = 50
67  arr_data = np.random.random((num_samples, 2))
68  arr_labels = np.random.random((num_samples, 4))
69  arr_weights = np.random.random((num_samples,))
70  i = 0
71  while True:
72    if cur_batch_size > 1:
73      cur_batch_size -= 1
74    batch_index = i * batch_size % num_samples
75    i += 1
76    start = batch_index
77    end = start + cur_batch_size
78    x = arr_data[start: end]
79    y = arr_labels[start: end]
80    w = arr_weights[start: end]
81    if mode == 1:
82      yield x
83    elif mode == 2:
84      yield x, y
85    else:
86      yield x, y, w
87
88custom_generator_threads = data_utils.threadsafe_generator(custom_generator)
89
90
91class TestGeneratorMethods(keras_parameterized.TestCase):
92
93  @keras_parameterized.run_with_all_model_types
94  @keras_parameterized.run_all_keras_modes
95  @data_utils.dont_use_multiprocessing_pool
96  def test_fit_generator_method(self):
97    model = testing_utils.get_small_mlp(
98        num_hidden=3, num_classes=4, input_dim=2)
99    model.compile(
100        loss='mse',
101        optimizer=rmsprop.RMSprop(1e-3),
102        metrics=['mae', metrics_module.CategoricalAccuracy()])
103
104    model.fit_generator(custom_generator_threads(),
105                        steps_per_epoch=5,
106                        epochs=1,
107                        verbose=1,
108                        max_queue_size=10,
109                        workers=4,
110                        use_multiprocessing=True)
111    model.fit_generator(custom_generator(),
112                        steps_per_epoch=5,
113                        epochs=1,
114                        verbose=1,
115                        max_queue_size=10,
116                        use_multiprocessing=False)
117    model.fit_generator(custom_generator(),
118                        steps_per_epoch=5,
119                        epochs=1,
120                        verbose=1,
121                        max_queue_size=10,
122                        use_multiprocessing=False,
123                        validation_data=custom_generator(),
124                        validation_steps=10)
125    model.fit_generator(custom_generator(),
126                        steps_per_epoch=5,
127                        validation_data=custom_generator(),
128                        validation_steps=1,
129                        workers=0)
130
131  @keras_parameterized.run_with_all_model_types
132  @keras_parameterized.run_all_keras_modes
133  @data_utils.dont_use_multiprocessing_pool
134  def test_evaluate_generator_method(self):
135    model = testing_utils.get_small_mlp(
136        num_hidden=3, num_classes=4, input_dim=2)
137    model.compile(
138        loss='mse',
139        optimizer=rmsprop.RMSprop(1e-3),
140        metrics=['mae', metrics_module.CategoricalAccuracy()],
141        run_eagerly=testing_utils.should_run_eagerly())
142
143    model.evaluate_generator(custom_generator_threads(),
144                             steps=5,
145                             max_queue_size=10,
146                             workers=2,
147                             verbose=1,
148                             use_multiprocessing=True)
149    model.evaluate_generator(custom_generator(),
150                             steps=5,
151                             max_queue_size=10,
152                             use_multiprocessing=False)
153    model.evaluate_generator(custom_generator(),
154                             steps=5,
155                             max_queue_size=10,
156                             use_multiprocessing=False,
157                             workers=0)
158
159  @keras_parameterized.run_with_all_model_types
160  @keras_parameterized.run_all_keras_modes
161  @data_utils.dont_use_multiprocessing_pool
162  def test_predict_generator_method(self):
163    model = testing_utils.get_small_mlp(
164        num_hidden=3, num_classes=4, input_dim=2)
165    model.run_eagerly = testing_utils.should_run_eagerly()
166
167    model.predict_generator(custom_generator_threads(),
168                            steps=5,
169                            max_queue_size=10,
170                            workers=2,
171                            use_multiprocessing=True)
172    model.predict_generator(custom_generator(),
173                            steps=5,
174                            max_queue_size=10,
175                            use_multiprocessing=False)
176    model.predict_generator(custom_generator(),
177                            steps=5,
178                            max_queue_size=10,
179                            workers=0)
180    # Test generator with just inputs (no targets)
181    model.predict_generator(custom_generator_threads(mode=1),
182                            steps=5,
183                            max_queue_size=10,
184                            workers=2,
185                            use_multiprocessing=True)
186    model.predict_generator(custom_generator(mode=1),
187                            steps=5,
188                            max_queue_size=10,
189                            use_multiprocessing=False)
190    model.predict_generator(custom_generator(mode=1),
191                            steps=5,
192                            max_queue_size=10,
193                            workers=0)
194
195  @keras_parameterized.run_with_all_model_types
196  @keras_parameterized.run_all_keras_modes
197  def test_generator_methods_with_sample_weights(self):
198    model = testing_utils.get_small_mlp(
199        num_hidden=3, num_classes=4, input_dim=2)
200    model.compile(
201        loss='mse',
202        optimizer=rmsprop.RMSprop(1e-3),
203        metrics=['mae', metrics_module.CategoricalAccuracy()],
204        run_eagerly=testing_utils.should_run_eagerly())
205
206    model.fit_generator(custom_generator(mode=3),
207                        steps_per_epoch=5,
208                        epochs=1,
209                        verbose=1,
210                        max_queue_size=10,
211                        use_multiprocessing=False)
212    model.fit_generator(custom_generator(mode=3),
213                        steps_per_epoch=5,
214                        epochs=1,
215                        verbose=1,
216                        max_queue_size=10,
217                        use_multiprocessing=False,
218                        validation_data=custom_generator(mode=3),
219                        validation_steps=10)
220    model.predict_generator(custom_generator(mode=3),
221                            steps=5,
222                            max_queue_size=10,
223                            use_multiprocessing=False)
224    model.evaluate_generator(custom_generator(mode=3),
225                             steps=5,
226                             max_queue_size=10,
227                             use_multiprocessing=False)
228
229  @keras_parameterized.run_with_all_model_types
230  @keras_parameterized.run_all_keras_modes
231  def test_generator_methods_invalid_use_case(self):
232    def invalid_generator():
233      while 1:
234        yield (0, 0, 0, 0)
235
236    model = testing_utils.get_small_mlp(
237        num_hidden=3, num_classes=4, input_dim=2)
238    model.compile(
239        loss='mse',
240        optimizer=rmsprop.RMSprop(1e-3),
241        run_eagerly=testing_utils.should_run_eagerly())
242
243    with self.assertRaises(ValueError):
244      model.fit_generator(invalid_generator(),
245                          steps_per_epoch=5,
246                          epochs=1,
247                          verbose=1,
248                          max_queue_size=10,
249                          use_multiprocessing=False)
250    with self.assertRaises(ValueError):
251      model.fit_generator(custom_generator(),
252                          steps_per_epoch=5,
253                          epochs=1,
254                          verbose=1,
255                          max_queue_size=10,
256                          use_multiprocessing=False,
257                          validation_data=invalid_generator(),
258                          validation_steps=10)
259    with self.assertRaises(ValueError):
260      model.predict_generator(invalid_generator(),
261                              steps=5,
262                              max_queue_size=10,
263                              use_multiprocessing=False)
264    with self.assertRaises(ValueError):
265      model.evaluate_generator(invalid_generator(),
266                               steps=5,
267                               max_queue_size=10,
268                               use_multiprocessing=False)
269
270  @keras_parameterized.run_with_all_model_types
271  @keras_parameterized.run_all_keras_modes
272  def test_generator_input_to_fit_eval_predict(self):
273    val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
274
275    def ones_generator():
276      while True:
277        yield np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
278
279    model = testing_utils.get_small_mlp(
280        num_hidden=10, num_classes=1, input_dim=10)
281
282    model.compile(
283        rmsprop.RMSprop(0.001),
284        'binary_crossentropy',
285        run_eagerly=testing_utils.should_run_eagerly())
286    model.fit(
287        ones_generator(),
288        steps_per_epoch=2,
289        validation_data=val_data,
290        epochs=2)
291    model.evaluate(ones_generator(), steps=2)
292    model.predict(ones_generator(), steps=2)
293
294    # Test with a changing batch size
295    model = testing_utils.get_small_mlp(
296        num_hidden=3, num_classes=4, input_dim=2)
297    model.compile(
298        loss='mse',
299        optimizer=rmsprop.RMSprop(1e-3),
300        metrics=['mae', metrics_module.CategoricalAccuracy()])
301    model.fit_generator(custom_generator_changing_batch_size(),
302                        steps_per_epoch=5,
303                        epochs=1,
304                        verbose=1,
305                        max_queue_size=10,
306                        use_multiprocessing=False)
307    model.fit_generator(custom_generator_changing_batch_size(),
308                        steps_per_epoch=5,
309                        epochs=1,
310                        verbose=1,
311                        max_queue_size=10,
312                        use_multiprocessing=False,
313                        validation_data=custom_generator_changing_batch_size(),
314                        validation_steps=10)
315
316    model.fit(
317        custom_generator_changing_batch_size(),
318        steps_per_epoch=5,
319        validation_data=custom_generator_changing_batch_size(),
320        validation_steps=10,
321        epochs=2)
322    model.evaluate(custom_generator_changing_batch_size(), steps=5)
323    model.predict(custom_generator_changing_batch_size(), steps=5)
324
325  @keras_parameterized.run_with_all_model_types
326  @keras_parameterized.run_all_keras_modes
327  @data_utils.dont_use_multiprocessing_pool
328  def test_generator_dynamic_shapes(self):
329
330    x = [
331        'I think juice is great',
332        'unknown is the best language since slicedbread',
333        'a a a a a a a',
334        'matmul'
335        'Yaks are also quite nice',
336    ]
337    y = [1, 0, 0, 1, 1]
338
339    vocab = {
340        word: i + 1 for i, word in
341        enumerate(
342            sorted(set(itertools.chain(*[i.split() for i in x]))))
343    }
344
345    def data_gen(batch_size=2):
346      np.random.seed(0)
347      data = list(zip(x, y)) * 10
348      np.random.shuffle(data)
349
350      def pack_and_pad(queue):
351        x = [[vocab[j] for j in i[0].split()] for i in queue]
352        pad_len = max(len(i) for i in x)
353        x = np.array([i + [0] * (pad_len - len(i)) for i in x])
354        y = np.array([i[1] for i in queue])
355        del queue[:]
356        return x, y[:, np.newaxis]
357
358      queue = []
359      for i, element in enumerate(data):
360        queue.append(element)
361        if not (i + 1) % batch_size:
362          yield pack_and_pad(queue)
363
364      if queue:
365        # Last partial batch
366        yield pack_and_pad(queue)
367
368    model = testing_utils.get_model_from_layers([
369        layers_module.Embedding(input_dim=len(vocab) + 1, output_dim=4),
370        layers_module.SimpleRNN(units=1),
371        layers_module.Activation('sigmoid')
372    ],
373                                                input_shape=(None,))
374
375    model.compile(loss=losses.binary_crossentropy, optimizer='sgd')
376    model.fit(data_gen(), epochs=1, steps_per_epoch=5)
377
378
379class TestGeneratorMethodsWithSequences(keras_parameterized.TestCase):
380
381  @keras_parameterized.run_with_all_model_types
382  @keras_parameterized.run_all_keras_modes
383  @data_utils.dont_use_multiprocessing_pool
384  def test_training_with_sequences(self):
385
386    class DummySequence(data_utils.Sequence):
387
388      def __getitem__(self, idx):
389        return np.zeros([10, 2]), np.ones([10, 4])
390
391      def __len__(self):
392        return 10
393
394    model = testing_utils.get_small_mlp(
395        num_hidden=3, num_classes=4, input_dim=2)
396    model.compile(loss='mse', optimizer=rmsprop.RMSprop(1e-3))
397
398    model.fit_generator(DummySequence(),
399                        steps_per_epoch=10,
400                        validation_data=custom_generator(),
401                        validation_steps=1,
402                        max_queue_size=10,
403                        workers=0,
404                        use_multiprocessing=True)
405    model.fit_generator(DummySequence(),
406                        steps_per_epoch=10,
407                        validation_data=custom_generator(),
408                        validation_steps=1,
409                        max_queue_size=10,
410                        workers=0,
411                        use_multiprocessing=False)
412
413  @keras_parameterized.run_with_all_model_types
414  @keras_parameterized.run_all_keras_modes
415  @data_utils.dont_use_multiprocessing_pool
416  def test_sequence_input_to_fit_eval_predict(self):
417    val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
418
419    class CustomSequence(data_utils.Sequence):
420
421      def __getitem__(self, idx):
422        return np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
423
424      def __len__(self):
425        return 2
426
427    class CustomSequenceChangingBatchSize(data_utils.Sequence):
428
429      def __getitem__(self, idx):
430        batch_size = 10 - idx
431        return (np.ones([batch_size, 10], np.float32),
432                np.ones([batch_size, 1], np.float32))
433
434      def __len__(self):
435        return 2
436
437    model = testing_utils.get_small_mlp(
438        num_hidden=10, num_classes=1, input_dim=10)
439
440    model.compile(rmsprop.RMSprop(0.001), 'binary_crossentropy')
441    model.fit(CustomSequence(), validation_data=val_data, epochs=2)
442    model.evaluate(CustomSequence())
443    model.predict(CustomSequence())
444
445    with self.assertRaisesRegex(ValueError, '`y` argument is not supported'):
446      model.fit(CustomSequence(), y=np.ones([10, 1]))
447
448    with self.assertRaisesRegex(ValueError,
449                                '`sample_weight` argument is not supported'):
450      model.fit(CustomSequence(), sample_weight=np.ones([10, 1]))
451
452    model.compile(rmsprop.RMSprop(0.001), 'binary_crossentropy')
453    model.fit(CustomSequenceChangingBatchSize(),
454              validation_data=val_data, epochs=2)
455    model.evaluate(CustomSequenceChangingBatchSize())
456    model.predict(CustomSequenceChangingBatchSize())
457
458  @keras_parameterized.run_all_keras_modes(always_skip_v1=True)
459  def test_sequence_on_epoch_end(self):
460
461    class MySequence(data_utils.Sequence):
462
463      def __init__(self):
464        self.epochs = 0
465
466      def __getitem__(self, idx):
467        return np.ones([10, 10], np.float32), np.ones([10, 1], np.float32)
468
469      def __len__(self):
470        return 2
471
472      def on_epoch_end(self):
473        self.epochs += 1
474
475    inputs = input_layer.Input(10)
476    outputs = layers_module.Dense(1)(inputs)
477    model = training.Model(inputs, outputs)
478    model.compile('sgd', 'mse')
479    my_seq = MySequence()
480    model.fit(my_seq, epochs=2)
481    self.assertEqual(my_seq.epochs, 2)
482
483
484@combinations.generate(combinations.combine(mode=['graph', 'eager']))
485class TestConvertToGeneratorLike(test.TestCase, parameterized.TestCase):
486  simple_inputs = (np.ones((10, 10)), np.ones((10, 1)))
487  nested_inputs = ((np.ones((10, 10)), np.ones((10, 20))), (np.ones((10, 1)),
488                                                            np.ones((10, 3))))
489
490  def _make_dataset(self, inputs, batches):
491    return dataset_ops.DatasetV2.from_tensors(inputs).repeat(batches)
492
493  def _make_iterator(self, inputs, batches):
494    return dataset_ops.make_one_shot_iterator(
495        self._make_dataset(inputs, batches))
496
497  def _make_generator(self, inputs, batches):
498
499    def _gen():
500      for _ in range(batches):
501        yield inputs
502
503    return _gen()
504
505  def _make_numpy(self, inputs, _):
506    return inputs
507
508  @parameterized.named_parameters(
509      ('simple_dataset', _make_dataset, simple_inputs),
510      ('simple_iterator', _make_iterator, simple_inputs),
511      ('simple_generator', _make_generator, simple_inputs),
512      ('simple_numpy', _make_numpy, simple_inputs),
513      ('nested_dataset', _make_dataset, nested_inputs),
514      ('nested_iterator', _make_iterator, nested_inputs),
515      ('nested_generator', _make_generator, nested_inputs),
516      ('nested_numpy', _make_numpy, nested_inputs))
517  def test_convert_to_generator_like(self, input_fn, inputs):
518    expected_batches = 5
519    data = input_fn(self, inputs, expected_batches)
520
521    # Dataset and Iterator not supported in Legacy Graph mode.
522    if (not context.executing_eagerly() and
523        isinstance(data, (dataset_ops.DatasetV2, iterator_ops.Iterator))):
524      return
525
526    generator, steps = training_generator_v1.convert_to_generator_like(
527        data, batch_size=2, steps_per_epoch=expected_batches)
528    self.assertEqual(steps, expected_batches)
529
530    for _ in range(expected_batches):
531      outputs = next(generator)
532    nest.assert_same_structure(outputs, inputs)
533
534
535if __name__ == '__main__':
536  test.main()
537