• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""This module customizes `test_combinations` for `tf.keras` related tests."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22
23from tensorflow.python import tf2
24from tensorflow.python.framework import combinations
25from tensorflow.python.framework import test_combinations
26from tensorflow.python.keras import testing_utils
27
28KERAS_MODEL_TYPES = ['functional', 'subclass', 'sequential']
29
30
31def keras_mode_combinations(mode=None, run_eagerly=None):
32  """Returns the default test combinations for tf.keras tests.
33
34  Note that if tf2 is enabled, then v1 session test will be skipped.
35
36  Args:
37    mode: List of modes to run the tests. The valid options are 'graph' and
38      'eager'. Default to ['graph', 'eager'] if not specified. If a empty list
39      is provide, then the test will run under the context based on tf's
40      version, eg graph for v1 and eager for v2.
41    run_eagerly: List of `run_eagerly` value to be run with the tests.
42      Default to [True, False] if not specified. Note that for `graph` mode,
43      run_eagerly value will only be False.
44
45  Returns:
46    A list contains all the combinations to be used to generate test cases.
47  """
48  if mode is None:
49    mode = ['eager'] if tf2.enabled() else ['graph', 'eager']
50  if run_eagerly is None:
51    run_eagerly = [True, False]
52  result = []
53  if 'eager' in mode:
54    result += combinations.combine(mode=['eager'], run_eagerly=run_eagerly)
55  if 'graph' in mode:
56    result += combinations.combine(mode=['graph'], run_eagerly=[False])
57  return result
58
59
60def keras_model_type_combinations():
61  return combinations.combine(model_type=KERAS_MODEL_TYPES)
62
63
64class KerasModeCombination(test_combinations.TestCombination):
65  """Combination for Keras test mode.
66
67  It by default includes v1_session, v2_eager and v2_tf_function.
68  """
69
70  def context_managers(self, kwargs):
71    run_eagerly = kwargs.pop('run_eagerly', None)
72
73    if run_eagerly is not None:
74      return [testing_utils.run_eagerly_scope(run_eagerly)]
75    else:
76      return []
77
78  def parameter_modifiers(self):
79    return [test_combinations.OptionalParameter('run_eagerly')]
80
81
82class KerasModelTypeCombination(test_combinations.TestCombination):
83  """Combination for Keras model types when doing model test.
84
85  It by default includes 'functional', 'subclass', 'sequential'.
86
87  Various methods in `testing_utils` to get models will auto-generate a model
88  of the currently active Keras model type. This allows unittests to confirm
89  the equivalence between different Keras models.
90  """
91
92  def context_managers(self, kwargs):
93    model_type = kwargs.pop('model_type', None)
94    if model_type in KERAS_MODEL_TYPES:
95      return [testing_utils.model_type_scope(model_type)]
96    else:
97      return []
98
99  def parameter_modifiers(self):
100    return [test_combinations.OptionalParameter('model_type')]
101
102
103_defaults = combinations.generate.keywords['test_combinations']
104generate = functools.partial(
105    combinations.generate,
106    test_combinations=_defaults +
107    (KerasModeCombination(), KerasModelTypeCombination()))
108combine = test_combinations.combine
109times = test_combinations.times
110NamedObject = test_combinations.NamedObject
111