1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utility functions shared between SavedModel saving/loading implementations.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import itertools 21import threading 22import types 23 24from tensorflow.python.eager import context 25from tensorflow.python.keras import backend as K 26from tensorflow.python.keras.engine import base_layer_utils 27from tensorflow.python.keras.utils import control_flow_util 28from tensorflow.python.keras.utils import tf_contextlib 29from tensorflow.python.keras.utils import tf_inspect 30from tensorflow.python.keras.utils.generic_utils import LazyLoader 31from tensorflow.python.util import tf_decorator 32 33 34# pylint:disable=g-inconsistent-quotes 35training_lib = LazyLoader( 36 "training_lib", globals(), 37 "tensorflow.python.keras.engine.training") 38# pylint:enable=g-inconsistent-quotes 39 40 41def use_wrapped_call(layer, call_fn, default_training_value=None, 42 return_method=False): 43 """Creates fn that adds the losses returned by call_fn & returns the outputs. 44 45 Args: 46 layer: A Keras layer object 47 call_fn: tf.function that takes layer inputs (and possibly a training arg), 48 and returns a tuple of (outputs, list of losses). 49 default_training_value: Default value of the training kwarg. If `None`, the 50 default is `K.learning_phase()`. 51 return_method: Whether to return a method bound to the layer. 52 53 Returns: 54 function that calls call_fn and returns the outputs. Losses returned by 55 call_fn are added to the layer losses. 56 """ 57 expects_training_arg = layer_uses_training_bool(layer) 58 if hasattr(call_fn, 'original_call'): # call_fn is a LayerCall object 59 original_call = call_fn.original_call 60 # In Python 3, callable objects are not compatible with inspect.getargspec 61 call_fn = call_fn.__call__ 62 else: 63 original_call = call_fn 64 fn, arg_spec = maybe_add_training_arg( 65 original_call, call_fn, expects_training_arg, default_training_value) 66 67 def return_outputs_and_add_losses(*args, **kwargs): 68 """Returns the outputs from the call_fn, and adds the losses.""" 69 inputs_arg_index = 1 if return_method else 0 70 inputs = args[inputs_arg_index] 71 args = args[inputs_arg_index + 1:] 72 outputs, losses = fn(inputs, *args, **kwargs) 73 layer.add_loss(losses, inputs=inputs) 74 75 # TODO(kathywu): This is a temporary hack. When a network of layers is 76 # revived from SavedModel, only the top-level layer will have losses. This 77 # causes issues in eager mode because the child layers may have graph losses 78 # (thus model.losses returns a mix of Eager and graph tensors). To fix this, 79 # whenever eager losses are added to one layer, add eager losses to all 80 # child layers. This causes `.losses` to only return eager losses. 81 # pylint: disable=protected-access 82 if context.executing_eagerly(): 83 for i in layer._flatten_layers(): 84 if i is not layer: 85 i._eager_losses = [base_layer_utils.REVIVED_LOSS_PLACEHOLDER] 86 # pylint: enable=protected-access 87 return outputs 88 89 decorated = tf_decorator.make_decorator( 90 target=call_fn, 91 decorator_func=return_outputs_and_add_losses, 92 decorator_argspec=arg_spec) 93 94 if return_method: 95 return types.MethodType(decorated, layer) 96 else: 97 return decorated 98 99 100def layer_uses_training_bool(layer): 101 """Returns whether this layer or any of its children uses the training arg.""" 102 if layer._expects_training_arg: # pylint: disable=protected-access 103 return True 104 visited = {layer} 105 to_visit = list_all_layers(layer) 106 while to_visit: 107 layer = to_visit.pop() 108 if layer in visited: 109 continue 110 if getattr(layer, '_expects_training_arg', True): 111 return True 112 visited.add(layer) 113 to_visit.extend(list_all_layers(layer)) 114 return False 115 116 117def list_all_layers(obj): 118 if isinstance(obj, training_lib.Model): 119 # Handle special case of Sequential, which doesn't return 120 # the `Input` layer. 121 return obj.layers 122 else: 123 return list(obj._flatten_layers(include_self=False, recursive=False)) # pylint: disable=protected-access 124 125 126def list_all_layers_and_sublayers(obj): 127 s = set([obj]) 128 s.update(itertools.chain.from_iterable( 129 list_all_layers_and_sublayers(layer) for layer in list_all_layers(obj))) 130 return s 131 132 133def maybe_add_training_arg( 134 original_call, wrapped_call, expects_training_arg, default_training_value): 135 """Decorate call and optionally adds training argument. 136 137 If a layer expects a training argument, this function ensures that 'training' 138 is present in the layer args or kwonly args, with the default training value. 139 140 Args: 141 original_call: Original call function. 142 wrapped_call: Wrapped call function. 143 expects_training_arg: Whether to include 'training' argument. 144 default_training_value: Default value of the training kwarg to include in 145 the arg spec. If `None`, the default is `K.learning_phase()`. 146 147 Returns: 148 Tuple of ( 149 function that calls `wrapped_call` and sets the training arg, 150 Argspec of returned function or `None` if the argspec is unchanged) 151 """ 152 if not expects_training_arg: 153 return wrapped_call, None 154 155 def wrap_with_training_arg(*args, **kwargs): 156 """Wrap the `wrapped_call` function, and set training argument.""" 157 training_arg_index = get_training_arg_index(original_call) 158 training = get_training_arg(training_arg_index, args, kwargs) 159 if training is None: 160 training = default_training_value or K.learning_phase() 161 162 args = list(args) 163 kwargs = kwargs.copy() 164 165 def replace_training_and_call(training): 166 set_training_arg(training, training_arg_index, args, kwargs) 167 return wrapped_call(*args, **kwargs) 168 169 return control_flow_util.smart_cond( 170 training, lambda: replace_training_and_call(True), 171 lambda: replace_training_and_call(False)) 172 173 # Create arg spec for decorated function. If 'training' is not defined in the 174 # args of the original arg spec, then add it to kwonlyargs. 175 arg_spec = tf_inspect.getfullargspec(original_call) 176 defaults = list(arg_spec.defaults) if arg_spec.defaults is not None else [] 177 178 kwonlyargs = arg_spec.kwonlyargs 179 kwonlydefaults = arg_spec.kwonlydefaults or {} 180 # Add training arg if it does not exist, or set the default training value. 181 if 'training' not in arg_spec.args: 182 kwonlyargs.append('training') 183 kwonlydefaults['training'] = default_training_value 184 else: 185 index = arg_spec.args.index('training') 186 training_default_index = len(arg_spec.args) - index 187 if (arg_spec.defaults and 188 len(arg_spec.defaults) >= training_default_index and 189 defaults[-training_default_index] is None): 190 defaults[-training_default_index] = default_training_value 191 192 decorator_argspec = tf_inspect.FullArgSpec( 193 args=arg_spec.args, 194 varargs=arg_spec.varargs, 195 varkw=arg_spec.varkw, 196 defaults=defaults, 197 kwonlyargs=kwonlyargs, 198 kwonlydefaults=kwonlydefaults, 199 annotations=arg_spec.annotations) 200 return wrap_with_training_arg, decorator_argspec 201 202 203def get_training_arg_index(call_fn): 204 """Returns the index of 'training' in the layer call function arguments. 205 206 Args: 207 call_fn: Call function. 208 209 Returns: 210 - n: index of 'training' in the call function arguments. 211 - -1: if 'training' is not found in the arguments, but layer.call accepts 212 variable keyword arguments 213 - None: if layer doesn't expect a training argument. 214 """ 215 arg_list = tf_inspect.getfullargspec(call_fn).args 216 if tf_inspect.ismethod(call_fn): 217 arg_list = arg_list[1:] 218 if 'training' in arg_list: 219 return arg_list.index('training') 220 else: 221 return -1 222 223 224def set_training_arg(training, index, args, kwargs): 225 if index is None: 226 pass 227 elif index >= 0 and len(args) > index: 228 args[index] = training 229 else: 230 kwargs['training'] = training 231 return args, kwargs 232 233 234def get_training_arg(index, args, kwargs): 235 if index is None: 236 return None 237 elif index >= 0 and len(args) > index: 238 return args[index] 239 else: 240 return kwargs.get('training', None) 241 242 243def remove_training_arg(index, args, kwargs): 244 if index is None: 245 pass 246 elif index >= 0 and len(args) > index: 247 args.pop(index) 248 else: 249 kwargs.pop('training', None) 250 251 252class SaveOptionsContext(threading.local): 253 254 def __init__(self): 255 super(SaveOptionsContext, self).__init__() 256 self.save_traces = True 257 258 259_save_options_context = SaveOptionsContext() 260 261 262@tf_contextlib.contextmanager 263def keras_option_scope(save_traces): 264 previous_value = _save_options_context.save_traces 265 try: 266 _save_options_context.save_traces = save_traces 267 yield 268 finally: 269 _save_options_context.save_traces = previous_value 270 271 272def should_save_traces(): 273 return _save_options_context.save_traces 274