1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""This module customizes `test_combinations` for `tf.distribute.Strategy`. 16 17Additionally it provides `generate()`, `combine()` and `times()` with 18`tf.distribute.Strategy` customizations as a default. 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import collections 26import copy 27import re 28import sys 29import types 30import unittest 31 32from absl import app 33import six 34 35 36from tensorflow.python.client import session 37from tensorflow.python.distribute import collective_all_reduce_strategy 38from tensorflow.python.distribute import distribute_lib 39from tensorflow.python.distribute import multi_process_runner 40from tensorflow.python.distribute import multi_worker_test_base 41from tensorflow.python.eager import context 42from tensorflow.python.eager import def_function 43from tensorflow.python.framework import combinations as framework_combinations 44from tensorflow.python.framework import config 45from tensorflow.python.framework import ops 46from tensorflow.python.framework import test_combinations as combinations_lib 47from tensorflow.python.framework import test_util 48from tensorflow.python.platform import flags 49from tensorflow.python.platform import tf_logging as logging 50from tensorflow.python.util import tf_decorator 51from tensorflow.python.util import tf_inspect 52from tensorflow.python.util.tf_export import tf_export 53 54 55# TODO(rchao): Rename `distribution` parameter to `strategy` or 56# `distribute_strategy` in all tests. 57class DistributionParameter(combinations_lib.ParameterModifier): 58 """Transforms arguments of type `NamedDistribution`. 59 60 Convert all arguments of type `NamedDistribution` to the value of their 61 `strategy` property. 62 """ 63 64 def modified_arguments(self, kwargs, requested_parameters): 65 # Get the parameter that indicates if we need to set the `_use_policy` flag 66 # on the strategy object. This is a temporary flag for testing the variable 67 # policy rollout. 68 use_var_policy = kwargs.get("use_var_policy", None) 69 distribution_arguments = {} 70 for k, v in kwargs.items(): 71 if isinstance(v, NamedDistribution): 72 strategy = v.strategy 73 if use_var_policy: 74 strategy.extended._use_var_policy = use_var_policy 75 distribution_arguments[k] = strategy 76 return distribution_arguments 77 78 79class ClusterParameters(combinations_lib.ParameterModifier): 80 """Adds cluster parameters if a `NamedDistribution` has it. 81 82 It needs to be before DistributionParameter. 83 """ 84 85 def modified_arguments(self, kwargs, requested_parameters): 86 strategy = None 87 for _, v in kwargs.items(): 88 if isinstance(v, NamedDistribution): 89 if strategy is not None and _num_total_workers(v.has_chief, 90 v.num_workers) > 1: 91 raise ValueError("Only support one NamedDistribution for multi worker" 92 "tests.") 93 strategy = v 94 95 if strategy: 96 has_chief = strategy.has_chief 97 num_workers = strategy.num_workers 98 runner = strategy.runner 99 share_gpu = strategy.share_gpu 100 num_ps = strategy.num_ps 101 if "has_chief" in kwargs and kwargs["has_chief"] != has_chief: 102 raise ValueError( 103 "both has_chief and strategy specified but are not compatible") 104 if "num_workers" in kwargs and kwargs["num_workers"] != num_workers: 105 raise ValueError( 106 "both num_workers and strategy specified but are not compatible") 107 else: 108 has_chief = kwargs.get("has_chief", False) 109 num_workers = kwargs.get("num_workers", 1) 110 runner = kwargs.get("runner", None) 111 share_gpu = kwargs.get("share_gpu", True) 112 num_ps = kwargs.get("num_ps", 0) 113 114 # Always set cluster parameters if they're requested. So that generate() 115 # works when there's no startegy in the combinations. 116 update = {} 117 if "has_chief" in requested_parameters: 118 update["has_chief"] = has_chief 119 if "num_workers" in requested_parameters: 120 update["num_workers"] = num_workers 121 if "runner" in requested_parameters: 122 update["runner"] = runner 123 if "share_gpu" in requested_parameters: 124 update["share_gpu"] = share_gpu 125 if "num_ps" in requested_parameters: 126 update["num_ps"] = num_ps 127 return update 128 129 130class DistributionCombination(combinations_lib.TestCombination): 131 """Sets up distribution strategy for tests.""" 132 133 def should_execute_combination(self, kwargs): 134 distributions = [ 135 v for v in kwargs.values() if isinstance(v, NamedDistribution) 136 ] 137 if test_util.is_xla_enabled() and any(d.no_xla for d in distributions): 138 return ( 139 False, 140 "n/a: skipping strategy combination with no_xla=True in XLA tests") 141 return (True, None) 142 143 def parameter_modifiers(self): 144 return [ 145 DistributionParameter(), 146 combinations_lib.OptionalParameter("use_var_policy"), 147 ] 148 149 150class ClusterCombination(combinations_lib.TestCombination): 151 """Sets up multi worker tests.""" 152 153 def parameter_modifiers(self): 154 return [ClusterParameters()] 155 156 157class GPUCombination(combinations_lib.TestCombination): 158 """Enable tests to request GPU hardware and skip non-GPU combinations. 159 160 This class expects test_combinations to be generated with `NamedDistribution` 161 wrapping instances of `tf.distribute.Strategy`. 162 163 Optionally, the `required_gpus` argument is supported. GPU hardware is 164 required, if its value is `True` or > 0. 165 166 Attributes: 167 GPU_TEST: The environment is considered to have GPU hardware available if 168 the name of the program contains "test_gpu" or "test_xla_gpu". 169 """ 170 171 GPU_TEST = re.search(r"(test_2?gpu|test_xla_2?gpu)$", sys.argv[0]) 172 173 def should_execute_combination(self, kwargs): 174 distributions = [ 175 v for v in kwargs.values() if isinstance(v, NamedDistribution) 176 ] 177 required_gpus = kwargs.get("required_gpus", 0) 178 required_physical_gpus = kwargs.get("required_physical_gpus", 0) 179 180 if distributions and required_gpus: 181 raise ValueError("Do not use `required_gpus` and arguments of type " 182 "NamedDistribution together.") 183 184 number_of_required_gpus = max( 185 [required_gpus] + [required_physical_gpus] + 186 [d.required_physical_gpus or 0 for d in distributions] + 187 [d.required_gpus or 0 for d in distributions]) 188 number_of_required_physical_gpus = max( 189 [required_physical_gpus] + 190 [d.required_physical_gpus or 0 for d in distributions]) 191 192 if (required_physical_gpus and required_gpus): 193 raise ValueError("Only one of `required_physical_gpus`(number of physical" 194 " GPUs required) and `required_gpus`(total number of " 195 "GPUs required) should be set. ") 196 if not number_of_required_gpus and GPUCombination.GPU_TEST: 197 return (False, "Test that doesn't require GPUs.") 198 elif (number_of_required_gpus > 0 199 and context.num_gpus() < number_of_required_gpus): 200 return (False, ("Only {} of {} required GPUs are available.".format( 201 context.num_gpus(), number_of_required_gpus))) 202 elif number_of_required_physical_gpus > len( 203 config.list_physical_devices("GPU")): 204 return (False, 205 ("Only {} of {} required physical GPUs are available.".format( 206 config.list_physical_devices("GPU"), required_physical_gpus))) 207 else: 208 return (True, None) 209 210 def parameter_modifiers(self): 211 return [combinations_lib.OptionalParameter("required_gpus"), 212 combinations_lib.OptionalParameter("required_physical_gpus")] 213 214 215class TPUCombination(combinations_lib.TestCombination): 216 """Allow to request TPU hardware and skip non-TPU combinations. 217 218 This class expects test_combinations to be generated with `NamedDistribution` 219 wrapping instances of `tf.distribute.Strategy`. 220 221 Optionally, the `required_tpus` parameter is supported. TPU hardware is 222 required, if its argument is `True` or > 0. 223 224 Optionally, the `use_cloud_tpu` parameter is supported. If TPU hardware is 225 required by `required_tpus`, it specifically must be a Cloud TPU (specified 226 with `--tpu`) if `use_cloud_tpu` is `True`. 227 228 Attributes: 229 TPU_TEST: The environment is considered to have TPU hardware available if 230 the name of the program contains "test_tpu". 231 """ 232 233 TPU_TEST = "test_tpu" in sys.argv[0] 234 235 def should_execute_combination(self, kwargs): 236 distributions = [ 237 v for v in kwargs.values() if isinstance(v, NamedDistribution) 238 ] 239 # TODO(isaprykin): Migrate all tests away from using 'required_tpu' in favor 240 # of 'required_tpus'. 241 if "required_tpus" in kwargs and "required_tpu" in kwargs: 242 raise ValueError("Do not use `required_tpu`. Both `required_tpus` and " 243 "`required_tpu` were specified.") 244 required_tpus = kwargs.get("required_tpus", None) or kwargs.get( 245 "required_tpu", None) 246 247 if distributions and required_tpus: 248 raise ValueError("Do not use `required_tpus` and arguments of type " 249 "NamedDistribution together.") 250 251 # TODO(isaprykin): Add support for a particular number of TPUs. Right now 252 # it's binary. 253 number_of_required_tpus = max([required_tpus or 0] + 254 [d.required_tpu or 0 for d in distributions]) 255 use_cloud_tpu = any([kwargs.get("use_cloud_tpu")] + 256 [d.use_cloud_tpu for d in distributions]) 257 tpu = hasattr(flags.FLAGS, "tpu") and flags.FLAGS.tpu or "" 258 259 if not number_of_required_tpus and TPUCombination.TPU_TEST: 260 return (False, "Test that doesn't require TPUs.") 261 if number_of_required_tpus and not TPUCombination.TPU_TEST: 262 return (False, "Test requires a TPU, but it's not available.") 263 if use_cloud_tpu and not tpu: 264 return (False, "Test requires a Cloud TPU, but none specified.") 265 if not use_cloud_tpu and tpu: 266 return (False, "Test requires local TPU, but Cloud TPU specified.") 267 return (True, None) 268 269 def parameter_modifiers(self): 270 return [ 271 combinations_lib.OptionalParameter("required_tpus"), 272 combinations_lib.OptionalParameter("required_tpu"), 273 combinations_lib.OptionalParameter("use_cloud_tpu"), 274 ] 275 276 277class NamedDistribution(object): 278 """Wraps a `tf.distribute.Strategy` and adds a name for test titles.""" 279 280 def __init__(self, 281 name, 282 distribution_fn, 283 required_gpus=None, 284 required_physical_gpus=0, 285 required_tpu=False, 286 use_cloud_tpu=False, 287 has_chief=False, 288 num_workers=1, 289 num_ps=0, 290 share_gpu=True, 291 pool_runner_fn=None, 292 no_xla=False): 293 """Initialize NamedDistribution. 294 295 Args: 296 name: Name that will be a part of the name of the test case. 297 distribution_fn: A callable that creates a `tf.distribute.Strategy`. 298 required_gpus: The number of GPUs that the strategy requires. Only one of 299 `required_gpus` and `required_physical_gpus` should be set. 300 required_physical_gpus: Number of physical GPUs required. Only one of 301 `required_gpus` and `required_physical_gpus` should be set. 302 required_tpu: Whether the strategy requires TPU. 303 use_cloud_tpu: Whether the strategy requires cloud TPU. 304 has_chief: Whether the strategy requires a chief worker. 305 num_workers: The number of workers that the strategy requires. 306 num_ps: The number of parameter servers. 307 share_gpu: Whether to share GPUs among workers. 308 pool_runner_fn: An optional callable that returns a MultiProcessPoolRunner 309 to run the test. 310 no_xla: Whether to skip in XLA tests. 311 """ 312 object.__init__(self) 313 self._name = name 314 self._distribution_fn = distribution_fn 315 self.required_gpus = required_gpus 316 self.required_physical_gpus = required_physical_gpus 317 self.required_tpu = required_tpu 318 self.use_cloud_tpu = use_cloud_tpu 319 self.has_chief = has_chief 320 self.num_workers = num_workers 321 self.num_ps = num_ps 322 self.share_gpu = share_gpu 323 self._pool_runner_fn = pool_runner_fn 324 self.no_xla = no_xla 325 326 @property 327 def runner(self): 328 if self._pool_runner_fn is not None: 329 return self._pool_runner_fn() 330 return None 331 332 @property 333 def strategy(self): 334 return self._distribution_fn() 335 336 def __repr__(self): 337 return self._name 338 339 340# This is to allow adding combinations that runs a function both as a 341# tf.function and eagerly. 342# 343# @combinations.generate( 344# combinations.combine( 345# tf_function = [combinations.tf_function, combinations.no_tf_function] 346# ) 347# ) 348# def testXXX(tf_function): 349# @tf_function 350# def foo(): 351# tf.add(1., 1.) 352# 353# foo() 354tf_function = combinations_lib.NamedObject("TfFunction", def_function.function) 355no_tf_function = combinations_lib.NamedObject("NoTfFunction", lambda f: f) 356 357 358def concat(*combined): 359 """Concats combinations.""" 360 result = [] 361 for one in combined: 362 result += one 363 return result 364 365 366@tf_export("__internal__.distribute.combinations.generate", v1=[]) 367def generate(combinations, test_combinations=()): 368 # pylint: disable=g-doc-args,g-doc-return-or-yield 369 """Distributed adapter of `tf.__internal__.test.combinations.generate`. 370 371 All tests with distributed strategy should use this one instead of 372 `tf.__internal__.test.combinations.generate`. This function has support of 373 strategy combinations, GPU/TPU and multi worker support. 374 375 See `tf.__internal__.test.combinations.generate` for usage. 376 """ 377 # pylint: enable=g-doc-args,g-doc-return-or-yield 378 default_combinations = ( 379 framework_combinations.EagerGraphCombination(), 380 framework_combinations.TFVersionCombination(), 381 ClusterCombination(), 382 DistributionCombination(), 383 GPUCombination(), 384 TPUCombination(), 385 ) 386 # We apply our own decoration to handle multi worker tests before applying 387 # framework.test_combinations.generate. The order is important since we need 388 # framework.test_combinations.generate to apply all parameter modifiers first. 389 combination_decorator = combinations_lib.generate( 390 combinations, test_combinations=default_combinations + test_combinations) 391 392 def decorator(test_method_or_class): 393 if isinstance(test_method_or_class, type): 394 # If it's a test class. 395 class_object = test_method_or_class 396 # Decorate each test method with _multi_worker_test. 397 for name, test_method in six.iteritems(class_object.__dict__.copy()): 398 if (name.startswith(unittest.TestLoader.testMethodPrefix) and 399 isinstance(test_method, types.FunctionType)): 400 setattr(class_object, name, _multi_worker_test(test_method)) 401 return combination_decorator(class_object) 402 else: 403 return combination_decorator(_multi_worker_test(test_method_or_class)) 404 405 return decorator 406 407 408combine = combinations_lib.combine 409times = combinations_lib.times 410NamedObject = combinations_lib.NamedObject 411 412 413# Identifies whether we're in the main process or worker processes. 414# `_multi_worker_test` decoration behaves differently in the main processs and 415# the worker processes. See the documentation of _multi_worker_test for detail. 416_running_in_worker = False 417 418 419def in_main_process(): 420 """Whether it's in the main test process. 421 422 This is normally used to prepare the test environment which should only happen 423 in the main process. 424 425 Returns: 426 A boolean. 427 """ 428 return not _running_in_worker 429 430 431class TestEnvironment(object): 432 433 def __init__(self): 434 self.tf_data_service_dispatcher = None 435 # Note that this includes GPUs that may not be visible to the current 436 # worker. 437 self.total_phsyical_gpus = None 438 439 def __setattr__(self, name, value): 440 if not in_main_process(): 441 raise ValueError( 442 "combinations.env() should only be modified in the main process. " 443 "Condition your code on combinations.in_main_process().") 444 super().__setattr__(name, value) 445 446 447_env = TestEnvironment() 448 449 450def env(): 451 """Returns the object holds the test environment information. 452 453 Tests should modifies this in the main process if needed, and it will be 454 passed to the worker processes each time a test case is ran. 455 456 Returns: 457 a TestEnvironment object. 458 """ 459 return _env 460 461 462def _set_total_phsyical_gpus(): 463 if in_main_process(): 464 env().total_phsyical_gpus = len( 465 context.context().list_physical_devices("GPU")) 466 467 468# This is needed in case CUDA is lazily loaded. 469app.call_after_init(_set_total_phsyical_gpus) 470 471 472_TestResult = collections.namedtuple("_TestResult", ["status", "message"]) 473 474 475def _test_runner(test_id, test_env): 476 """Executes the test with the given test_id. 477 478 This is a simple wrapper around TestRunner to be used with 479 multi_process_runner. Similar to test.main(), but it executes only one test 480 specified by test_id and returns whether the test succeeds. If the test fails, 481 the function prints failures and errors to stdout. 482 483 Args: 484 test_id: TestCase.id() 485 test_env: a TestEnvironment object. 486 487 Returns: 488 A boolean indicates whether the test succeeds. 489 """ 490 global _running_in_worker, _env 491 # No need to restore the value of _running_in_worker since it should always be 492 # True in worker processes. 493 _running_in_worker = True 494 _env = test_env 495 test = unittest.defaultTestLoader.loadTestsFromName(test_id) 496 runner = unittest.TextTestRunner() 497 result = runner.run(test) 498 # Treat expected failures as failures, so that the main process can get 499 # them and fail as expected. Also treat errors as failures to simplify the 500 # handling. 501 failures = result.failures + result.expectedFailures + result.errors 502 if failures: 503 ret = _TestResult(status="failure", message=failures[0][1]) 504 elif result.skipped: 505 ret = _TestResult(status="skipped", message=result.skipped[0][1]) 506 else: 507 # Treat unexpectedSuccesses as OK so that the test case in the main process 508 # succeed as well. 509 ret = _TestResult(status="ok", message=None) 510 # Print tracebacks to stdout and multi_process_runner will collect 511 # them and stream back to the main process. 512 if ret.message: 513 print(ret.message) 514 return ret 515 516 517def _multi_worker_test(test_method): 518 """Decorate test_method so that it runs in each worker. 519 520 We use `multi_process_runner` to simulate multiple workers. Since we run the 521 this function in the main process and all worker processes, this decoration 522 behaves differently in the main process and worker procssses. In the main 523 process, it spawns subprocesses and runs the test on each of them; in a worker 524 process, it executes test in the same way as a normal test, e.g. 525 setUp()/tearDown() are called before/after the test. 526 527 Args: 528 test_method: a function which must be a test method. 529 530 Returns: 531 Decorated `test_method`. Note that the decorated function has additional 532 arguments. 533 """ 534 535 def decorator(self, has_chief, num_workers, num_ps, share_gpu, runner, 536 **kwargs): 537 if _num_total_workers(has_chief, 538 num_workers) == 1 or _running_in_worker or ( 539 # Use in-process cluster for PS combinations 540 # when XLA is enabled. 541 test_util.is_xla_enabled() and num_ps > 0): 542 # We're in worker process or the test is for single worker. Either case we 543 # execute the test method directly instead of spawning subprocesses. 544 545 # For MultiWorkerMirroredStrategy(CollectiveAllReduceStrategy), install a 546 # session that connects to the local server. This is necessary for multi 547 # worker graph mode tests to work. Those tests cannot use their graphs or 548 # sessions, including the one returned by self.cached_session(). Since 549 # existing tests may already be doing so, we only install the session for 550 # multi worker tests. 551 with _multi_worker_session(kwargs): 552 test_method(self, **kwargs) 553 return 554 555 # We're in the main process. We spawn subprocesses and run the *test* on 556 # each of them. Note that we're not directly executing test_method passed to 557 # _multi_worker_test, because we need setUp()/tearDown() to be called and 558 # all the decorations on the test method. The conceptual call stack is: 559 # [main process]test.main() 560 # [main process]test_runner.run(test) 561 # [main process]wrapper by combinations.generate() 562 # [main process]_multi_worker_test.decorator() 563 # # A sub process goes through the same code path as the main 564 # # process. 565 # [sub process]_test_runner() 566 # [sub process]test_runner.run(test) 567 # [sub process]wrapper by combinations.generate() 568 # [sub process]_multi_worker_test.decorator() 569 # # _running_in_worker is True 570 # [sub process]test_method() 571 test_id = self.id() 572 if runner: 573 results = runner.run(_test_runner, args=(test_id, _env)) 574 else: 575 cluster_spec = multi_worker_test_base.create_cluster_spec( 576 has_chief=has_chief, 577 num_workers=num_workers, 578 num_ps=num_ps, 579 has_eval=False) 580 ephemeral_runner = multi_process_runner.MultiProcessRunner( 581 _test_runner, 582 cluster_spec, 583 share_gpu=share_gpu, 584 args=(test_id, _env), 585 dependence_on_chief=has_chief) 586 ephemeral_runner.start() 587 results = ephemeral_runner.join().return_value 588 589 skip_reason = None 590 for result in results: 591 if result.status == "failure": 592 # We can't tell which worker the return value come from, so we fail on 593 # the first error. 594 self.fail(result.message) 595 break 596 elif result.status == "skipped": 597 # Record the skip reason, but do not actually skip the test in case some 598 # processes fail instead. 599 skip_reason = result.message 600 if skip_reason is not None: 601 self.skipTest(skip_reason) 602 603 argspec = tf_inspect.getfullargspec(test_method) 604 decorator_args = (argspec.args or []) + [ 605 "has_chief", "num_workers", "num_ps", "share_gpu", "runner" 606 ] 607 decorator_argspec = argspec._replace(args=decorator_args) 608 return tf_decorator.make_decorator( 609 test_method, decorator, decorator_argspec=decorator_argspec) 610 611 612def _num_total_workers(has_chief, num_workers): 613 """Returns the number of workers including the chief.""" 614 if has_chief: 615 return num_workers + 1 616 return num_workers 617 618 619def _multi_worker_session(kwargs): 620 """Returns a context manager that enters a session that is configured for the MultiWorkerMirroredStrategy. 621 622 Args: 623 kwargs: a dict. Keyword arguments passed to the test. 624 625 Returns: 626 A context manager. If MultiWorkerMirroredStrategy is the one and only one 627 strategy in kwargs and it's in graph mode, it's the seesion that is 628 configured for that strategy. Otherwise, it's a no-op context manager. 629 """ 630 strategy = None 631 for _, v in kwargs.items(): 632 if isinstance(v, distribute_lib.StrategyBase): 633 if strategy is not None: 634 logging.warning( 635 "The test uses multiple strategies. Skipping " 636 "entering a session that is configured for the strategy.") 637 return ops.NullContextmanager() 638 strategy = v 639 if context.executing_eagerly() or not isinstance( 640 strategy, collective_all_reduce_strategy.CollectiveAllReduceStrategy): 641 return ops.NullContextmanager() 642 sess_config = copy.deepcopy(context.context().config) 643 sess_config = strategy.update_config_proto(sess_config) 644 target = strategy.cluster_resolver.master() 645 return session.Session(config=sess_config, target=target).as_default() 646