1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""This module customizes `test_combinations` for `tf.distribute.Strategy`. 16 17Additionally it provides `generate()`, `combine()` and `times()` with 18`tf.distribute.Strategy` customizations as a default. 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import collections 26import copy 27import re 28import sys 29import types 30import unittest 31 32import six 33 34from tensorflow.python.client import session 35from tensorflow.python.distribute import collective_all_reduce_strategy 36from tensorflow.python.distribute import distribute_lib 37from tensorflow.python.distribute import multi_process_runner 38from tensorflow.python.distribute import multi_worker_test_base 39from tensorflow.python.eager import context 40from tensorflow.python.eager import def_function 41from tensorflow.python.framework import combinations as framework_combinations 42from tensorflow.python.framework import ops 43from tensorflow.python.framework import test_combinations as combinations_lib 44from tensorflow.python.framework import test_util 45from tensorflow.python.platform import flags 46from tensorflow.python.platform import tf_logging as logging 47from tensorflow.python.util import tf_decorator 48from tensorflow.python.util import tf_inspect 49from tensorflow.python.util.tf_export import tf_export 50 51 52# TODO(rchao): Rename `distribution` parameter to `strategy` or 53# `distribute_strategy` in all tests. 54class DistributionParameter(combinations_lib.ParameterModifier): 55 """Transforms arguments of type `NamedDistribution`. 56 57 Convert all arguments of type `NamedDistribution` to the value of their 58 `strategy` property. 59 """ 60 61 def modified_arguments(self, kwargs, requested_parameters): 62 # Get the parameter that indicates if we need to set the `_use_policy` flag 63 # on the strategy object. This is a temporary flag for testing the variable 64 # policy rollout. 65 use_var_policy = kwargs.get("use_var_policy", None) 66 distribution_arguments = {} 67 for k, v in kwargs.items(): 68 if isinstance(v, NamedDistribution): 69 strategy = v.strategy 70 if use_var_policy: 71 strategy.extended._use_var_policy = use_var_policy 72 distribution_arguments[k] = strategy 73 return distribution_arguments 74 75 76class ClusterParameters(combinations_lib.ParameterModifier): 77 """Adds cluster parameters if a `NamedDistribution` has it. 78 79 It needs to be before DistributionParameter. 80 """ 81 82 def modified_arguments(self, kwargs, requested_parameters): 83 strategy = None 84 for _, v in kwargs.items(): 85 if isinstance(v, NamedDistribution): 86 if strategy is not None and _num_total_workers(v.has_chief, 87 v.num_workers) > 1: 88 raise ValueError("Only support one NamedDistribution for multi worker" 89 "tests.") 90 strategy = v 91 92 if strategy: 93 has_chief = strategy.has_chief 94 num_workers = strategy.num_workers 95 runner = strategy.runner 96 if "has_chief" in kwargs and kwargs["has_chief"] != has_chief: 97 raise ValueError( 98 "both has_chief and strategy specified but are not compatible") 99 if "num_workers" in kwargs and kwargs["num_workers"] != num_workers: 100 raise ValueError( 101 "both num_workers and strategy specified but are not compatible") 102 else: 103 has_chief = kwargs.get("has_chief", False) 104 num_workers = kwargs.get("num_workers", 1) 105 runner = kwargs.get("runner", None) 106 107 # Always set cluster parameters if they're requested. So that generate() 108 # works when there's no startegy in the combinations. 109 update = {} 110 if "has_chief" in requested_parameters: 111 update["has_chief"] = has_chief 112 if "num_workers" in requested_parameters: 113 update["num_workers"] = num_workers 114 if "runner" in requested_parameters: 115 update["runner"] = runner 116 return update 117 118 119class DistributionCombination(combinations_lib.TestCombination): 120 """Sets up distribution strategy for tests.""" 121 122 def should_execute_combination(self, kwargs): 123 distributions = [ 124 v for v in kwargs.values() if isinstance(v, NamedDistribution) 125 ] 126 if test_util.is_xla_enabled() and any(d.no_xla for d in distributions): 127 return ( 128 False, 129 "n/a: skipping strategy combination with no_xla=True in XLA tests") 130 return (True, None) 131 132 def parameter_modifiers(self): 133 return [ 134 DistributionParameter(), 135 combinations_lib.OptionalParameter("use_var_policy"), 136 ] 137 138 139class ClusterCombination(combinations_lib.TestCombination): 140 """Sets up multi worker tests.""" 141 142 def parameter_modifiers(self): 143 return [ClusterParameters()] 144 145 146class GPUCombination(combinations_lib.TestCombination): 147 """Enable tests to request GPU hardware and skip non-GPU combinations. 148 149 This class expects test_combinations to be generated with `NamedDistribution` 150 wrapping instances of `tf.distribute.Strategy`. 151 152 Optionally, the `required_gpus` argument is supported. GPU hardware is 153 required, if its value is `True` or > 0. 154 155 Attributes: 156 GPU_TEST: The environment is considered to have GPU hardware available if 157 the name of the program contains "test_gpu" or "test_xla_gpu". 158 """ 159 160 GPU_TEST = re.search(r"(test_gpu|test_xla_gpu)$", sys.argv[0]) 161 162 def should_execute_combination(self, kwargs): 163 distributions = [ 164 v for v in kwargs.values() if isinstance(v, NamedDistribution) 165 ] 166 required_gpus = kwargs.get("required_gpus", None) 167 168 if distributions and required_gpus: 169 raise ValueError("Do not use `required_gpus` and arguments of type " 170 "NamedDistribution together.") 171 172 number_of_required_gpus = max([required_gpus or 0] + 173 [d.required_gpus or 0 for d in distributions]) 174 175 if not number_of_required_gpus and GPUCombination.GPU_TEST: 176 return (False, "Test that doesn't require GPUs.") 177 elif (number_of_required_gpus > 0 178 and context.num_gpus() < number_of_required_gpus): 179 return (False, ("Only {} of {} required GPUs are available.".format( 180 context.num_gpus(), number_of_required_gpus))) 181 else: 182 return (True, None) 183 184 def parameter_modifiers(self): 185 return [combinations_lib.OptionalParameter("required_gpus")] 186 187 188class TPUCombination(combinations_lib.TestCombination): 189 """Allow to request TPU hardware and skip non-TPU combinations. 190 191 This class expects test_combinations to be generated with `NamedDistribution` 192 wrapping instances of `tf.distribute.Strategy`. 193 194 Optionally, the `required_tpus` parameter is supported. TPU hardware is 195 required, if its argument is `True` or > 0. 196 197 Optionally, the `use_cloud_tpu` parameter is supported. If TPU hardware is 198 required by `required_tpus`, it specifically must be a Cloud TPU (specified 199 with `--tpu`) if `use_cloud_tpu` is `True`. 200 201 Attributes: 202 TPU_TEST: The environment is considered to have TPU hardware available if 203 the name of the program contains "test_tpu". 204 """ 205 206 TPU_TEST = "test_tpu" in sys.argv[0] 207 208 def should_execute_combination(self, kwargs): 209 distributions = [ 210 v for v in kwargs.values() if isinstance(v, NamedDistribution) 211 ] 212 # TODO(isaprykin): Migrate all tests away from using 'required_tpu' in favor 213 # of 'required_tpus'. 214 if "required_tpus" in kwargs and "required_tpu" in kwargs: 215 raise ValueError("Do not use `required_tpu`. Both `required_tpus` and " 216 "`required_tpu` were specified.") 217 required_tpus = kwargs.get("required_tpus", None) or kwargs.get( 218 "required_tpu", None) 219 220 if distributions and required_tpus: 221 raise ValueError("Do not use `required_tpus` and arguments of type " 222 "NamedDistribution together.") 223 224 # TODO(isaprykin): Add support for a particular number of TPUs. Right now 225 # it's binary. 226 number_of_required_tpus = max([required_tpus or 0] + 227 [d.required_tpu or 0 for d in distributions]) 228 use_cloud_tpu = any([kwargs.get("use_cloud_tpu")] + 229 [d.use_cloud_tpu for d in distributions]) 230 tpu = hasattr(flags.FLAGS, "tpu") and flags.FLAGS.tpu or "" 231 232 if not number_of_required_tpus and TPUCombination.TPU_TEST: 233 return (False, "Test that doesn't require TPUs.") 234 if number_of_required_tpus and not TPUCombination.TPU_TEST: 235 return (False, "Test requires a TPU, but it's not available.") 236 if use_cloud_tpu and not tpu: 237 return (False, "Test requires a Cloud TPU, but none specified.") 238 if not use_cloud_tpu and tpu: 239 return (False, "Test requires local TPU, but Cloud TPU specified.") 240 return (True, None) 241 242 def parameter_modifiers(self): 243 return [ 244 combinations_lib.OptionalParameter("required_tpus"), 245 combinations_lib.OptionalParameter("required_tpu"), 246 combinations_lib.OptionalParameter("use_cloud_tpu"), 247 ] 248 249 250class NamedDistribution(object): 251 """Wraps a `tf.distribute.Strategy` and adds a name for test titles.""" 252 253 def __init__(self, 254 name, 255 distribution_fn, 256 required_gpus=None, 257 required_tpu=False, 258 use_cloud_tpu=False, 259 has_chief=False, 260 num_workers=1, 261 pool_runner_fn=None, 262 no_xla=False): 263 """Initialize NamedDistribution. 264 265 Args: 266 name: Name that will be a part of the name of the test case. 267 distribution_fn: A callable that creates a `tf.distribute.Strategy`. 268 required_gpus: The number of GPUs that the strategy requires. 269 required_tpu: Whether the strategy requires TPU. 270 use_cloud_tpu: Whether the strategy requires cloud TPU. 271 has_chief: Whether the strategy requires a chief worker. 272 num_workers: The number of workers that the strategy requires. 273 pool_runner_fn: An optional callable that returns a MultiProcessPoolRunner 274 to run the test. 275 no_xla: Whether to skip in XLA tests. 276 """ 277 object.__init__(self) 278 self._name = name 279 self._distribution_fn = distribution_fn 280 self.required_gpus = required_gpus 281 self.required_tpu = required_tpu 282 self.use_cloud_tpu = use_cloud_tpu 283 self.has_chief = has_chief 284 self.num_workers = num_workers 285 self._pool_runner_fn = pool_runner_fn 286 self.no_xla = no_xla 287 288 @property 289 def runner(self): 290 if self._pool_runner_fn is not None: 291 return self._pool_runner_fn() 292 return None 293 294 @property 295 def strategy(self): 296 return self._distribution_fn() 297 298 def __repr__(self): 299 return self._name 300 301 302# This is to allow adding combinations that runs a function both as a 303# tf.function and eagerly. 304# 305# @combinations.generate( 306# combinations.combine( 307# tf_function = [combinations.tf_function, combinations.no_tf_function] 308# ) 309# ) 310# def testXXX(tf_function): 311# @tf_function 312# def foo(): 313# tf.add(1., 1.) 314# 315# foo() 316tf_function = combinations_lib.NamedObject("TfFunction", def_function.function) 317no_tf_function = combinations_lib.NamedObject("NoTfFunction", lambda f: f) 318 319 320def concat(*combined): 321 """Concats combinations.""" 322 result = [] 323 for one in combined: 324 result += one 325 return result 326 327 328@tf_export("__internal__.distribute.combinations.generate", v1=[]) 329def generate(combinations, test_combinations=()): 330 # pylint: disable=g-doc-args,g-doc-return-or-yield 331 """Distributed adapter of `tf.__internal__.test.combinations.generate`. 332 333 All tests with distributed strategy should use this one instead of 334 `tf.__internal__.test.combinations.generate`. This function has support of 335 strategy combinations, GPU/TPU and multi worker support. 336 337 See `tf.__internal__.test.combinations.generate` for usage. 338 """ 339 # pylint: enable=g-doc-args,g-doc-return-or-yield 340 default_combinations = ( 341 framework_combinations.EagerGraphCombination(), 342 framework_combinations.TFVersionCombination(), 343 ClusterCombination(), 344 DistributionCombination(), 345 GPUCombination(), 346 TPUCombination(), 347 ) 348 # We apply our own decoration to handle multi worker tests before applying 349 # framework.test_combinations.generate. The order is important since we need 350 # framework.test_combinations.generate to apply all parameter modifiers first. 351 combination_decorator = combinations_lib.generate( 352 combinations, test_combinations=default_combinations + test_combinations) 353 354 def decorator(test_method_or_class): 355 if isinstance(test_method_or_class, type): 356 # If it's a test class. 357 class_object = test_method_or_class 358 # Decorate each test method with _multi_worker_test. 359 for name, test_method in six.iteritems(class_object.__dict__.copy()): 360 if (name.startswith(unittest.TestLoader.testMethodPrefix) and 361 isinstance(test_method, types.FunctionType)): 362 setattr(class_object, name, _multi_worker_test(test_method)) 363 return combination_decorator(class_object) 364 else: 365 return combination_decorator(_multi_worker_test(test_method_or_class)) 366 367 return decorator 368 369 370combine = combinations_lib.combine 371times = combinations_lib.times 372NamedObject = combinations_lib.NamedObject 373 374 375# Identifies whether we're in the main process or worker processes. 376# `_multi_worker_test` decoration behaves differently in the main processs and 377# the worker processes. See the documentation of _multi_worker_test for detail. 378_running_in_worker = False 379 380 381def in_main_process(): 382 """Whether it's in the main test process. 383 384 This is normally used to prepare the test environment which should only happen 385 in the main process. 386 387 Returns: 388 A boolean. 389 """ 390 return not _running_in_worker 391 392 393class TestEnvironment(object): 394 395 def __init__(self): 396 self.tf_data_service_dispatcher = None 397 398 def __setattr__(self, name, value): 399 if not in_main_process(): 400 raise ValueError( 401 "combinations.env() should only be modified in the main process. " 402 "Condition your code on combinations.in_main_process().") 403 super().__setattr__(name, value) 404 405 406_env = TestEnvironment() 407 408 409def env(): 410 """Returns the object holds the test environment information. 411 412 Tests should modifies this in the main process if needed, and it will be 413 passed to the worker processes each time a test case is ran. 414 415 Returns: 416 a TestEnvironment object. 417 """ 418 return _env 419 420 421_TestResult = collections.namedtuple("_TestResult", ["status", "message"]) 422 423 424def _test_runner(test_id, test_env): 425 """Executes the test with the given test_id. 426 427 This is a simple wrapper around TestRunner to be used with 428 multi_process_runner. Similar to test.main(), but it executes only one test 429 specified by test_id and returns whether the test succeeds. If the test fails, 430 the function prints failures and errors to stdout. 431 432 Args: 433 test_id: TestCase.id() 434 test_env: a TestEnvironment object. 435 436 Returns: 437 A boolean indicates whether the test succeeds. 438 """ 439 global _running_in_worker, _env 440 # No need to restore the value of _running_in_worker since it should always be 441 # True in worker processes. 442 _running_in_worker = True 443 _env = test_env 444 test = unittest.defaultTestLoader.loadTestsFromName(test_id) 445 runner = unittest.TextTestRunner() 446 result = runner.run(test) 447 # Treat expected failures as failures, so that the main process can get 448 # them and fail as expected. Also treat errors as failures to simplify the 449 # handling. 450 failures = result.failures + result.expectedFailures + result.errors 451 if failures: 452 ret = _TestResult(status="failure", message=failures[0][1]) 453 elif result.skipped: 454 ret = _TestResult(status="skipped", message=result.skipped[0][1]) 455 else: 456 # Treat unexpectedSuccesses as OK so that the test case in the main process 457 # succeed as well. 458 ret = _TestResult(status="ok", message=None) 459 # Print tracebacks to stdout and multi_process_runner will collect 460 # them and stream back to the main process. 461 if ret.message: 462 print(ret.message) 463 return ret 464 465 466def _multi_worker_test(test_method): 467 """Decorate test_method so that it runs in each worker. 468 469 We use `multi_process_runner` to simulate multiple workers. Since we run the 470 this function in the main process and all worker processes, this decoration 471 behaves differently in the main process and worker procssses. In the main 472 process, it spawns subprocesses and runs the test on each of them; in a worker 473 process, it executes test in the same way as a normal test, e.g. 474 setUp()/tearDown() are called before/after the test. 475 476 Args: 477 test_method: a function which must be a test method. 478 479 Returns: 480 Decorated `test_method`. Note that the decorated function has additional 481 arguments. 482 """ 483 484 def decorator(self, has_chief, num_workers, runner, **kwargs): 485 if _num_total_workers(has_chief, num_workers) == 1 or _running_in_worker: 486 # We're in worker process or the test is for single worker. Either case we 487 # execute the test method directly instead of spawning subprocesses. 488 489 # For MultiWorkerMirroredStrategy(CollectiveAllReduceStrategy), install a 490 # session that connects to the local server. This is necessary for multi 491 # worker graph mode tests to work. Those tests cannot use their graphs or 492 # sessions, including the one returned by self.cached_session(). Since 493 # existing tests may already be doing so, we only install the session for 494 # multi worker tests. 495 with _multi_worker_session(kwargs): 496 test_method(self, **kwargs) 497 return 498 499 # We're in the main process. We spawn subprocesses and run the *test* on 500 # each of them. Note that we're not directly executing test_method passed to 501 # _multi_worker_test, because we need setUp()/tearDown() to be called and 502 # all the decorations on the test method. The conceptual call stack is: 503 # [main process]test.main() 504 # [main process]test_runner.run(test) 505 # [main process]wrapper by combinations.generate() 506 # [main process]_multi_worker_test.decorator() 507 # # A sub process goes through the same code path as the main 508 # # process. 509 # [sub process]_test_runner() 510 # [sub process]test_runner.run(test) 511 # [sub process]wrapper by combinations.generate() 512 # [sub process]_multi_worker_test.decorator() 513 # # _running_in_worker is True 514 # [sub process]test_method() 515 test_id = self.id() 516 if runner: 517 results = runner.run(_test_runner, args=(test_id, _env)) 518 else: 519 cluster_spec = multi_worker_test_base.create_cluster_spec( 520 has_chief=has_chief, 521 num_workers=num_workers, 522 num_ps=0, 523 has_eval=False) 524 results = multi_process_runner.run( 525 _test_runner, cluster_spec, args=(test_id, _env)).return_value 526 527 skip_reason = None 528 for result in results: 529 if result.status == "failure": 530 # We can't tell which worker the return value come from, so we fail on 531 # the first error. 532 self.fail(result.message) 533 break 534 elif result.status == "skipped": 535 # Record the skip reason, but do not actually skip the test in case some 536 # processes fail instead. 537 skip_reason = result.message 538 if skip_reason is not None: 539 self.skipTest(skip_reason) 540 541 argspec = tf_inspect.getfullargspec(test_method) 542 decorator_args = (argspec.args or []) + ["has_chief", "num_workers", "runner"] 543 decorator_argspec = argspec._replace(args=decorator_args) 544 return tf_decorator.make_decorator( 545 test_method, decorator, decorator_argspec=decorator_argspec) 546 547 548def _num_total_workers(has_chief, num_workers): 549 """Returns the number of workers including the chief.""" 550 if has_chief: 551 return num_workers + 1 552 return num_workers 553 554 555def _multi_worker_session(kwargs): 556 """Returns a context manager that enters a session that is configured for the MultiWorkerMirroredStrategy. 557 558 Args: 559 kwargs: a dict. Keyword arguments passed to the test. 560 561 Returns: 562 A context manager. If MultiWorkerMirroredStrategy is the one and only one 563 strategy in kwargs and it's in graph mode, it's the seesion that is 564 configured for that strategy. Otherwise, it's a no-op context manager. 565 """ 566 strategy = None 567 for _, v in kwargs.items(): 568 if isinstance(v, distribute_lib.StrategyBase): 569 if strategy is not None: 570 logging.warning( 571 "The test uses multiple strategies. Skipping " 572 "entering a session that is configured for the strategy.") 573 return ops.NullContextmanager() 574 strategy = v 575 if context.executing_eagerly() or not isinstance( 576 strategy, collective_all_reduce_strategy.CollectiveAllReduceStrategy): 577 return ops.NullContextmanager() 578 sess_config = copy.deepcopy(context.context().config) 579 sess_config = strategy.update_config_proto(sess_config) 580 target = strategy.cluster_resolver.master() 581 return session.Session(config=sess_config, target=target).as_default() 582