1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for training routines.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22import unittest 23 24from absl.testing import parameterized 25import numpy as np 26 27from tensorflow.python import keras 28from tensorflow.python.data.ops import dataset_ops 29from tensorflow.python.data.ops import iterator_ops 30from tensorflow.python.eager import context 31from tensorflow.python.framework import test_util as tf_test_util 32from tensorflow.python.keras import keras_parameterized 33from tensorflow.python.keras import metrics as metrics_module 34from tensorflow.python.keras import testing_utils 35from tensorflow.python.keras.engine import training_generator 36from tensorflow.python.keras.optimizer_v2 import rmsprop 37from tensorflow.python.platform import test 38from tensorflow.python.util import nest 39 40 41def custom_generator(mode=2): 42 batch_size = 10 43 num_samples = 50 44 arr_data = np.random.random((num_samples, 2)) 45 arr_labels = np.random.random((num_samples, 4)) 46 arr_weights = np.random.random((num_samples,)) 47 i = 0 48 while True: 49 batch_index = i * batch_size % num_samples 50 i += 1 51 start = batch_index 52 end = start + batch_size 53 x = arr_data[start: end] 54 y = arr_labels[start: end] 55 w = arr_weights[start: end] 56 if mode == 1: 57 yield x 58 elif mode == 2: 59 yield x, y 60 else: 61 yield x, y, w 62 63 64class TestGeneratorMethods(keras_parameterized.TestCase): 65 66 @unittest.skipIf( 67 os.name == 'nt', 68 'use_multiprocessing=True does not work on windows properly.') 69 @keras_parameterized.run_with_all_model_types 70 @keras_parameterized.run_all_keras_modes 71 def test_fit_generator_method(self): 72 model = testing_utils.get_small_mlp( 73 num_hidden=3, num_classes=4, input_dim=2) 74 model.compile( 75 loss='mse', 76 optimizer=rmsprop.RMSprop(1e-3), 77 metrics=['mae', metrics_module.CategoricalAccuracy()]) 78 79 model.fit_generator(custom_generator(), 80 steps_per_epoch=5, 81 epochs=1, 82 verbose=1, 83 max_queue_size=10, 84 workers=4, 85 use_multiprocessing=True) 86 model.fit_generator(custom_generator(), 87 steps_per_epoch=5, 88 epochs=1, 89 verbose=1, 90 max_queue_size=10, 91 use_multiprocessing=False) 92 model.fit_generator(custom_generator(), 93 steps_per_epoch=5, 94 epochs=1, 95 verbose=1, 96 max_queue_size=10, 97 use_multiprocessing=False, 98 validation_data=custom_generator(), 99 validation_steps=10) 100 model.fit_generator(custom_generator(), 101 steps_per_epoch=5, 102 validation_data=custom_generator(), 103 validation_steps=1, 104 workers=0) 105 106 @unittest.skipIf( 107 os.name == 'nt', 108 'use_multiprocessing=True does not work on windows properly.') 109 @keras_parameterized.run_with_all_model_types 110 @keras_parameterized.run_all_keras_modes 111 def test_evaluate_generator_method(self): 112 model = testing_utils.get_small_mlp( 113 num_hidden=3, num_classes=4, input_dim=2) 114 model.compile( 115 loss='mse', 116 optimizer=rmsprop.RMSprop(1e-3), 117 metrics=['mae', metrics_module.CategoricalAccuracy()], 118 run_eagerly=testing_utils.should_run_eagerly()) 119 120 model.evaluate_generator(custom_generator(), 121 steps=5, 122 max_queue_size=10, 123 workers=2, 124 verbose=1, 125 use_multiprocessing=True) 126 model.evaluate_generator(custom_generator(), 127 steps=5, 128 max_queue_size=10, 129 use_multiprocessing=False) 130 model.evaluate_generator(custom_generator(), 131 steps=5, 132 max_queue_size=10, 133 use_multiprocessing=False, 134 workers=0) 135 136 @unittest.skipIf( 137 os.name == 'nt', 138 'use_multiprocessing=True does not work on windows properly.') 139 @keras_parameterized.run_with_all_model_types 140 @keras_parameterized.run_all_keras_modes 141 def test_predict_generator_method(self): 142 model = testing_utils.get_small_mlp( 143 num_hidden=3, num_classes=4, input_dim=2) 144 model.run_eagerly = testing_utils.should_run_eagerly() 145 146 model.predict_generator(custom_generator(), 147 steps=5, 148 max_queue_size=10, 149 workers=2, 150 use_multiprocessing=True) 151 model.predict_generator(custom_generator(), 152 steps=5, 153 max_queue_size=10, 154 use_multiprocessing=False) 155 model.predict_generator(custom_generator(), 156 steps=5, 157 max_queue_size=10, 158 workers=0) 159 # Test generator with just inputs (no targets) 160 model.predict_generator(custom_generator(mode=1), 161 steps=5, 162 max_queue_size=10, 163 workers=2, 164 use_multiprocessing=True) 165 model.predict_generator(custom_generator(mode=1), 166 steps=5, 167 max_queue_size=10, 168 use_multiprocessing=False) 169 model.predict_generator(custom_generator(mode=1), 170 steps=5, 171 max_queue_size=10, 172 workers=0) 173 174 @keras_parameterized.run_with_all_model_types 175 @keras_parameterized.run_all_keras_modes 176 def test_generator_methods_with_sample_weights(self): 177 model = testing_utils.get_small_mlp( 178 num_hidden=3, num_classes=4, input_dim=2) 179 model.compile( 180 loss='mse', 181 optimizer=rmsprop.RMSprop(1e-3), 182 metrics=['mae', metrics_module.CategoricalAccuracy()], 183 run_eagerly=testing_utils.should_run_eagerly()) 184 185 model.fit_generator(custom_generator(mode=3), 186 steps_per_epoch=5, 187 epochs=1, 188 verbose=1, 189 max_queue_size=10, 190 use_multiprocessing=False) 191 model.fit_generator(custom_generator(mode=3), 192 steps_per_epoch=5, 193 epochs=1, 194 verbose=1, 195 max_queue_size=10, 196 use_multiprocessing=False, 197 validation_data=custom_generator(mode=3), 198 validation_steps=10) 199 model.predict_generator(custom_generator(mode=3), 200 steps=5, 201 max_queue_size=10, 202 use_multiprocessing=False) 203 model.evaluate_generator(custom_generator(mode=3), 204 steps=5, 205 max_queue_size=10, 206 use_multiprocessing=False) 207 208 @keras_parameterized.run_with_all_model_types 209 @keras_parameterized.run_all_keras_modes 210 def test_generator_methods_invalid_use_case(self): 211 212 def invalid_generator(): 213 while 1: 214 yield 0 215 216 model = testing_utils.get_small_mlp( 217 num_hidden=3, num_classes=4, input_dim=2) 218 model.compile(loss='mse', optimizer=rmsprop.RMSprop(1e-3), 219 run_eagerly=testing_utils.should_run_eagerly()) 220 221 with self.assertRaises(ValueError): 222 model.fit_generator(invalid_generator(), 223 steps_per_epoch=5, 224 epochs=1, 225 verbose=1, 226 max_queue_size=10, 227 use_multiprocessing=False) 228 with self.assertRaises(ValueError): 229 model.fit_generator(custom_generator(), 230 steps_per_epoch=5, 231 epochs=1, 232 verbose=1, 233 max_queue_size=10, 234 use_multiprocessing=False, 235 validation_data=invalid_generator(), 236 validation_steps=10) 237 with self.assertRaises(AttributeError): 238 model.predict_generator(invalid_generator(), 239 steps=5, 240 max_queue_size=10, 241 use_multiprocessing=False) 242 with self.assertRaises(ValueError): 243 model.evaluate_generator(invalid_generator(), 244 steps=5, 245 max_queue_size=10, 246 use_multiprocessing=False) 247 248 @keras_parameterized.run_with_all_model_types 249 @keras_parameterized.run_all_keras_modes 250 def test_generator_input_to_fit_eval_predict(self): 251 val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32) 252 253 def ones_generator(): 254 while True: 255 yield np.ones([10, 10], np.float32), np.ones([10, 1], np.float32) 256 257 model = testing_utils.get_small_mlp( 258 num_hidden=10, num_classes=1, input_dim=10) 259 260 model.compile(rmsprop.RMSprop(0.001), 'binary_crossentropy', 261 run_eagerly=testing_utils.should_run_eagerly()) 262 model.fit( 263 ones_generator(), 264 steps_per_epoch=2, 265 validation_data=val_data, 266 epochs=2) 267 model.evaluate(ones_generator(), steps=2) 268 model.predict(ones_generator(), steps=2) 269 270 271class TestGeneratorMethodsWithSequences(keras_parameterized.TestCase): 272 273 @keras_parameterized.run_with_all_model_types 274 @keras_parameterized.run_all_keras_modes 275 def test_training_with_sequences(self): 276 277 class DummySequence(keras.utils.Sequence): 278 279 def __getitem__(self, idx): 280 return np.zeros([10, 2]), np.ones([10, 4]) 281 282 def __len__(self): 283 return 10 284 285 model = testing_utils.get_small_mlp( 286 num_hidden=3, num_classes=4, input_dim=2) 287 model.compile(loss='mse', optimizer=rmsprop.RMSprop(1e-3)) 288 289 model.fit_generator(DummySequence(), 290 steps_per_epoch=10, 291 validation_data=custom_generator(), 292 validation_steps=1, 293 max_queue_size=10, 294 workers=0, 295 use_multiprocessing=True) 296 model.fit_generator(DummySequence(), 297 steps_per_epoch=10, 298 validation_data=custom_generator(), 299 validation_steps=1, 300 max_queue_size=10, 301 workers=0, 302 use_multiprocessing=False) 303 304 @keras_parameterized.run_with_all_model_types 305 @keras_parameterized.run_all_keras_modes 306 def test_sequence_input_to_fit_eval_predict(self): 307 val_data = np.ones([10, 10], np.float32), np.ones([10, 1], np.float32) 308 309 class CustomSequence(keras.utils.Sequence): 310 311 def __getitem__(self, idx): 312 return np.ones([10, 10], np.float32), np.ones([10, 1], np.float32) 313 314 def __len__(self): 315 return 2 316 317 model = testing_utils.get_small_mlp( 318 num_hidden=10, num_classes=1, input_dim=10) 319 320 model.compile(rmsprop.RMSprop(0.001), 'binary_crossentropy') 321 model.fit(CustomSequence(), validation_data=val_data, epochs=2) 322 model.evaluate(CustomSequence()) 323 model.predict(CustomSequence()) 324 325 with self.assertRaisesRegexp(ValueError, '`y` argument is not supported'): 326 model.fit(CustomSequence(), y=np.ones([10, 1])) 327 328 with self.assertRaisesRegexp(ValueError, 329 '`sample_weight` argument is not supported'): 330 model.fit(CustomSequence(), sample_weight=np.ones([10, 1])) 331 332 333@tf_test_util.run_all_in_graph_and_eager_modes 334class TestConvertToGeneratorLike(test.TestCase, parameterized.TestCase): 335 simple_inputs = (np.ones((10, 10)), np.ones((10, 1))) 336 nested_inputs = ((np.ones((10, 10)), np.ones((10, 20))), (np.ones((10, 1)), 337 np.ones((10, 3)))) 338 339 def _make_dataset(self, inputs, batches): 340 return dataset_ops.DatasetV2.from_tensors(inputs).repeat(batches) 341 342 def _make_iterator(self, inputs, batches): 343 return dataset_ops.make_one_shot_iterator( 344 self._make_dataset(inputs, batches)) 345 346 def _make_generator(self, inputs, batches): 347 348 def _gen(): 349 for _ in range(batches): 350 yield inputs 351 352 return _gen() 353 354 def _make_numpy(self, inputs, _): 355 return inputs 356 357 @parameterized.named_parameters( 358 ('simple_dataset', _make_dataset, simple_inputs), 359 ('simple_iterator', _make_iterator, simple_inputs), 360 ('simple_generator', _make_generator, simple_inputs), 361 ('simple_numpy', _make_numpy, simple_inputs), 362 ('nested_dataset', _make_dataset, nested_inputs), 363 ('nested_iterator', _make_iterator, nested_inputs), 364 ('nested_generator', _make_generator, nested_inputs), 365 ('nested_numpy', _make_numpy, nested_inputs)) 366 def test_convert_to_generator_like(self, input_fn, inputs): 367 expected_batches = 5 368 data = input_fn(self, inputs, expected_batches) 369 370 # Dataset and Iterator not supported in Legacy Graph mode. 371 if (not context.executing_eagerly() and 372 isinstance(data, (dataset_ops.DatasetV2, iterator_ops.Iterator))): 373 return 374 375 generator, steps = training_generator.convert_to_generator_like( 376 data, batch_size=2, steps_per_epoch=expected_batches) 377 self.assertEqual(steps, expected_batches) 378 379 for _ in range(expected_batches): 380 outputs = next(generator) 381 nest.assert_same_structure(outputs, inputs) 382 383 384if __name__ == '__main__': 385 test.main() 386