1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Keras SavedModel serialization. 16 17TODO (kathywu): Move to layer_serialization.py. Some model-specific logic should 18go to model_serialization.py. 19""" 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import functools 25import weakref 26 27from tensorflow.python.eager import def_function 28from tensorflow.python.framework import tensor_shape 29from tensorflow.python.framework import tensor_spec 30from tensorflow.python.keras import backend as K 31from tensorflow.python.keras.engine import base_layer_utils 32from tensorflow.python.keras.engine import input_spec 33from tensorflow.python.keras.mixed_precision import autocast_variable 34from tensorflow.python.keras.saving import saving_utils 35from tensorflow.python.keras.saving.saved_model import constants 36from tensorflow.python.keras.saving.saved_model import load as keras_load 37from tensorflow.python.keras.saving.saved_model import serialized_attributes 38from tensorflow.python.keras.saving.saved_model import utils 39from tensorflow.python.keras.utils import tf_inspect 40from tensorflow.python.keras.utils import version_utils 41from tensorflow.python.keras.utils.generic_utils import LazyLoader 42from tensorflow.python.platform import tf_logging as logging 43from tensorflow.python.training.tracking import base as trackable 44from tensorflow.python.training.tracking import data_structures 45from tensorflow.python.util import nest 46from tensorflow.python.util import tf_decorator 47 48# To avoid circular dependencies between keras/engine and keras/saving, 49# code in keras/saving must delay imports. 50 51# TODO(b/134426265): Switch back to single-quotes to match the rest of the file 52# once the issue with copybara is fixed. 53# pylint:disable=g-inconsistent-quotes 54base_layer = LazyLoader( 55 "base_layer", globals(), 56 "tensorflow.python.keras.engine.base_layer") 57metrics = LazyLoader("metrics", globals(), 58 "tensorflow.python.keras.metrics") 59input_layer = LazyLoader( 60 "input_layer", globals(), 61 "tensorflow.python.keras.engine.input_layer") 62training_lib = LazyLoader( 63 "training_lib", globals(), 64 "tensorflow.python.keras.engine.training") 65sequential_lib = LazyLoader( 66 "sequential_lib", globals(), 67 "tensorflow.python.keras.engine.sequential") 68# pylint:enable=g-inconsistent-quotes 69 70 71def should_skip_serialization(layer): 72 """Skip serializing extra objects and functions if layer inputs aren't set.""" 73 saved_model_input_spec_set = (isinstance(layer, training_lib.Model) and 74 layer._saved_model_inputs_spec is not None) # pylint: disable=protected-access 75 if not layer.built and not saved_model_input_spec_set: 76 logging.warning('Skipping full serialization of Keras layer {}, because ' 77 'it is not built.'.format(layer)) 78 return True 79 return False 80 81 82def wrap_layer_objects(layer, serialization_cache): 83 """Returns extra trackable objects to attach to the serialized layer. 84 85 Args: 86 layer: Keras Layer object. 87 serialization_cache: Dictionary shared between all objects during 88 serialization. 89 90 Returns: 91 A dictionary containing all checkpointable objects from a 92 SerializedAttributes object. See LayerAttributes and ModelAttributes for 93 entire list of objects 94 """ 95 # Wrap all regularization losses as tf.functions. 96 # First, generate list of all regularization losses in this layer and 97 # sublayers. 98 all_losses = layer._callable_losses[:] # pylint: disable=protected-access 99 for child_layer in utils.list_all_layers(layer): 100 all_losses.extend(child_layer._callable_losses) # pylint: disable=protected-access 101 # Next, wrap all loss functions as tf.functions. Use the serialization cache 102 # to store already-wrapped functions. 103 keras_loss_cache = serialization_cache.setdefault('keras_losses', {}) 104 wrapped_loss_functions = [] 105 for loss_fn in all_losses: 106 if loss_fn in keras_loss_cache: 107 wrapped_loss_functions.append(keras_loss_cache[loss_fn]) 108 else: 109 wrapped_loss = _wrap_unconditional_loss(loss_fn, len(keras_loss_cache)) 110 keras_loss_cache[loss_fn] = wrapped_loss 111 wrapped_loss_functions.append(wrapped_loss) 112 wrapped_layer_losses = [keras_loss_cache[fn] 113 for fn in layer._callable_losses[:]] # pylint: disable=protected-access 114 115 layer_metrics = data_structures._DictWrapper( # pylint: disable=protected-access 116 {m.name: m for m in layer._metrics}) # pylint: disable=protected-access 117 return dict( 118 variables=data_structures.ListWrapper(layer.variables), 119 trainable_variables=data_structures.ListWrapper( 120 layer.trainable_variables), 121 non_trainable_variables=data_structures.ListWrapper( 122 layer.non_trainable_variables), 123 layers=data_structures.ListWrapper(utils.list_all_layers(layer)), 124 metrics=data_structures.ListWrapper(layer.metrics), 125 regularization_losses=data_structures.ListWrapper( 126 wrapped_loss_functions), 127 layer_regularization_losses=data_structures.ListWrapper( 128 wrapped_layer_losses), 129 layer_metrics=layer_metrics) 130 # pylint: disable=protected-access 131 132 133def wrap_layer_functions(layer, serialization_cache): 134 """Returns dict of wrapped layer call function and losses in tf.functions. 135 136 Args: 137 layer: Keras Layer object. 138 serialization_cache: Dictionary shared between all objects during 139 serialization. 140 141 Returns: 142 A dictionary containing all keras tf.functions to serialize. See 143 LayerAttributes and ModelAttributes for the list of all attributes. 144 """ 145 # Since Sequential models may be modified in place using model.add() or 146 # model.pop(), don't use saved functions. 147 if (isinstance(layer, keras_load.RevivedLayer) and 148 not isinstance(layer, sequential_lib.Sequential)): 149 return {fn_name: getattr(layer.keras_api, fn_name, None) 150 for fn_name in serialized_attributes.LayerAttributes.all_functions} 151 152 # Reset the losses of the layer and its children. The call function in each 153 # child layer is replaced with tf.functions. 154 original_fns = _replace_child_layer_functions(layer, serialization_cache) 155 original_losses = _reset_layer_losses(layer) 156 157 # Wrap all the layer call and activity regularizer functions. 158 159 # Use LayerCallCollection to ensure that all layer call functions (__call__, 160 # call with losses) are traced with the same inputs. 161 call_collection = LayerCallCollection(layer) 162 call_fn_with_losses = call_collection.add_function( 163 _wrap_call_and_conditional_losses(layer), 164 '{}_layer_call_and_return_conditional_losses'.format(layer.name)) 165 call_fn = call_collection.add_function( 166 _extract_outputs_from_fn(layer, call_fn_with_losses), 167 '{}_layer_call_fn'.format(layer.name)) 168 169 fns = {'call_and_return_conditional_losses': call_fn_with_losses, 170 '__call__': call_fn} 171 172 if layer._activity_regularizer is not None: # pylint: disable=protected-access 173 fns['activity_regularizer_fn'] = _wrap_activity_regularizer(layer) 174 fns['call_and_return_all_conditional_losses'] = ( 175 call_collection.add_function( 176 _append_activity_regularizer_loss(layer, 177 call_fn_with_losses, 178 fns['activity_regularizer_fn']), 179 '{}_layer_call_and_return_all_conditional_losses'.format(layer.name) 180 )) 181 else: 182 fns['activity_regularizer_fn'] = None 183 fns['call_and_return_all_conditional_losses'] = call_fn_with_losses 184 185 # Manually trigger traces before restoring the overwritten functions. The 186 # functions are traced within the layer call context to ensure that layer 187 # functions (e.g. add_loss) behave as though running in graph mode. 188 with base_layer_utils.call_context().enter( 189 layer, inputs=None, build_graph=True, training=None, saving=True): 190 for fn in fns.values(): 191 if fn is not None and fn.input_signature is not None: 192 fn.get_concrete_function() 193 194 # Restore overwritten functions and losses 195 _restore_child_layer_functions(original_fns) 196 _restore_layer_losses(original_losses) 197 198 return fns 199 200 201def default_save_signature(layer): 202 original_losses = _reset_layer_losses(layer) 203 fn = saving_utils.trace_model_call(layer) 204 fn.get_concrete_function() 205 _restore_layer_losses(original_losses) 206 return fn 207 208 209def _replace_child_layer_functions(layer, serialization_cache): 210 """Replaces functions in the children layers with wrapped tf.functions. 211 212 This step allows functions from parent layers to reference the wrapped 213 functions from their children layers instead of retracing the ops. 214 215 This function also resets all losses stored in the layer. These are stored in 216 the returned dictionary. Use `_restore_child_layer_functions` to restore 217 the original attributes. 218 219 Args: 220 layer: Keras Layer object. 221 serialization_cache: Dictionary shared between all objects during 222 serialization. 223 224 Returns: 225 Dictionary mapping layer objects -> original functions and losses: 226 { Child layer 1: { 227 'losses': Original losses, 228 'call': Original call function 229 '_activity_regularizer': Original activity regularizer}, 230 Child layer 2: ... 231 } 232 """ 233 # pylint: disable=protected-access 234 original_fns = {} 235 236 def replace_layer_functions(child_layer, serialized_fns): 237 """Replaces layer call and activity regularizer with wrapped functions.""" 238 original_fns[child_layer] = { 239 'call': child_layer.call, 240 '_activity_regularizer': child_layer._activity_regularizer 241 } 242 with trackable.no_automatic_dependency_tracking_scope(child_layer): 243 try: 244 child_layer._activity_regularizer = serialized_fns.get( 245 'activity_regularizer_fn') 246 except AttributeError: 247 # Some layers have an unsettable activity regularizer. 248 pass 249 child_layer.call = utils.use_wrapped_call( 250 child_layer, 251 serialized_fns['call_and_return_conditional_losses'], 252 default_training_value=False) 253 254 def replace_metric_functions(child_layer, serialized_fns): 255 """Replaces metric functions with wrapped functions.""" 256 original_fns[child_layer] = { 257 '__call__': child_layer.__call__, 258 'result': child_layer.result, 259 'update_state': child_layer.update_state 260 } 261 with trackable.no_automatic_dependency_tracking_scope(child_layer): 262 child_layer.__call__ = serialized_fns['__call__'] 263 child_layer.result = serialized_fns['result'] 264 child_layer.update_state = serialized_fns['update_state'] 265 266 for child_layer in utils.list_all_layers(layer): 267 if isinstance(child_layer, input_layer.InputLayer): 268 continue 269 270 if child_layer not in serialization_cache[constants.KERAS_CACHE_KEY]: 271 serialized_functions = ( 272 child_layer._trackable_saved_model_saver._get_serialized_attributes( 273 serialization_cache).functions) 274 else: 275 serialized_functions = ( 276 serialization_cache[constants.KERAS_CACHE_KEY][child_layer].functions) 277 if not serialized_functions: 278 # This indicates either: 279 # - circular dependency, which means the current layer's functions 280 # should be wrapped first. 281 # - Child layer's inputs are not defined, so its functions have not been 282 # wrapped. In this case, no replacement is necessary so move on to the 283 # next child. 284 continue 285 286 if isinstance(child_layer, metrics.Metric): 287 replace_metric_functions(child_layer, serialized_functions) 288 else: 289 replace_layer_functions(child_layer, serialized_functions) 290 291 return original_fns 292 # pylint: enable=protected-access 293 294 295def _restore_child_layer_functions(original_fns): 296 """Restores attributes replaced with `_replace_child_layer_functions`.""" 297 for child_layer, fns in original_fns.items(): 298 with trackable.no_automatic_dependency_tracking_scope(child_layer): 299 for fn_name, fn in fns.items(): 300 try: 301 setattr(child_layer, fn_name, fn) # pylint: disable=protected-access 302 except AttributeError: 303 pass # In the case of _activity_regularizer, setting the attribute 304 # may be disallowed. 305 306 307# pylint: disable=protected-access 308def _reset_layer_losses(parent_layer): 309 """Resets losses of layer and its sublayers, and returns original losses.""" 310 losses_dict = {} 311 for layer in utils.list_all_layers_and_sublayers(parent_layer): 312 losses_dict[layer] = {'losses': layer._losses[:], 313 'eager_losses': layer._eager_losses[:]} 314 with trackable.no_automatic_dependency_tracking_scope(layer): 315 layer._losses = [] 316 layer._eager_losses = [] 317 return losses_dict 318 319 320def _restore_layer_losses(losses_dict): 321 for layer in losses_dict: 322 with trackable.no_automatic_dependency_tracking_scope(layer): 323 layer._losses = losses_dict[layer]['losses'] 324 layer._eager_losses = losses_dict[layer]['eager_losses'] 325# pylint: enable=protected-access 326 327 328class LayerCallCollection(object): 329 """Groups wrapped layer call functions. 330 331 This is used to ensure that all layer call functions are traced with the same 332 inputs- 333 - call 334 - call_and_return_conditional_losses 335 - call_and_return_all_conditional_losses 336 """ 337 338 def __init__(self, layer): 339 self.layer = layer 340 341 self.layer_call_method = _get_layer_call_method(layer) 342 self._expects_training_arg = utils.layer_uses_training_bool(layer) 343 self._training_arg_index = utils.get_training_arg_index( 344 self.layer_call_method) 345 346 # If the layer call function has kwargs, then the traced function cannot 347 # have an input signature. 348 arg_spec = tf_inspect.getfullargspec(self.layer_call_method) 349 self._has_kwargs = bool(self._expects_training_arg or 350 arg_spec.defaults or 351 arg_spec.kwonlyargs or 352 arg_spec.varkw) 353 354 self._input_signature = self._generate_input_signature(layer) 355 self._functions = weakref.WeakValueDictionary() 356 # Bool indicating whether this object is currently tracing the layer call 357 # functions. 358 self.tracing = False 359 360 # Get the input argument name from the args. 361 args = arg_spec.args 362 if tf_inspect.ismethod(self.layer_call_method): 363 args = args[1:] 364 self._input_arg_name = args[0] if args else 'inputs' 365 366 def _generate_input_signature(self, layer): 367 """Inspects layer object and returns the inferred input signature. 368 369 Args: 370 layer: Layer object. 371 372 Returns: 373 List of possibly nested TensorSpecs of the layer call function inputs. 374 The list does not contain the `training` argument. 375 """ 376 if (isinstance(layer.call, def_function.Function) and 377 layer.call.input_signature is not None): 378 return layer.call.input_signature 379 elif isinstance(layer, training_lib.Model): 380 return saving_utils.model_input_signature(layer) 381 elif (layer.input_spec is not None and 382 layer._use_input_spec_as_call_signature): # pylint: disable=protected-access 383 384 def to_tensor_spec_or_none(x): 385 spec = input_spec.to_tensor_spec(x, layer._compute_dtype) # pylint: disable=protected-access 386 # If the shape is too general (e.g. multiple dimensions are allowed), 387 # return None so that separate functions can be generated for each 388 # inferred input signature. 389 # TODO(b/134962016): currently partial signatures are not supported. 390 if spec.shape == tensor_shape.TensorShape(None): 391 return None 392 return spec 393 input_signature = [nest.map_structure( 394 to_tensor_spec_or_none, layer.input_spec)] 395 396 return input_signature 397 else: 398 return None 399 400 def add_trace(self, *args, **kwargs): 401 """Traces all functions with the same args and kwargs. 402 403 Args: 404 *args: Positional args passed to the original function. 405 **kwargs: Keyword args passed to the original function. 406 """ 407 args = list(args) 408 kwargs = kwargs.copy() 409 self.tracing = True 410 for fn in self._functions.values(): 411 # TODO(kathywu): Replace arguments with broader shapes defined in the 412 # input signature. 413 if self._expects_training_arg: 414 def trace_with_training(value, fn=fn): 415 utils.set_training_arg(value, self._training_arg_index, args, kwargs) 416 with K.deprecated_internal_learning_phase_scope(value): 417 fn.get_concrete_function(*args, **kwargs) 418 419 trace_with_training(True) 420 trace_with_training(False) 421 else: 422 fn.get_concrete_function(*args, **kwargs) 423 self.tracing = False 424 425 @property 426 def fn_input_signature(self): 427 """Returns input signature for the wrapped layer call function.""" 428 if self._has_kwargs: 429 # Input signatures may only describe tensor arguments and kwargs are not 430 # supported. 431 return None 432 if None in nest.flatten(self._input_signature): 433 # TODO(b/134962016): If input signature cannot be partially defined. 434 return None 435 return self._input_signature 436 437 def training_arg_was_passed(self, args, kwargs): 438 if not self.layer._expects_training_arg and self._expects_training_arg: # pylint: disable=protected-access 439 return (utils.get_training_arg(self._training_arg_index, args, kwargs) 440 is not None) 441 else: 442 return self.layer._call_arg_was_passed( # pylint: disable=protected-access 443 'training', args, kwargs, inputs_in_args=True) 444 445 def get_training_arg_value(self, args, kwargs): 446 if not self.layer._expects_training_arg and self._expects_training_arg: # pylint: disable=protected-access 447 return utils.get_training_arg(self._training_arg_index, args, kwargs) 448 else: 449 return self.layer._get_call_arg_value( # pylint: disable=protected-access 450 'training', args, kwargs, inputs_in_args=True) 451 452 def get_input_arg_value(self, args, kwargs): 453 return self.layer._get_call_arg_value( # pylint: disable=protected-access 454 self._input_arg_name, args, kwargs, inputs_in_args=True) 455 456 def _maybe_wrap_with_training_arg(self, call_fn): 457 """Wraps call function with added training argument if necessary.""" 458 if not self.layer._expects_training_arg and self._expects_training_arg: # pylint: disable=protected-access 459 # Add training arg to wrapper function. 460 arg_spec = tf_inspect.getfullargspec(call_fn) 461 args = arg_spec.args + ['training'] 462 defaults = list(arg_spec.defaults or []) 463 defaults.append(False) 464 new_arg_spec = tf_inspect.FullArgSpec( 465 args=args, 466 varargs=arg_spec.varargs, 467 varkw=arg_spec.varkw, 468 defaults=defaults, 469 kwonlyargs=arg_spec.kwonlyargs, 470 kwonlydefaults=arg_spec.kwonlydefaults, 471 annotations=arg_spec.annotations) 472 473 # Set new training arg index 474 self._training_arg_index = len(args) - 1 475 if tf_inspect.ismethod(call_fn): 476 self._training_arg_index -= 1 477 478 def wrap_with_training_arg(*args, **kwargs): 479 # Remove the training value, since the original call_fn does not expect 480 # a training arg. Instead, the training value will be propagated using 481 # the call context created in LayerCall. 482 args = list(args) 483 kwargs = kwargs.copy() 484 utils.remove_training_arg(self._training_arg_index, args, kwargs) 485 return call_fn(*args, **kwargs) 486 487 return tf_decorator.make_decorator( 488 target=call_fn, 489 decorator_func=wrap_with_training_arg, 490 decorator_argspec=new_arg_spec) 491 492 return call_fn 493 494 def add_function(self, call_fn, name): 495 """Adds a layer call function to the collection.""" 496 self._functions[name] = fn = LayerCall( 497 self, self._maybe_wrap_with_training_arg(call_fn), name, 498 input_signature=self.fn_input_signature) 499 500 if (None not in nest.flatten(self._input_signature) and 501 self._has_kwargs): 502 # Manually add traces for layers that have keyword arguments and have 503 # a fully defined input signature. 504 self.add_trace(*self._input_signature) 505 return fn 506 507 508def layer_call_wrapper(call_collection, method): 509 """Ensures layer losses are kept the same, and runs method in call context.""" 510 def wrapper(*args, **kwargs): 511 """Calls method within call context.""" 512 layer = call_collection.layer 513 training = None 514 inputs = call_collection.get_input_arg_value(args, kwargs) 515 # pylint: disable=protected-access 516 if (args or kwargs) and call_collection.training_arg_was_passed( 517 args, kwargs): 518 training = call_collection.get_training_arg_value(args, kwargs) 519 # pylint: enable=protected-access 520 original_losses = _reset_layer_losses(layer) 521 with base_layer_utils.call_context().enter( 522 layer, inputs=inputs, build_graph=False, training=training, 523 saving=True): 524 with autocast_variable.enable_auto_cast_variables( 525 layer._compute_dtype_object): # pylint: disable=protected-access 526 ret = method(*args, **kwargs) 527 _restore_layer_losses(original_losses) 528 return ret 529 return tf_decorator.make_decorator(target=method, decorator_func=wrapper) 530 531 532class LayerCall(def_function.Function): 533 """Function that triggers traces of other functions in the same collection.""" 534 535 def __init__(self, call_collection, python_function, *args, **kwargs): 536 self.call_collection = call_collection 537 self.original_call = call_collection.layer_call_method 538 python_function = layer_call_wrapper(call_collection, python_function) 539 super(LayerCall, self).__init__(python_function, *args, **kwargs) 540 541 def __call__(self, *args, **kwargs): 542 if not self.call_collection.tracing: 543 self.call_collection.add_trace(*args, **kwargs) 544 return super(LayerCall, self).__call__(*args, **kwargs) 545 546 def get_concrete_function(self, *args, **kwargs): 547 if not self.call_collection.tracing: 548 self.call_collection.add_trace(*args, **kwargs) 549 return super(LayerCall, self).get_concrete_function(*args, **kwargs) 550 551 552def _wrap_call_and_conditional_losses(layer): 553 """Wraps call function that returns a tuple of (outputs, losses). 554 555 The losses returned are conditional on the inputs passed to the call function. 556 Unconditional losses (e.g. weight regularizeration) are wrapped separately. 557 558 Args: 559 layer: a Keras layer object 560 561 Returns: 562 python call function that returns outputs and conditional losses -- excludes 563 activity regularizer 564 """ 565 # Create function that generates both outputs and losses 566 layer_call = _get_layer_call_method(layer) 567 def call_and_return_conditional_losses(inputs, *args, **kwargs): 568 """Returns layer (call_output, conditional losses) tuple.""" 569 call_output = layer_call(inputs, *args, **kwargs) 570 if version_utils.is_v1_layer_or_model(layer): 571 conditional_losses = layer.get_losses_for(inputs) 572 else: 573 conditional_losses = [ 574 l for l in layer.losses if not hasattr(l, '_unconditional_loss') 575 ] 576 return call_output, conditional_losses 577 578 return _create_call_fn_decorator(layer, call_and_return_conditional_losses) 579 580 581def _extract_outputs_from_fn(layer, call_and_return_conditional_losses): 582 """Returns a function that returns only call function outputs.""" 583 if isinstance(layer, keras_load.RevivedLayer): 584 return layer.keras_api.__call__ # pylint: disable=protected-access 585 def call(inputs, *args, **kwargs): 586 return call_and_return_conditional_losses(inputs, *args, **kwargs)[0] 587 return _create_call_fn_decorator(layer, call) 588 589 590def _append_activity_regularizer_loss( 591 layer, call_fn_with_losses, activity_regularizer_fn): 592 """Appends activity regularizer loss to losses returned by the wrapped fn.""" 593 def fn(inputs, *args, **kwargs): 594 outputs, losses = call_fn_with_losses(inputs, *args, **kwargs) 595 losses.append(activity_regularizer_fn(outputs)) 596 return outputs, losses 597 return _create_call_fn_decorator(layer, fn) 598 599 600def _create_call_fn_decorator(layer, wrapped_call): 601 call_fn = _get_layer_call_method(layer) 602 fn, arg_spec = utils.maybe_add_training_arg( 603 call_fn, wrapped_call, layer._expects_training_arg, # pylint: disable=protected-access 604 default_training_value=False) 605 return tf_decorator.make_decorator( 606 target=call_fn, 607 decorator_func=fn, 608 decorator_argspec=arg_spec) 609 610 611def _wrap_unconditional_loss(loss_fn, index): 612 """Wraps callable/unconditional loss, returning a serializable function.""" 613 # Extract original loss function from partial function 614 fn = loss_fn.args[0] if isinstance(loss_fn, functools.partial) else loss_fn 615 if isinstance(fn, def_function.Function): 616 return fn 617 else: 618 return def_function.Function( 619 fn, 'loss_fn_{}'.format(index), input_signature=[]) 620 621 622def _wrap_activity_regularizer(layer): 623 """Wraps the activity regularizer.""" 624 # pylint: disable=protected-access 625 if isinstance(layer._activity_regularizer, def_function.Function): 626 return layer._activity_regularizer 627 return def_function.Function( 628 layer._activity_regularizer, 629 '{}_activity_regularizer'.format(layer.name), 630 input_signature=[ 631 tensor_spec.TensorSpec(None, layer._compute_dtype or K.floatx()) 632 ]) 633 # pylint: enable=protected-access 634 635 636def _get_layer_call_method(layer): 637 if isinstance(layer.call, (def_function.Function)): 638 return layer.call.python_function 639 return layer.call 640