• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""This module customizes `test_combinations` for `tf.distribute.Strategy`.
16
17Additionally it provides `generate()`, `combine()` and `times()` with
18`tf.distribute.Strategy` customizations as a default.
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import collections
26import copy
27import re
28import sys
29import types
30import unittest
31
32from absl import app
33import six
34
35
36from tensorflow.python.client import session
37from tensorflow.python.distribute import collective_all_reduce_strategy
38from tensorflow.python.distribute import distribute_lib
39from tensorflow.python.distribute import multi_process_runner
40from tensorflow.python.distribute import multi_worker_test_base
41from tensorflow.python.eager import context
42from tensorflow.python.eager import def_function
43from tensorflow.python.framework import combinations as framework_combinations
44from tensorflow.python.framework import config
45from tensorflow.python.framework import ops
46from tensorflow.python.framework import test_combinations as combinations_lib
47from tensorflow.python.framework import test_util
48from tensorflow.python.platform import flags
49from tensorflow.python.platform import tf_logging as logging
50from tensorflow.python.util import tf_decorator
51from tensorflow.python.util import tf_inspect
52from tensorflow.python.util.tf_export import tf_export
53
54
55# TODO(rchao): Rename `distribution` parameter to `strategy` or
56# `distribute_strategy` in all tests.
57class DistributionParameter(combinations_lib.ParameterModifier):
58  """Transforms arguments of type `NamedDistribution`.
59
60  Convert all arguments of type `NamedDistribution` to the value of their
61  `strategy` property.
62  """
63
64  def modified_arguments(self, kwargs, requested_parameters):
65    # Get the parameter that indicates if we need to set the `_use_policy` flag
66    # on the strategy object. This is a temporary flag for testing the variable
67    # policy rollout.
68    use_var_policy = kwargs.get("use_var_policy", None)
69    distribution_arguments = {}
70    for k, v in kwargs.items():
71      if isinstance(v, NamedDistribution):
72        strategy = v.strategy
73        if use_var_policy:
74          strategy.extended._use_var_policy = use_var_policy
75        distribution_arguments[k] = strategy
76    return distribution_arguments
77
78
79class ClusterParameters(combinations_lib.ParameterModifier):
80  """Adds cluster parameters if a `NamedDistribution` has it.
81
82  It needs to be before DistributionParameter.
83  """
84
85  def modified_arguments(self, kwargs, requested_parameters):
86    strategy = None
87    for _, v in kwargs.items():
88      if isinstance(v, NamedDistribution):
89        if strategy is not None and _num_total_workers(v.has_chief,
90                                                       v.num_workers) > 1:
91          raise ValueError("Only support one NamedDistribution for multi worker"
92                           "tests.")
93        strategy = v
94
95    if strategy:
96      has_chief = strategy.has_chief
97      num_workers = strategy.num_workers
98      runner = strategy.runner
99      share_gpu = strategy.share_gpu
100      num_ps = strategy.num_ps
101      if "has_chief" in kwargs and kwargs["has_chief"] != has_chief:
102        raise ValueError(
103            "both has_chief and strategy specified but are not compatible")
104      if "num_workers" in kwargs and kwargs["num_workers"] != num_workers:
105        raise ValueError(
106            "both num_workers and strategy specified but are not compatible")
107    else:
108      has_chief = kwargs.get("has_chief", False)
109      num_workers = kwargs.get("num_workers", 1)
110      runner = kwargs.get("runner", None)
111      share_gpu = kwargs.get("share_gpu", True)
112      num_ps = kwargs.get("num_ps", 0)
113
114    # Always set cluster parameters if they're requested. So that generate()
115    # works when there's no startegy in the combinations.
116    update = {}
117    if "has_chief" in requested_parameters:
118      update["has_chief"] = has_chief
119    if "num_workers" in requested_parameters:
120      update["num_workers"] = num_workers
121    if "runner" in requested_parameters:
122      update["runner"] = runner
123    if "share_gpu" in requested_parameters:
124      update["share_gpu"] = share_gpu
125    if "num_ps" in requested_parameters:
126      update["num_ps"] = num_ps
127    return update
128
129
130class DistributionCombination(combinations_lib.TestCombination):
131  """Sets up distribution strategy for tests."""
132
133  def should_execute_combination(self, kwargs):
134    distributions = [
135        v for v in kwargs.values() if isinstance(v, NamedDistribution)
136    ]
137    if test_util.is_xla_enabled() and any(d.no_xla for d in distributions):
138      return (
139          False,
140          "n/a: skipping strategy combination with no_xla=True in XLA tests")
141    return (True, None)
142
143  def parameter_modifiers(self):
144    return [
145        DistributionParameter(),
146        combinations_lib.OptionalParameter("use_var_policy"),
147    ]
148
149
150class ClusterCombination(combinations_lib.TestCombination):
151  """Sets up multi worker tests."""
152
153  def parameter_modifiers(self):
154    return [ClusterParameters()]
155
156
157class GPUCombination(combinations_lib.TestCombination):
158  """Enable tests to request GPU hardware and skip non-GPU combinations.
159
160  This class expects test_combinations to be generated with `NamedDistribution`
161  wrapping instances of `tf.distribute.Strategy`.
162
163  Optionally, the `required_gpus` argument is supported.  GPU hardware is
164  required, if its value is `True` or > 0.
165
166  Attributes:
167    GPU_TEST: The environment is considered to have GPU hardware available if
168              the name of the program contains "test_gpu" or "test_xla_gpu".
169  """
170
171  GPU_TEST = re.search(r"(test_2?gpu|test_xla_2?gpu)$", sys.argv[0])
172
173  def should_execute_combination(self, kwargs):
174    distributions = [
175        v for v in kwargs.values() if isinstance(v, NamedDistribution)
176    ]
177    required_gpus = kwargs.get("required_gpus", 0)
178    required_physical_gpus = kwargs.get("required_physical_gpus", 0)
179
180    if distributions and required_gpus:
181      raise ValueError("Do not use `required_gpus` and arguments of type "
182                       "NamedDistribution together.")
183
184    number_of_required_gpus = max(
185        [required_gpus] + [required_physical_gpus] +
186        [d.required_physical_gpus or 0 for d in distributions] +
187        [d.required_gpus or 0 for d in distributions])
188    number_of_required_physical_gpus = max(
189        [required_physical_gpus] +
190        [d.required_physical_gpus or 0 for d in distributions])
191
192    if (required_physical_gpus and required_gpus):
193      raise ValueError("Only one of `required_physical_gpus`(number of physical"
194                       " GPUs required) and `required_gpus`(total number of "
195                       "GPUs required) should be set. ")
196    if not number_of_required_gpus and GPUCombination.GPU_TEST:
197      return (False, "Test that doesn't require GPUs.")
198    elif (number_of_required_gpus > 0
199          and context.num_gpus() < number_of_required_gpus):
200      return (False, ("Only {} of {} required GPUs are available.".format(
201          context.num_gpus(), number_of_required_gpus)))
202    elif number_of_required_physical_gpus > len(
203        config.list_physical_devices("GPU")):
204      return (False,
205              ("Only {} of {} required physical GPUs are available.".format(
206                  config.list_physical_devices("GPU"), required_physical_gpus)))
207    else:
208      return (True, None)
209
210  def parameter_modifiers(self):
211    return [combinations_lib.OptionalParameter("required_gpus"),
212            combinations_lib.OptionalParameter("required_physical_gpus")]
213
214
215class TPUCombination(combinations_lib.TestCombination):
216  """Allow to request TPU hardware and skip non-TPU combinations.
217
218  This class expects test_combinations to be generated with `NamedDistribution`
219  wrapping instances of `tf.distribute.Strategy`.
220
221  Optionally, the `required_tpus` parameter is supported.  TPU hardware is
222  required, if its argument is `True` or > 0.
223
224  Optionally, the `use_cloud_tpu` parameter is supported. If TPU hardware is
225  required by `required_tpus`, it specifically must be a Cloud TPU (specified
226  with `--tpu`) if `use_cloud_tpu` is `True`.
227
228  Attributes:
229    TPU_TEST: The environment is considered to have TPU hardware available if
230              the name of the program contains "test_tpu".
231  """
232
233  TPU_TEST = "test_tpu" in sys.argv[0]
234
235  def should_execute_combination(self, kwargs):
236    distributions = [
237        v for v in kwargs.values() if isinstance(v, NamedDistribution)
238    ]
239    # TODO(isaprykin): Migrate all tests away from using 'required_tpu' in favor
240    # of 'required_tpus'.
241    if "required_tpus" in kwargs and "required_tpu" in kwargs:
242      raise ValueError("Do not use `required_tpu`.  Both `required_tpus` and "
243                       "`required_tpu` were specified.")
244    required_tpus = kwargs.get("required_tpus", None) or kwargs.get(
245        "required_tpu", None)
246
247    if distributions and required_tpus:
248      raise ValueError("Do not use `required_tpus` and arguments of type "
249                       "NamedDistribution together.")
250
251    # TODO(isaprykin): Add support for a particular number of TPUs.  Right now
252    # it's binary.
253    number_of_required_tpus = max([required_tpus or 0] +
254                                  [d.required_tpu or 0 for d in distributions])
255    use_cloud_tpu = any([kwargs.get("use_cloud_tpu")] +
256                        [d.use_cloud_tpu for d in distributions])
257    tpu = hasattr(flags.FLAGS, "tpu") and flags.FLAGS.tpu or ""
258
259    if not number_of_required_tpus and TPUCombination.TPU_TEST:
260      return (False, "Test that doesn't require TPUs.")
261    if number_of_required_tpus and not TPUCombination.TPU_TEST:
262      return (False, "Test requires a TPU, but it's not available.")
263    if use_cloud_tpu and not tpu:
264      return (False, "Test requires a Cloud TPU, but none specified.")
265    if not use_cloud_tpu and tpu:
266      return (False, "Test requires local TPU, but Cloud TPU specified.")
267    return (True, None)
268
269  def parameter_modifiers(self):
270    return [
271        combinations_lib.OptionalParameter("required_tpus"),
272        combinations_lib.OptionalParameter("required_tpu"),
273        combinations_lib.OptionalParameter("use_cloud_tpu"),
274    ]
275
276
277class NamedDistribution(object):
278  """Wraps a `tf.distribute.Strategy` and adds a name for test titles."""
279
280  def __init__(self,
281               name,
282               distribution_fn,
283               required_gpus=None,
284               required_physical_gpus=0,
285               required_tpu=False,
286               use_cloud_tpu=False,
287               has_chief=False,
288               num_workers=1,
289               num_ps=0,
290               share_gpu=True,
291               pool_runner_fn=None,
292               no_xla=False):
293    """Initialize NamedDistribution.
294
295    Args:
296      name: Name that will be a part of the name of the test case.
297      distribution_fn: A callable that creates a `tf.distribute.Strategy`.
298      required_gpus: The number of GPUs that the strategy requires. Only one of
299      `required_gpus` and `required_physical_gpus` should be set.
300      required_physical_gpus: Number of physical GPUs required. Only one of
301      `required_gpus` and `required_physical_gpus` should be set.
302      required_tpu: Whether the strategy requires TPU.
303      use_cloud_tpu: Whether the strategy requires cloud TPU.
304      has_chief: Whether the strategy requires a chief worker.
305      num_workers: The number of workers that the strategy requires.
306      num_ps: The number of parameter servers.
307      share_gpu: Whether to share GPUs among workers.
308      pool_runner_fn: An optional callable that returns a MultiProcessPoolRunner
309        to run the test.
310      no_xla: Whether to skip in XLA tests.
311    """
312    object.__init__(self)
313    self._name = name
314    self._distribution_fn = distribution_fn
315    self.required_gpus = required_gpus
316    self.required_physical_gpus = required_physical_gpus
317    self.required_tpu = required_tpu
318    self.use_cloud_tpu = use_cloud_tpu
319    self.has_chief = has_chief
320    self.num_workers = num_workers
321    self.num_ps = num_ps
322    self.share_gpu = share_gpu
323    self._pool_runner_fn = pool_runner_fn
324    self.no_xla = no_xla
325
326  @property
327  def runner(self):
328    if self._pool_runner_fn is not None:
329      return self._pool_runner_fn()
330    return None
331
332  @property
333  def strategy(self):
334    return self._distribution_fn()
335
336  def __repr__(self):
337    return self._name
338
339
340# This is to allow adding combinations that runs a function both as a
341# tf.function and eagerly.
342#
343# @combinations.generate(
344#   combinations.combine(
345#     tf_function = [combinations.tf_function, combinations.no_tf_function]
346#   )
347# )
348# def testXXX(tf_function):
349#   @tf_function
350#   def foo():
351#     tf.add(1., 1.)
352#
353#   foo()
354tf_function = combinations_lib.NamedObject("TfFunction", def_function.function)
355no_tf_function = combinations_lib.NamedObject("NoTfFunction", lambda f: f)
356
357
358def concat(*combined):
359  """Concats combinations."""
360  result = []
361  for one in combined:
362    result += one
363  return result
364
365
366@tf_export("__internal__.distribute.combinations.generate", v1=[])
367def generate(combinations, test_combinations=()):
368  # pylint: disable=g-doc-args,g-doc-return-or-yield
369  """Distributed adapter of `tf.__internal__.test.combinations.generate`.
370
371  All tests with distributed strategy should use this one instead of
372  `tf.__internal__.test.combinations.generate`. This function has support of
373  strategy combinations, GPU/TPU and multi worker support.
374
375  See `tf.__internal__.test.combinations.generate` for usage.
376  """
377  # pylint: enable=g-doc-args,g-doc-return-or-yield
378  default_combinations = (
379      framework_combinations.EagerGraphCombination(),
380      framework_combinations.TFVersionCombination(),
381      ClusterCombination(),
382      DistributionCombination(),
383      GPUCombination(),
384      TPUCombination(),
385  )
386  # We apply our own decoration to handle multi worker tests before applying
387  # framework.test_combinations.generate. The order is important since we need
388  # framework.test_combinations.generate to apply all parameter modifiers first.
389  combination_decorator = combinations_lib.generate(
390      combinations, test_combinations=default_combinations + test_combinations)
391
392  def decorator(test_method_or_class):
393    if isinstance(test_method_or_class, type):
394      # If it's a test class.
395      class_object = test_method_or_class
396      # Decorate each test method with _multi_worker_test.
397      for name, test_method in six.iteritems(class_object.__dict__.copy()):
398        if (name.startswith(unittest.TestLoader.testMethodPrefix) and
399            isinstance(test_method, types.FunctionType)):
400          setattr(class_object, name, _multi_worker_test(test_method))
401      return combination_decorator(class_object)
402    else:
403      return combination_decorator(_multi_worker_test(test_method_or_class))
404
405  return decorator
406
407
408combine = combinations_lib.combine
409times = combinations_lib.times
410NamedObject = combinations_lib.NamedObject
411
412
413# Identifies whether we're in the main process or worker processes.
414# `_multi_worker_test` decoration behaves differently in the main processs and
415# the worker processes. See the documentation of _multi_worker_test for detail.
416_running_in_worker = False
417
418
419def in_main_process():
420  """Whether it's in the main test process.
421
422  This is normally used to prepare the test environment which should only happen
423  in the main process.
424
425  Returns:
426    A boolean.
427  """
428  return not _running_in_worker
429
430
431class TestEnvironment(object):
432
433  def __init__(self):
434    self.tf_data_service_dispatcher = None
435    # Note that this includes GPUs that may not be visible to the current
436    # worker.
437    self.total_phsyical_gpus = None
438
439  def __setattr__(self, name, value):
440    if not in_main_process():
441      raise ValueError(
442          "combinations.env() should only be modified in the main process. "
443          "Condition your code on combinations.in_main_process().")
444    super().__setattr__(name, value)
445
446
447_env = TestEnvironment()
448
449
450def env():
451  """Returns the object holds the test environment information.
452
453  Tests should modifies this in the main process if needed, and it will be
454  passed to the worker processes each time a test case is ran.
455
456  Returns:
457    a TestEnvironment object.
458  """
459  return _env
460
461
462def _set_total_phsyical_gpus():
463  if in_main_process():
464    env().total_phsyical_gpus = len(
465        context.context().list_physical_devices("GPU"))
466
467
468# This is needed in case CUDA is lazily loaded.
469app.call_after_init(_set_total_phsyical_gpus)
470
471
472_TestResult = collections.namedtuple("_TestResult", ["status", "message"])
473
474
475def _test_runner(test_id, test_env):
476  """Executes the test with the given test_id.
477
478  This is a simple wrapper around TestRunner to be used with
479  multi_process_runner. Similar to test.main(), but it executes only one test
480  specified by test_id and returns whether the test succeeds. If the test fails,
481  the function prints failures and errors to stdout.
482
483  Args:
484    test_id: TestCase.id()
485    test_env: a TestEnvironment object.
486
487  Returns:
488    A boolean indicates whether the test succeeds.
489  """
490  global _running_in_worker, _env
491  # No need to restore the value of _running_in_worker since it should always be
492  # True in worker processes.
493  _running_in_worker = True
494  _env = test_env
495  test = unittest.defaultTestLoader.loadTestsFromName(test_id)
496  runner = unittest.TextTestRunner()
497  result = runner.run(test)
498  # Treat expected failures as failures, so that the main process can get
499  # them and fail as expected. Also treat errors as failures to simplify the
500  # handling.
501  failures = result.failures + result.expectedFailures + result.errors
502  if failures:
503    ret = _TestResult(status="failure", message=failures[0][1])
504  elif result.skipped:
505    ret = _TestResult(status="skipped", message=result.skipped[0][1])
506  else:
507    # Treat unexpectedSuccesses as OK so that the test case in the main process
508    # succeed as well.
509    ret = _TestResult(status="ok", message=None)
510  # Print tracebacks to stdout and multi_process_runner will collect
511  # them and stream back to the main process.
512  if ret.message:
513    print(ret.message)
514  return ret
515
516
517def _multi_worker_test(test_method):
518  """Decorate test_method so that it runs in each worker.
519
520  We use `multi_process_runner` to simulate multiple workers. Since we run the
521  this function in the main process and all worker processes, this decoration
522  behaves differently in the main process and worker procssses. In the main
523  process, it spawns subprocesses and runs the test on each of them; in a worker
524  process, it executes test in the same way as a normal test, e.g.
525  setUp()/tearDown() are called before/after the test.
526
527  Args:
528    test_method: a function which must be a test method.
529
530  Returns:
531    Decorated `test_method`. Note that the decorated function has additional
532    arguments.
533  """
534
535  def decorator(self, has_chief, num_workers, num_ps, share_gpu, runner,
536                **kwargs):
537    if _num_total_workers(has_chief,
538                          num_workers) == 1 or _running_in_worker or (
539                              # Use in-process cluster for PS combinations
540                              # when XLA is enabled.
541                              test_util.is_xla_enabled() and num_ps > 0):
542      # We're in worker process or the test is for single worker. Either case we
543      # execute the test method directly instead of spawning subprocesses.
544
545      # For MultiWorkerMirroredStrategy(CollectiveAllReduceStrategy), install a
546      # session that connects to the local server. This is necessary for multi
547      # worker graph mode tests to work. Those tests cannot use their graphs or
548      # sessions, including the one returned by self.cached_session(). Since
549      # existing tests may already be doing so, we only install the session for
550      # multi worker tests.
551      with _multi_worker_session(kwargs):
552        test_method(self, **kwargs)
553      return
554
555    # We're in the main process. We spawn subprocesses and run the *test* on
556    # each of them. Note that we're not directly executing test_method passed to
557    # _multi_worker_test, because we need setUp()/tearDown() to be called and
558    # all the decorations on the test method. The conceptual call stack is:
559    #   [main process]test.main()
560    #     [main process]test_runner.run(test)
561    #       [main process]wrapper by combinations.generate()
562    #         [main process]_multi_worker_test.decorator()
563    #           # A sub process goes through the same code path as the main
564    #           # process.
565    #           [sub process]_test_runner()
566    #             [sub process]test_runner.run(test)
567    #               [sub process]wrapper by combinations.generate()
568    #                 [sub process]_multi_worker_test.decorator()
569    #                   # _running_in_worker is True
570    #                   [sub process]test_method()
571    test_id = self.id()
572    if runner:
573      results = runner.run(_test_runner, args=(test_id, _env))
574    else:
575      cluster_spec = multi_worker_test_base.create_cluster_spec(
576          has_chief=has_chief,
577          num_workers=num_workers,
578          num_ps=num_ps,
579          has_eval=False)
580      ephemeral_runner = multi_process_runner.MultiProcessRunner(
581          _test_runner,
582          cluster_spec,
583          share_gpu=share_gpu,
584          args=(test_id, _env),
585          dependence_on_chief=has_chief)
586      ephemeral_runner.start()
587      results = ephemeral_runner.join().return_value
588
589    skip_reason = None
590    for result in results:
591      if result.status == "failure":
592        # We can't tell which worker the return value come from, so we fail on
593        # the  first error.
594        self.fail(result.message)
595        break
596      elif result.status == "skipped":
597        # Record the skip reason, but do not actually skip the test in case some
598        # processes fail instead.
599        skip_reason = result.message
600    if skip_reason is not None:
601      self.skipTest(skip_reason)
602
603  argspec = tf_inspect.getfullargspec(test_method)
604  decorator_args = (argspec.args or []) + [
605      "has_chief", "num_workers", "num_ps", "share_gpu", "runner"
606  ]
607  decorator_argspec = argspec._replace(args=decorator_args)
608  return tf_decorator.make_decorator(
609      test_method, decorator, decorator_argspec=decorator_argspec)
610
611
612def _num_total_workers(has_chief, num_workers):
613  """Returns the number of workers including the chief."""
614  if has_chief:
615    return num_workers + 1
616  return num_workers
617
618
619def _multi_worker_session(kwargs):
620  """Returns a context manager that enters a session that is configured for the MultiWorkerMirroredStrategy.
621
622  Args:
623    kwargs: a dict. Keyword arguments passed to the test.
624
625  Returns:
626    A context manager. If MultiWorkerMirroredStrategy is the  one and only one
627    strategy in kwargs and it's in graph mode, it's the seesion that is
628    configured for that strategy.  Otherwise, it's a no-op context manager.
629  """
630  strategy = None
631  for _, v in kwargs.items():
632    if isinstance(v, distribute_lib.StrategyBase):
633      if strategy is not None:
634        logging.warning(
635            "The test uses multiple strategies. Skipping "
636            "entering a session that is configured for the strategy.")
637        return ops.NullContextmanager()
638      strategy = v
639  if context.executing_eagerly() or not isinstance(
640      strategy, collective_all_reduce_strategy.CollectiveAllReduceStrategy):
641    return ops.NullContextmanager()
642  sess_config = copy.deepcopy(context.context().config)
643  sess_config = strategy.update_config_proto(sess_config)
644  target = strategy.cluster_resolver.master()
645  return session.Session(config=sess_config, target=target).as_default()
646