1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for unit-testing Keras.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import functools 23import itertools 24import unittest 25 26from absl.testing import parameterized 27 28from tensorflow.python import keras 29from tensorflow.python import tf2 30from tensorflow.python.eager import context 31from tensorflow.python.keras import testing_utils 32from tensorflow.python.platform import test 33from tensorflow.python.util import nest 34 35 36class TestCase(test.TestCase, parameterized.TestCase): 37 38 def tearDown(self): 39 keras.backend.clear_session() 40 super(TestCase, self).tearDown() 41 42 43# TODO(kaftan): Possibly enable 'subclass_custom_build' when tests begin to pass 44# it. Or perhaps make 'subclass' always use a custom build method. 45def run_with_all_model_types( 46 test_or_class=None, 47 exclude_models=None): 48 """Execute the decorated test with all Keras model types. 49 50 This decorator is intended to be applied either to individual test methods in 51 a `keras_parameterized.TestCase` class, or directly to a test class that 52 extends it. Doing so will cause the contents of the individual test 53 method (or all test methods in the class) to be executed multiple times - once 54 for each Keras model type. 55 56 The Keras model types are: ['functional', 'subclass', 'sequential'] 57 58 Note: if stacking this decorator with absl.testing's parameterized decorators, 59 those should be at the bottom of the stack. 60 61 Various methods in `testing_utils` to get models will auto-generate a model 62 of the currently active Keras model type. This allows unittests to confirm 63 the equivalence between different Keras models. 64 65 For example, consider the following unittest: 66 67 ```python 68 class MyTests(testing_utils.KerasTestCase): 69 70 @testing_utils.run_with_all_model_types( 71 exclude_models = ['sequential']) 72 def test_foo(self): 73 model = testing_utils.get_small_mlp(1, 4, input_dim=3) 74 optimizer = RMSPropOptimizer(learning_rate=0.001) 75 loss = 'mse' 76 metrics = ['mae'] 77 model.compile(optimizer, loss, metrics=metrics) 78 79 inputs = np.zeros((10, 3)) 80 targets = np.zeros((10, 4)) 81 dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) 82 dataset = dataset.repeat(100) 83 dataset = dataset.batch(10) 84 85 model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) 86 87 if __name__ == "__main__": 88 tf.test.main() 89 ``` 90 91 This test tries building a small mlp as both a functional model and as a 92 subclass model. 93 94 We can also annotate the whole class if we want this to apply to all tests in 95 the class: 96 ```python 97 @testing_utils.run_with_all_model_types(exclude_models = ['sequential']) 98 class MyTests(testing_utils.KerasTestCase): 99 100 def test_foo(self): 101 model = testing_utils.get_small_mlp(1, 4, input_dim=3) 102 optimizer = RMSPropOptimizer(learning_rate=0.001) 103 loss = 'mse' 104 metrics = ['mae'] 105 model.compile(optimizer, loss, metrics=metrics) 106 107 inputs = np.zeros((10, 3)) 108 targets = np.zeros((10, 4)) 109 dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) 110 dataset = dataset.repeat(100) 111 dataset = dataset.batch(10) 112 113 model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) 114 115 if __name__ == "__main__": 116 tf.test.main() 117 ``` 118 119 120 Args: 121 test_or_class: test method or class to be annotated. If None, 122 this method returns a decorator that can be applied to a test method or 123 test class. If it is not None this returns the decorator applied to the 124 test or class. 125 exclude_models: A collection of Keras model types to not run. 126 (May also be a single model type not wrapped in a collection). 127 Defaults to None. 128 129 Returns: 130 Returns a decorator that will run the decorated test method multiple times: 131 once for each desired Keras model type. 132 133 Raises: 134 ImportError: If abseil parameterized is not installed or not included as 135 a target dependency. 136 """ 137 model_types = ['functional', 'subclass', 'sequential'] 138 params = [('_%s' % model, model) for model in model_types 139 if model not in nest.flatten(exclude_models)] 140 141 def single_method_decorator(f): 142 """Decorator that constructs the test cases.""" 143 # Use named_parameters so it can be individually run from the command line 144 @parameterized.named_parameters(*params) 145 @functools.wraps(f) 146 def decorated(self, model_type, *args, **kwargs): 147 """A run of a single test case w/ the specified model type.""" 148 if model_type == 'functional': 149 _test_functional_model_type(f, self, *args, **kwargs) 150 elif model_type == 'subclass': 151 _test_subclass_model_type(f, self, *args, **kwargs) 152 elif model_type == 'sequential': 153 _test_sequential_model_type(f, self, *args, **kwargs) 154 else: 155 raise ValueError('Unknown model type: %s' % (model_type,)) 156 return decorated 157 158 return _test_or_class_decorator(test_or_class, single_method_decorator) 159 160 161def _test_functional_model_type(f, test_or_class, *args, **kwargs): 162 with testing_utils.model_type_scope('functional'): 163 f(test_or_class, *args, **kwargs) 164 165 166def _test_subclass_model_type(f, test_or_class, *args, **kwargs): 167 with testing_utils.model_type_scope('subclass'): 168 f(test_or_class, *args, **kwargs) 169 170 171def _test_sequential_model_type(f, test_or_class, *args, **kwargs): 172 with testing_utils.model_type_scope('sequential'): 173 f(test_or_class, *args, **kwargs) 174 175 176def run_all_keras_modes( 177 test_or_class=None, 178 config=None, 179 always_skip_v1=False): 180 """Execute the decorated test with all keras execution modes. 181 182 This decorator is intended to be applied either to individual test methods in 183 a `keras_parameterized.TestCase` class, or directly to a test class that 184 extends it. Doing so will cause the contents of the individual test 185 method (or all test methods in the class) to be executed multiple times - 186 once executing in legacy graph mode, once running eagerly and with 187 `should_run_eagerly` returning True, and once running eagerly with 188 `should_run_eagerly` returning False. 189 190 If Tensorflow v2 behavior is enabled, legacy graph mode will be skipped, and 191 the test will only run twice. 192 193 Note: if stacking this decorator with absl.testing's parameterized decorators, 194 those should be at the bottom of the stack. 195 196 For example, consider the following unittest: 197 198 ```python 199 class MyTests(testing_utils.KerasTestCase): 200 201 @testing_utils.run_all_keras_modes 202 def test_foo(self): 203 model = testing_utils.get_small_functional_mlp(1, 4, input_dim=3) 204 optimizer = RMSPropOptimizer(learning_rate=0.001) 205 loss = 'mse' 206 metrics = ['mae'] 207 model.compile(optimizer, loss, metrics=metrics, 208 run_eagerly=testing_utils.should_run_eagerly()) 209 210 inputs = np.zeros((10, 3)) 211 targets = np.zeros((10, 4)) 212 dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)) 213 dataset = dataset.repeat(100) 214 dataset = dataset.batch(10) 215 216 model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1) 217 218 if __name__ == "__main__": 219 tf.test.main() 220 ``` 221 222 This test will try compiling & fitting the small functional mlp using all 223 three Keras execution modes. 224 225 Args: 226 test_or_class: test method or class to be annotated. If None, 227 this method returns a decorator that can be applied to a test method or 228 test class. If it is not None this returns the decorator applied to the 229 test or class. 230 config: An optional config_pb2.ConfigProto to use to configure the 231 session when executing graphs. 232 always_skip_v1: If True, does not try running the legacy graph mode even 233 when Tensorflow v2 behavior is not enabled. 234 235 Returns: 236 Returns a decorator that will run the decorated test method multiple times. 237 238 Raises: 239 ImportError: If abseil parameterized is not installed or not included as 240 a target dependency. 241 """ 242 params = [('_v2_eager', 'v2_eager'), 243 ('_v2_function', 'v2_function')] 244 if not (always_skip_v1 or tf2.enabled()): 245 params.append(('_v1_graph', 'v1_graph')) 246 247 def single_method_decorator(f): 248 """Decorator that constructs the test cases.""" 249 250 # Use named_parameters so it can be individually run from the command line 251 @parameterized.named_parameters(*params) 252 @functools.wraps(f) 253 def decorated(self, run_mode, *args, **kwargs): 254 """A run of a single test case w/ specified run mode.""" 255 if run_mode == 'v1_graph': 256 _v1_graph_test(f, self, config, *args, **kwargs) 257 elif run_mode == 'v2_function': 258 _v2_graph_functions_test(f, self, *args, **kwargs) 259 elif run_mode == 'v2_eager': 260 _v2_eager_test(f, self, *args, **kwargs) 261 else: 262 return ValueError('Unknown run mode %s' % run_mode) 263 264 return decorated 265 266 return _test_or_class_decorator(test_or_class, single_method_decorator) 267 268 269def _v1_graph_test(f, test_or_class, config, *args, **kwargs): 270 with context.graph_mode(), testing_utils.run_eagerly_scope(False): 271 with test_or_class.test_session(use_gpu=True, config=config): 272 f(test_or_class, *args, **kwargs) 273 274 275def _v2_graph_functions_test(f, test_or_class, *args, **kwargs): 276 with context.eager_mode(): 277 with testing_utils.run_eagerly_scope(False): 278 f(test_or_class, *args, **kwargs) 279 280 281def _v2_eager_test(f, test_or_class, *args, **kwargs): 282 with context.eager_mode(): 283 with testing_utils.run_eagerly_scope(True): 284 f(test_or_class, *args, **kwargs) 285 286 287def _test_or_class_decorator(test_or_class, single_method_decorator): 288 """Decorate a test or class with a decorator intended for one method. 289 290 If the test_or_class is a class: 291 This will apply the decorator to all test methods in the class. 292 293 If the test_or_class is an iterable of already-parameterized test cases: 294 This will apply the decorator to all the cases, and then flatten the 295 resulting cross-product of test cases. This allows stacking the Keras 296 parameterized decorators w/ each other, and to apply them to test methods 297 that have already been marked with an absl parameterized decorator. 298 299 Otherwise, treat the obj as a single method and apply the decorator directly. 300 301 Args: 302 test_or_class: A test method (that may have already been decorated with a 303 parameterized decorator, or a test class that extends 304 keras_parameterized.TestCase 305 single_method_decorator: 306 A parameterized decorator intended for a single test method. 307 Returns: 308 The decorated result. 309 """ 310 def _decorate_test_or_class(obj): 311 if isinstance(obj, collections.Iterable): 312 return itertools.chain.from_iterable( 313 single_method_decorator(method) for method in obj) 314 if isinstance(obj, type): 315 cls = obj 316 for name, value in cls.__dict__.copy().items(): 317 if callable(value) and name.startswith( 318 unittest.TestLoader.testMethodPrefix): 319 setattr(cls, name, single_method_decorator(value)) 320 321 cls = type(cls).__new__(type(cls), cls.__name__, cls.__bases__, 322 cls.__dict__.copy()) 323 return cls 324 325 return single_method_decorator(obj) 326 327 if test_or_class is not None: 328 return _decorate_test_or_class(test_or_class) 329 330 return _decorate_test_or_class 331