1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Facilities for creating multiple test combinations. 16 17Here is a simple example for testing various optimizers in Eager and Graph: 18 19class AdditionExample(test.TestCase, parameterized.TestCase): 20 @combinations.generate( 21 combinations.combine(mode=["graph", "eager"], 22 optimizer=[AdamOptimizer(), 23 GradientDescentOptimizer()])) 24 def testOptimizer(self, optimizer): 25 ... f(optimizer)... 26 27This will run `testOptimizer` 4 times with the specified optimizers: 2 in 28Eager and 2 in Graph mode. 29The test is going to accept the same parameters as the ones used in `combine()`. 30The parameters need to match by name between the `combine()` call and the test 31signature. It is necessary to accept all parameters. See `OptionalParameter` 32for a way to implement optional parameters. 33 34`combine()` function is available for creating a cross product of various 35options. `times()` function exists for creating a product of N `combine()`-ed 36results. 37 38The execution of generated tests can be customized in a number of ways: 39- The test can be skipped if it is not running in the correct environment. 40- The arguments that are passed to the test can be additionally transformed. 41- The test can be run with specific Python context managers. 42These behaviors can be customized by providing instances of `TestCombination` to 43`generate()`. 44""" 45 46from __future__ import absolute_import 47from __future__ import division 48from __future__ import print_function 49 50from collections import OrderedDict 51import contextlib 52import re 53import types 54import unittest 55 56from absl.testing import parameterized 57import six 58 59from tensorflow.python.util import tf_inspect 60from tensorflow.python.util.tf_export import tf_export 61 62 63@tf_export("__internal__.test.combinations.TestCombination", v1=[]) 64class TestCombination(object): 65 """Customize the behavior of `generate()` and the tests that it executes. 66 67 Here is sequence of steps for executing a test combination: 68 1. The test combination is evaluated for whether it should be executed in 69 the given environment by calling `should_execute_combination`. 70 2. If the test combination is going to be executed, then the arguments for 71 all combined parameters are validated. Some arguments can be handled in 72 a special way. This is achieved by implementing that logic in 73 `ParameterModifier` instances that returned from `parameter_modifiers`. 74 3. Before executing the test, `context_managers` are installed 75 around it. 76 """ 77 78 def should_execute_combination(self, kwargs): 79 """Indicates whether the combination of test arguments should be executed. 80 81 If the environment doesn't satisfy the dependencies of the test 82 combination, then it can be skipped. 83 84 Args: 85 kwargs: Arguments that are passed to the test combination. 86 87 Returns: 88 A tuple boolean and an optional string. The boolean False indicates 89 that the test should be skipped. The string would indicate a textual 90 description of the reason. If the test is going to be executed, then 91 this method returns `None` instead of the string. 92 """ 93 del kwargs 94 return (True, None) 95 96 def parameter_modifiers(self): 97 """Returns `ParameterModifier` instances that customize the arguments.""" 98 return [] 99 100 def context_managers(self, kwargs): 101 """Return context managers for running the test combination. 102 103 The test combination will run under all context managers that all 104 `TestCombination` instances return. 105 106 Args: 107 kwargs: Arguments and their values that are passed to the test 108 combination. 109 110 Returns: 111 A list of instantiated context managers. 112 """ 113 del kwargs 114 return [] 115 116 117@tf_export("__internal__.test.combinations.ParameterModifier", v1=[]) 118class ParameterModifier(object): 119 """Customizes the behavior of a particular parameter. 120 121 Users should override `modified_arguments()` to modify the parameter they 122 want, eg: change the value of certain parameter or filter it from the params 123 passed to the test case. 124 125 See the sample usage below, it will change any negative parameters to zero 126 before it gets passed to test case. 127 ``` 128 class NonNegativeParameterModifier(ParameterModifier): 129 130 def modified_arguments(self, kwargs, requested_parameters): 131 updates = {} 132 for name, value in kwargs.items(): 133 if value < 0: 134 updates[name] = 0 135 return updates 136 ``` 137 """ 138 139 DO_NOT_PASS_TO_THE_TEST = object() 140 141 def __init__(self, parameter_name=None): 142 """Construct a parameter modifier that may be specific to a parameter. 143 144 Args: 145 parameter_name: A `ParameterModifier` instance may operate on a class of 146 parameters or on a parameter with a particular name. Only 147 `ParameterModifier` instances that are of a unique type or were 148 initialized with a unique `parameter_name` will be executed. 149 See `__eq__` and `__hash__`. 150 """ 151 object.__init__(self) 152 self._parameter_name = parameter_name 153 154 def modified_arguments(self, kwargs, requested_parameters): 155 """Replace user-provided arguments before they are passed to a test. 156 157 This makes it possible to adjust user-provided arguments before passing 158 them to the test method. 159 160 Args: 161 kwargs: The combined arguments for the test. 162 requested_parameters: The set of parameters that are defined in the 163 signature of the test method. 164 165 Returns: 166 A dictionary with updates to `kwargs`. Keys with values set to 167 `ParameterModifier.DO_NOT_PASS_TO_THE_TEST` are going to be deleted and 168 not passed to the test. 169 """ 170 del kwargs, requested_parameters 171 return {} 172 173 def __eq__(self, other): 174 """Compare `ParameterModifier` by type and `parameter_name`.""" 175 if self is other: 176 return True 177 elif type(self) is type(other): 178 return self._parameter_name == other._parameter_name 179 else: 180 return False 181 182 def __ne__(self, other): 183 return not self.__eq__(other) 184 185 def __hash__(self): 186 """Compare `ParameterModifier` by type or `parameter_name`.""" 187 if self._parameter_name: 188 return hash(self._parameter_name) 189 else: 190 return id(self.__class__) 191 192 193@tf_export("__internal__.test.combinations.OptionalParameter", v1=[]) 194class OptionalParameter(ParameterModifier): 195 """A parameter that is optional in `combine()` and in the test signature. 196 197 `OptionalParameter` is usually used with `TestCombination` in the 198 `parameter_modifiers()`. It allows `TestCombination` to skip certain 199 parameters when passing them to `combine()`, since the `TestCombination` might 200 consume the param and create some context based on the value it gets. 201 202 See the sample usage below: 203 204 ``` 205 class EagerGraphCombination(TestCombination): 206 207 def context_managers(self, kwargs): 208 mode = kwargs.pop("mode", None) 209 if mode is None: 210 return [] 211 elif mode == "eager": 212 return [context.eager_mode()] 213 elif mode == "graph": 214 return [ops.Graph().as_default(), context.graph_mode()] 215 else: 216 raise ValueError( 217 "'mode' has to be either 'eager' or 'graph', got {}".format(mode)) 218 219 def parameter_modifiers(self): 220 return [test_combinations.OptionalParameter("mode")] 221 ``` 222 223 When the test case is generated, the param "mode" will not be passed to the 224 test method, since it is consumed by the `EagerGraphCombination`. 225 """ 226 227 def modified_arguments(self, kwargs, requested_parameters): 228 if self._parameter_name in requested_parameters: 229 return {} 230 else: 231 return {self._parameter_name: ParameterModifier.DO_NOT_PASS_TO_THE_TEST} 232 233 234def generate(combinations, test_combinations=()): 235 """A decorator for generating combinations of a test method or a test class. 236 237 Parameters of the test method must match by name to get the corresponding 238 value of the combination. Tests must accept all parameters that are passed 239 other than the ones that are `OptionalParameter`. 240 241 Args: 242 combinations: a list of dictionaries created using combine() and times(). 243 test_combinations: a tuple of `TestCombination` instances that customize 244 the execution of generated tests. 245 246 Returns: 247 a decorator that will cause the test method or the test class to be run 248 under the specified conditions. 249 250 Raises: 251 ValueError: if any parameters were not accepted by the test method 252 """ 253 def decorator(test_method_or_class): 254 """The decorator to be returned.""" 255 256 # Generate good test names that can be used with --test_filter. 257 named_combinations = [] 258 for combination in combinations: 259 # We use OrderedDicts in `combine()` and `times()` to ensure stable 260 # order of keys in each dictionary. 261 assert isinstance(combination, OrderedDict) 262 name = "".join([ 263 "_{}_{}".format("".join(filter(str.isalnum, key)), 264 "".join(filter(str.isalnum, _get_name(value, i)))) 265 for i, (key, value) in enumerate(combination.items()) 266 ]) 267 named_combinations.append( 268 OrderedDict( 269 list(combination.items()) + 270 [("testcase_name", "_test{}".format(name))])) 271 272 if isinstance(test_method_or_class, type): 273 class_object = test_method_or_class 274 class_object._test_method_ids = test_method_ids = {} 275 for name, test_method in six.iteritems(class_object.__dict__.copy()): 276 if (name.startswith(unittest.TestLoader.testMethodPrefix) and 277 isinstance(test_method, types.FunctionType)): 278 delattr(class_object, name) 279 methods = {} 280 parameterized._update_class_dict_for_param_test_case( 281 class_object.__name__, methods, test_method_ids, name, 282 parameterized._ParameterizedTestIter( 283 _augment_with_special_arguments( 284 test_method, test_combinations=test_combinations), 285 named_combinations, parameterized._NAMED, name)) 286 for method_name, method in six.iteritems(methods): 287 setattr(class_object, method_name, method) 288 289 return class_object 290 else: 291 test_method = _augment_with_special_arguments( 292 test_method_or_class, test_combinations=test_combinations) 293 return parameterized.named_parameters(*named_combinations)(test_method) 294 295 return decorator 296 297 298def _augment_with_special_arguments(test_method, test_combinations): 299 def decorated(self, **kwargs): 300 """A wrapped test method that can treat some arguments in a special way.""" 301 original_kwargs = kwargs.copy() 302 303 # Skip combinations that are going to be executed in a different testing 304 # environment. 305 reasons_to_skip = [] 306 for combination in test_combinations: 307 should_execute, reason = combination.should_execute_combination( 308 original_kwargs.copy()) 309 if not should_execute: 310 reasons_to_skip.append(" - " + reason) 311 312 if reasons_to_skip: 313 self.skipTest("\n".join(reasons_to_skip)) 314 315 customized_parameters = [] 316 for combination in test_combinations: 317 customized_parameters.extend(combination.parameter_modifiers()) 318 customized_parameters = set(customized_parameters) 319 320 # The function for running the test under the total set of 321 # `context_managers`: 322 def execute_test_method(): 323 requested_parameters = tf_inspect.getfullargspec(test_method).args 324 for customized_parameter in customized_parameters: 325 for argument, value in customized_parameter.modified_arguments( 326 original_kwargs.copy(), requested_parameters).items(): 327 if value is ParameterModifier.DO_NOT_PASS_TO_THE_TEST: 328 kwargs.pop(argument, None) 329 else: 330 kwargs[argument] = value 331 332 omitted_arguments = set(requested_parameters).difference( 333 set(list(kwargs.keys()) + ["self"])) 334 if omitted_arguments: 335 raise ValueError("The test requires parameters whose arguments " 336 "were not passed: {} .".format(omitted_arguments)) 337 missing_arguments = set(list(kwargs.keys()) + ["self"]).difference( 338 set(requested_parameters)) 339 if missing_arguments: 340 raise ValueError("The test does not take parameters that were passed " 341 ": {} .".format(missing_arguments)) 342 343 kwargs_to_pass = {} 344 for parameter in requested_parameters: 345 if parameter == "self": 346 kwargs_to_pass[parameter] = self 347 else: 348 kwargs_to_pass[parameter] = kwargs[parameter] 349 test_method(**kwargs_to_pass) 350 351 # Install `context_managers` before running the test: 352 context_managers = [] 353 for combination in test_combinations: 354 for manager in combination.context_managers( 355 original_kwargs.copy()): 356 context_managers.append(manager) 357 358 if hasattr(contextlib, "nested"): # Python 2 359 # TODO(isaprykin): Switch to ExitStack when contextlib2 is available. 360 with contextlib.nested(*context_managers): 361 execute_test_method() 362 else: # Python 3 363 with contextlib.ExitStack() as context_stack: 364 for manager in context_managers: 365 context_stack.enter_context(manager) 366 execute_test_method() 367 368 return decorated 369 370 371@tf_export("__internal__.test.combinations.combine", v1=[]) 372def combine(**kwargs): 373 """Generate combinations based on its keyword arguments. 374 375 Two sets of returned combinations can be concatenated using +. Their product 376 can be computed using `times()`. 377 378 Args: 379 **kwargs: keyword arguments of form `option=[possibilities, ...]` 380 or `option=the_only_possibility`. 381 382 Returns: 383 a list of dictionaries for each combination. Keys in the dictionaries are 384 the keyword argument names. Each key has one value - one of the 385 corresponding keyword argument values. 386 """ 387 if not kwargs: 388 return [OrderedDict()] 389 390 sort_by_key = lambda k: k[0] 391 kwargs = OrderedDict(sorted(kwargs.items(), key=sort_by_key)) 392 first = list(kwargs.items())[0] 393 394 rest = dict(list(kwargs.items())[1:]) 395 rest_combined = combine(**rest) 396 397 key = first[0] 398 values = first[1] 399 if not isinstance(values, list): 400 values = [values] 401 402 return [ 403 OrderedDict(sorted(list(combined.items()) + [(key, v)], key=sort_by_key)) 404 for v in values 405 for combined in rest_combined 406 ] 407 408 409@tf_export("__internal__.test.combinations.times", v1=[]) 410def times(*combined): 411 """Generate a product of N sets of combinations. 412 413 times(combine(a=[1,2]), combine(b=[3,4])) == combine(a=[1,2], b=[3,4]) 414 415 Args: 416 *combined: N lists of dictionaries that specify combinations. 417 418 Returns: 419 a list of dictionaries for each combination. 420 421 Raises: 422 ValueError: if some of the inputs have overlapping keys. 423 """ 424 assert combined 425 426 if len(combined) == 1: 427 return combined[0] 428 429 first = combined[0] 430 rest_combined = times(*combined[1:]) 431 432 combined_results = [] 433 for a in first: 434 for b in rest_combined: 435 if set(a.keys()).intersection(set(b.keys())): 436 raise ValueError("Keys need to not overlap: {} vs {}".format( 437 a.keys(), b.keys())) 438 439 combined_results.append(OrderedDict(list(a.items()) + list(b.items()))) 440 return combined_results 441 442 443@tf_export("__internal__.test.combinations.NamedObject", v1=[]) 444class NamedObject(object): 445 """A class that translates an object into a good test name.""" 446 447 def __init__(self, name, obj): 448 object.__init__(self) 449 self._name = name 450 self._obj = obj 451 452 def __getattr__(self, name): 453 return getattr(self._obj, name) 454 455 def __call__(self, *args, **kwargs): 456 return self._obj(*args, **kwargs) 457 458 def __iter__(self): 459 return self._obj.__iter__() 460 461 def __repr__(self): 462 return self._name 463 464 465def _get_name(value, index): 466 return re.sub("0[xX][0-9a-fA-F]+", str(index), str(value)) 467