1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""This module customizes `test_combinations` for `tf.keras` related tests.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22 23from tensorflow.python import tf2 24from tensorflow.python.framework import combinations 25from tensorflow.python.framework import test_combinations 26from tensorflow.python.keras import testing_utils 27 28KERAS_MODEL_TYPES = ['functional', 'subclass', 'sequential'] 29 30 31def keras_mode_combinations(mode=None, run_eagerly=None): 32 """Returns the default test combinations for tf.keras tests. 33 34 Note that if tf2 is enabled, then v1 session test will be skipped. 35 36 Args: 37 mode: List of modes to run the tests. The valid options are 'graph' and 38 'eager'. Default to ['graph', 'eager'] if not specified. If a empty list 39 is provide, then the test will run under the context based on tf's 40 version, eg graph for v1 and eager for v2. 41 run_eagerly: List of `run_eagerly` value to be run with the tests. 42 Default to [True, False] if not specified. Note that for `graph` mode, 43 run_eagerly value will only be False. 44 45 Returns: 46 A list contains all the combinations to be used to generate test cases. 47 """ 48 if mode is None: 49 mode = ['eager'] if tf2.enabled() else ['graph', 'eager'] 50 if run_eagerly is None: 51 run_eagerly = [True, False] 52 result = [] 53 if 'eager' in mode: 54 result += combinations.combine(mode=['eager'], run_eagerly=run_eagerly) 55 if 'graph' in mode: 56 result += combinations.combine(mode=['graph'], run_eagerly=[False]) 57 return result 58 59 60def keras_model_type_combinations(): 61 return combinations.combine(model_type=KERAS_MODEL_TYPES) 62 63 64class KerasModeCombination(test_combinations.TestCombination): 65 """Combination for Keras test mode. 66 67 It by default includes v1_session, v2_eager and v2_tf_function. 68 """ 69 70 def context_managers(self, kwargs): 71 run_eagerly = kwargs.pop('run_eagerly', None) 72 73 if run_eagerly is not None: 74 return [testing_utils.run_eagerly_scope(run_eagerly)] 75 else: 76 return [] 77 78 def parameter_modifiers(self): 79 return [test_combinations.OptionalParameter('run_eagerly')] 80 81 82class KerasModelTypeCombination(test_combinations.TestCombination): 83 """Combination for Keras model types when doing model test. 84 85 It by default includes 'functional', 'subclass', 'sequential'. 86 87 Various methods in `testing_utils` to get models will auto-generate a model 88 of the currently active Keras model type. This allows unittests to confirm 89 the equivalence between different Keras models. 90 """ 91 92 def context_managers(self, kwargs): 93 model_type = kwargs.pop('model_type', None) 94 if model_type in KERAS_MODEL_TYPES: 95 return [testing_utils.model_type_scope(model_type)] 96 else: 97 return [] 98 99 def parameter_modifiers(self): 100 return [test_combinations.OptionalParameter('model_type')] 101 102 103_defaults = combinations.generate.keywords['test_combinations'] 104generate = functools.partial( 105 combinations.generate, 106 test_combinations=_defaults + 107 (KerasModeCombination(), KerasModelTypeCombination())) 108combine = test_combinations.combine 109times = test_combinations.times 110NamedObject = test_combinations.NamedObject 111