1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for unit-testing Keras.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections.abc as collections_abc 22import functools 23import itertools 24import unittest 25 26from absl.testing import parameterized 27 28from tensorflow.python import keras 29from tensorflow.python import tf2 30from tensorflow.python.eager import context 31from tensorflow.python.framework import ops 32from tensorflow.python.keras import testing_utils 33from tensorflow.python.platform import test 34from tensorflow.python.util import nest 35 36try: 37 import h5py # pylint:disable=g-import-not-at-top 38except ImportError: 39 h5py = None 40 41 42class TestCase(test.TestCase, parameterized.TestCase): 43 44 def tearDown(self): 45 keras.backend.clear_session() 46 super(TestCase, self).tearDown() 47 48 49def run_with_all_saved_model_formats( 50 test_or_class=None, 51 exclude_formats=None): 52 """Execute the decorated test with all Keras saved model formats). 53 54 This decorator is intended to be applied either to individual test methods in 55 a `keras_parameterized.TestCase` class, or directly to a test class that 56 extends it. Doing so will cause the contents of the individual test 57 method (or all test methods in the class) to be executed multiple times - once 58 for each Keras saved model format. 59 60 The Keras saved model formats include: 61 1. HDF5: 'h5' 62 2. SavedModel: 'tf' 63 64 Note: if stacking this decorator with absl.testing's parameterized decorators, 65 those should be at the bottom of the stack. 66 67 Various methods in `testing_utils` to get file path for saved models will 68 auto-generate a string of the two saved model formats. This allows unittests 69 to confirm the equivalence between the two Keras saved model formats. 70 71 For example, consider the following unittest: 72 73 ```python 74 class MyTests(testing_utils.KerasTestCase): 75 76 @testing_utils.run_with_all_saved_model_formats 77 def test_foo(self): 78 save_format = testing_utils.get_save_format() 79 saved_model_dir = '/tmp/saved_model/' 80 model = keras.models.Sequential() 81 model.add(keras.layers.Dense(2, input_shape=(3,))) 82 model.add(keras.layers.Dense(3)) 83 model.compile(loss='mse', optimizer='sgd', metrics=['acc']) 84 85 keras.models.save_model(model, saved_model_dir, save_format=save_format) 86 model = keras.models.load_model(saved_model_dir) 87 88 if __name__ == "__main__": 89 tf.test.main() 90 ``` 91 92 This test tries to save the model into the formats of 'hdf5', 'h5', 'keras', 93 'tensorflow', and 'tf'. 94 95 We can also annotate the whole class if we want this to apply to all tests in 96 the class: 97 ```python 98 @testing_utils.run_with_all_saved_model_formats 99 class MyTests(testing_utils.KerasTestCase): 100 101 def test_foo(self): 102 save_format = testing_utils.get_save_format() 103 saved_model_dir = '/tmp/saved_model/' 104 model = keras.models.Sequential() 105 model.add(keras.layers.Dense(2, input_shape=(3,))) 106 model.add(keras.layers.Dense(3)) 107 model.compile(loss='mse', optimizer='sgd', metrics=['acc']) 108 109 keras.models.save_model(model, saved_model_dir, save_format=save_format) 110 model = tf.keras.models.load_model(saved_model_dir) 111 112 if __name__ == "__main__": 113 tf.test.main() 114 ``` 115 116 Args: 117 test_or_class: test method or class to be annotated. If None, 118 this method returns a decorator that can be applied to a test method or 119 test class. If it is not None this returns the decorator applied to the 120 test or class. 121 exclude_formats: A collection of Keras saved model formats to not run. 122 (May also be a single format not wrapped in a collection). 123 Defaults to None. 124 125 Returns: 126 Returns a decorator that will run the decorated test method multiple times: 127 once for each desired Keras saved model format. 128 129 Raises: 130 ImportError: If abseil parameterized is not installed or not included as 131 a target dependency. 132 """ 133 # Exclude h5 save format if H5py isn't available. 134 if h5py is None: 135 exclude_formats.append(['h5']) 136 saved_model_formats = ['h5', 'tf', 'tf_no_traces'] 137 params = [('_%s' % saved_format, saved_format) 138 for saved_format in saved_model_formats 139 if saved_format not in nest.flatten(exclude_formats)] 140 141 def single_method_decorator(f): 142 """Decorator that constructs the test cases.""" 143 # Use named_parameters so it can be individually run from the command line 144 @parameterized.named_parameters(*params) 145 @functools.wraps(f) 146 def decorated(self, saved_format, *args, **kwargs): 147 """A run of a single test case w/ the specified model type.""" 148 if saved_format == 'h5': 149 _test_h5_saved_model_format(f, self, *args, **kwargs) 150 elif saved_format == 'tf': 151 _test_tf_saved_model_format(f, self, *args, **kwargs) 152 elif saved_format == 'tf_no_traces': 153 _test_tf_saved_model_format_no_traces(f, self, *args, **kwargs) 154 else: 155 raise ValueError('Unknown model type: %s' % (saved_format,)) 156 return decorated 157 158 return _test_or_class_decorator(test_or_class, single_method_decorator) 159 160 161def _test_h5_saved_model_format(f, test_or_class, *args, **kwargs): 162 with testing_utils.saved_model_format_scope('h5'): 163 f(test_or_class, *args, **kwargs) 164 165 166def _test_tf_saved_model_format(f, test_or_class, *args, **kwargs): 167 with testing_utils.saved_model_format_scope('tf'): 168 f(test_or_class, *args, **kwargs) 169 170 171def _test_tf_saved_model_format_no_traces(f, test_or_class, *args, **kwargs): 172 with testing_utils.saved_model_format_scope('tf', save_traces=False): 173 f(test_or_class, *args, **kwargs) 174 175 176def run_with_all_weight_formats(test_or_class=None, exclude_formats=None): 177 """Runs all tests with the supported formats for saving weights.""" 178 exclude_formats = exclude_formats or [] 179 exclude_formats.append('tf_no_traces') # Only applies to saving models 180 return run_with_all_saved_model_formats(test_or_class, exclude_formats) 181 182 183# TODO(kaftan): Possibly enable 'subclass_custom_build' when tests begin to pass 184# it. Or perhaps make 'subclass' always use a custom build method. 185def run_with_all_model_types( 186 test_or_class=None, 187 exclude_models=None): 188 """Execute the decorated test with all Keras model types. 189 190 This decorator is intended to be applied either to individual test methods in 191 a `keras_parameterized.TestCase` class, or directly to a test class that 192 extends it. Doing so will cause the contents of the individual test 193 method (or all test methods in the class) to be executed multiple times - once 194 for each Keras model type. 195 196 The Keras model types are: ['functional', 'subclass', 'sequential'] 197 198 Note: if stacking this decorator with absl.testing's parameterized decorators, 199 those should be at the bottom of the stack. 200 201 Various methods in `testing_utils` to get models will auto-generate a model 202 of the currently active Keras model type. This allows unittests to confirm 203 the equivalence between different Keras models. 204 205 For example, consider the following unittest: 206 207 ```python 208 class MyTests(testing_utils.KerasTestCase): 209 210 @testing_utils.run_with_all_model_types( 211 exclude_models = ['sequential']) 212 def test_foo(self): 213 model = testing_utils.get_small_mlp(1, 4, input_dim=3) 214 optimizer = RMSPropOptimizer(learning_rate=0.001) 215 loss = 'mse' 216 metrics = ['mae'] 217 model.compile(optimizer, loss, metrics=metrics) 218 219 inputs = np.zeros((10, 3)) 220 targets = np.zeros((10, 4)) 221 dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) 222 dataset = dataset.repeat(100) 223 dataset = dataset.batch(10) 224 225 model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) 226 227 if __name__ == "__main__": 228 tf.test.main() 229 ``` 230 231 This test tries building a small mlp as both a functional model and as a 232 subclass model. 233 234 We can also annotate the whole class if we want this to apply to all tests in 235 the class: 236 ```python 237 @testing_utils.run_with_all_model_types(exclude_models = ['sequential']) 238 class MyTests(testing_utils.KerasTestCase): 239 240 def test_foo(self): 241 model = testing_utils.get_small_mlp(1, 4, input_dim=3) 242 optimizer = RMSPropOptimizer(learning_rate=0.001) 243 loss = 'mse' 244 metrics = ['mae'] 245 model.compile(optimizer, loss, metrics=metrics) 246 247 inputs = np.zeros((10, 3)) 248 targets = np.zeros((10, 4)) 249 dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) 250 dataset = dataset.repeat(100) 251 dataset = dataset.batch(10) 252 253 model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) 254 255 if __name__ == "__main__": 256 tf.test.main() 257 ``` 258 259 260 Args: 261 test_or_class: test method or class to be annotated. If None, 262 this method returns a decorator that can be applied to a test method or 263 test class. If it is not None this returns the decorator applied to the 264 test or class. 265 exclude_models: A collection of Keras model types to not run. 266 (May also be a single model type not wrapped in a collection). 267 Defaults to None. 268 269 Returns: 270 Returns a decorator that will run the decorated test method multiple times: 271 once for each desired Keras model type. 272 273 Raises: 274 ImportError: If abseil parameterized is not installed or not included as 275 a target dependency. 276 """ 277 model_types = ['functional', 'subclass', 'sequential'] 278 params = [('_%s' % model, model) for model in model_types 279 if model not in nest.flatten(exclude_models)] 280 281 def single_method_decorator(f): 282 """Decorator that constructs the test cases.""" 283 # Use named_parameters so it can be individually run from the command line 284 @parameterized.named_parameters(*params) 285 @functools.wraps(f) 286 def decorated(self, model_type, *args, **kwargs): 287 """A run of a single test case w/ the specified model type.""" 288 if model_type == 'functional': 289 _test_functional_model_type(f, self, *args, **kwargs) 290 elif model_type == 'subclass': 291 _test_subclass_model_type(f, self, *args, **kwargs) 292 elif model_type == 'sequential': 293 _test_sequential_model_type(f, self, *args, **kwargs) 294 else: 295 raise ValueError('Unknown model type: %s' % (model_type,)) 296 return decorated 297 298 return _test_or_class_decorator(test_or_class, single_method_decorator) 299 300 301def _test_functional_model_type(f, test_or_class, *args, **kwargs): 302 with testing_utils.model_type_scope('functional'): 303 f(test_or_class, *args, **kwargs) 304 305 306def _test_subclass_model_type(f, test_or_class, *args, **kwargs): 307 with testing_utils.model_type_scope('subclass'): 308 f(test_or_class, *args, **kwargs) 309 310 311def _test_sequential_model_type(f, test_or_class, *args, **kwargs): 312 with testing_utils.model_type_scope('sequential'): 313 f(test_or_class, *args, **kwargs) 314 315 316def run_all_keras_modes(test_or_class=None, 317 config=None, 318 always_skip_v1=False, 319 always_skip_eager=False, 320 **kwargs): 321 """Execute the decorated test with all keras execution modes. 322 323 This decorator is intended to be applied either to individual test methods in 324 a `keras_parameterized.TestCase` class, or directly to a test class that 325 extends it. Doing so will cause the contents of the individual test 326 method (or all test methods in the class) to be executed multiple times - 327 once executing in legacy graph mode, once running eagerly and with 328 `should_run_eagerly` returning True, and once running eagerly with 329 `should_run_eagerly` returning False. 330 331 If Tensorflow v2 behavior is enabled, legacy graph mode will be skipped, and 332 the test will only run twice. 333 334 Note: if stacking this decorator with absl.testing's parameterized decorators, 335 those should be at the bottom of the stack. 336 337 For example, consider the following unittest: 338 339 ```python 340 class MyTests(testing_utils.KerasTestCase): 341 342 @testing_utils.run_all_keras_modes 343 def test_foo(self): 344 model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3) 345 optimizer = RMSPropOptimizer(learning_rate=0.001) 346 loss = 'mse' 347 metrics = ['mae'] 348 model.compile( 349 optimizer, loss, metrics=metrics, 350 run_eagerly=testing_utils.should_run_eagerly()) 351 352 inputs = np.zeros((10, 3)) 353 targets = np.zeros((10, 4)) 354 dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) 355 dataset = dataset.repeat(100) 356 dataset = dataset.batch(10) 357 358 model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) 359 360 if __name__ == "__main__": 361 tf.test.main() 362 ``` 363 364 This test will try compiling & fitting the small functional mlp using all 365 three Keras execution modes. 366 367 Args: 368 test_or_class: test method or class to be annotated. If None, 369 this method returns a decorator that can be applied to a test method or 370 test class. If it is not None this returns the decorator applied to the 371 test or class. 372 config: An optional config_pb2.ConfigProto to use to configure the 373 session when executing graphs. 374 always_skip_v1: If True, does not try running the legacy graph mode even 375 when Tensorflow v2 behavior is not enabled. 376 always_skip_eager: If True, does not execute the decorated test 377 with eager execution modes. 378 **kwargs: Additional kwargs for configuring tests for 379 in-progress Keras behaviors/ refactorings that we haven't fully 380 rolled out yet 381 382 Returns: 383 Returns a decorator that will run the decorated test method multiple times. 384 385 Raises: 386 ImportError: If abseil parameterized is not installed or not included as 387 a target dependency. 388 """ 389 if kwargs: 390 raise ValueError('Unrecognized keyword args: {}'.format(kwargs)) 391 392 params = [('_v2_function', 'v2_function')] 393 if not always_skip_eager: 394 params.append(('_v2_eager', 'v2_eager')) 395 if not (always_skip_v1 or tf2.enabled()): 396 params.append(('_v1_session', 'v1_session')) 397 398 def single_method_decorator(f): 399 """Decorator that constructs the test cases.""" 400 401 # Use named_parameters so it can be individually run from the command line 402 @parameterized.named_parameters(*params) 403 @functools.wraps(f) 404 def decorated(self, run_mode, *args, **kwargs): 405 """A run of a single test case w/ specified run mode.""" 406 if run_mode == 'v1_session': 407 _v1_session_test(f, self, config, *args, **kwargs) 408 elif run_mode == 'v2_eager': 409 _v2_eager_test(f, self, *args, **kwargs) 410 elif run_mode == 'v2_function': 411 _v2_function_test(f, self, *args, **kwargs) 412 else: 413 return ValueError('Unknown run mode %s' % run_mode) 414 415 return decorated 416 417 return _test_or_class_decorator(test_or_class, single_method_decorator) 418 419 420def _v1_session_test(f, test_or_class, config, *args, **kwargs): 421 with ops.get_default_graph().as_default(): 422 with testing_utils.run_eagerly_scope(False): 423 with test_or_class.test_session(config=config): 424 f(test_or_class, *args, **kwargs) 425 426 427def _v2_eager_test(f, test_or_class, *args, **kwargs): 428 with context.eager_mode(): 429 with testing_utils.run_eagerly_scope(True): 430 f(test_or_class, *args, **kwargs) 431 432 433def _v2_function_test(f, test_or_class, *args, **kwargs): 434 with context.eager_mode(): 435 with testing_utils.run_eagerly_scope(False): 436 f(test_or_class, *args, **kwargs) 437 438 439def _test_or_class_decorator(test_or_class, single_method_decorator): 440 """Decorate a test or class with a decorator intended for one method. 441 442 If the test_or_class is a class: 443 This will apply the decorator to all test methods in the class. 444 445 If the test_or_class is an iterable of already-parameterized test cases: 446 This will apply the decorator to all the cases, and then flatten the 447 resulting cross-product of test cases. This allows stacking the Keras 448 parameterized decorators w/ each other, and to apply them to test methods 449 that have already been marked with an absl parameterized decorator. 450 451 Otherwise, treat the obj as a single method and apply the decorator directly. 452 453 Args: 454 test_or_class: A test method (that may have already been decorated with a 455 parameterized decorator, or a test class that extends 456 keras_parameterized.TestCase 457 single_method_decorator: 458 A parameterized decorator intended for a single test method. 459 Returns: 460 The decorated result. 461 """ 462 def _decorate_test_or_class(obj): 463 if isinstance(obj, collections_abc.Iterable): 464 return itertools.chain.from_iterable( 465 single_method_decorator(method) for method in obj) 466 if isinstance(obj, type): 467 cls = obj 468 for name, value in cls.__dict__.copy().items(): 469 if callable(value) and name.startswith( 470 unittest.TestLoader.testMethodPrefix): 471 setattr(cls, name, single_method_decorator(value)) 472 473 cls = type(cls).__new__(type(cls), cls.__name__, cls.__bases__, 474 cls.__dict__.copy()) 475 return cls 476 477 return single_method_decorator(obj) 478 479 if test_or_class is not None: 480 return _decorate_test_or_class(test_or_class) 481 482 return _decorate_test_or_class 483