• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras SavedModel serialization.
16
17TODO (kathywu): Move to layer_serialization.py. Some model-specific logic should
18go to model_serialization.py.
19"""
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import functools
25import weakref
26
27from tensorflow.python.eager import def_function
28from tensorflow.python.framework import tensor_shape
29from tensorflow.python.framework import tensor_spec
30from tensorflow.python.keras import backend as K
31from tensorflow.python.keras.engine import base_layer_utils
32from tensorflow.python.keras.engine import input_spec
33from tensorflow.python.keras.mixed_precision import autocast_variable
34from tensorflow.python.keras.saving import saving_utils
35from tensorflow.python.keras.saving.saved_model import constants
36from tensorflow.python.keras.saving.saved_model import load as keras_load
37from tensorflow.python.keras.saving.saved_model import serialized_attributes
38from tensorflow.python.keras.saving.saved_model import utils
39from tensorflow.python.keras.utils import tf_inspect
40from tensorflow.python.keras.utils import version_utils
41from tensorflow.python.keras.utils.generic_utils import LazyLoader
42from tensorflow.python.platform import tf_logging as logging
43from tensorflow.python.training.tracking import base as trackable
44from tensorflow.python.training.tracking import data_structures
45from tensorflow.python.util import nest
46from tensorflow.python.util import tf_decorator
47
48# To avoid circular dependencies between keras/engine and keras/saving,
49# code in keras/saving must delay imports.
50
51# TODO(b/134426265): Switch back to single-quotes to match the rest of the file
52# once the issue with copybara is fixed.
53# pylint:disable=g-inconsistent-quotes
54base_layer = LazyLoader(
55    "base_layer", globals(),
56    "tensorflow.python.keras.engine.base_layer")
57metrics = LazyLoader("metrics", globals(),
58                     "tensorflow.python.keras.metrics")
59input_layer = LazyLoader(
60    "input_layer", globals(),
61    "tensorflow.python.keras.engine.input_layer")
62training_lib = LazyLoader(
63    "training_lib", globals(),
64    "tensorflow.python.keras.engine.training")
65sequential_lib = LazyLoader(
66    "sequential_lib", globals(),
67    "tensorflow.python.keras.engine.sequential")
68# pylint:enable=g-inconsistent-quotes
69
70
71def should_skip_serialization(layer):
72  """Skip serializing extra objects and functions if layer inputs aren't set."""
73  saved_model_input_spec_set = (isinstance(layer, training_lib.Model) and
74                                layer._saved_model_inputs_spec is not None)  # pylint: disable=protected-access
75  if not layer.built and not saved_model_input_spec_set:
76    logging.warning('Skipping full serialization of Keras layer {}, because '
77                    'it is not built.'.format(layer))
78    return True
79  return False
80
81
82def wrap_layer_objects(layer, serialization_cache):
83  """Returns extra trackable objects to attach to the serialized layer.
84
85  Args:
86    layer: Keras Layer object.
87    serialization_cache: Dictionary shared between all objects during
88      serialization.
89
90  Returns:
91    A dictionary containing all checkpointable objects from a
92    SerializedAttributes object. See LayerAttributes and ModelAttributes for
93    entire list of objects
94  """
95  # Wrap all regularization losses as tf.functions.
96  # First, generate list of all regularization losses in this layer and
97  # sublayers.
98  all_losses = layer._callable_losses[:]  # pylint: disable=protected-access
99  for child_layer in utils.list_all_layers(layer):
100    all_losses.extend(child_layer._callable_losses)  # pylint: disable=protected-access
101  # Next, wrap all loss functions as tf.functions. Use the serialization cache
102  # to store already-wrapped functions.
103  keras_loss_cache = serialization_cache.setdefault('keras_losses', {})
104  wrapped_loss_functions = []
105  for loss_fn in all_losses:
106    if loss_fn in keras_loss_cache:
107      wrapped_loss_functions.append(keras_loss_cache[loss_fn])
108    else:
109      wrapped_loss = _wrap_unconditional_loss(loss_fn, len(keras_loss_cache))
110      keras_loss_cache[loss_fn] = wrapped_loss
111      wrapped_loss_functions.append(wrapped_loss)
112  wrapped_layer_losses = [keras_loss_cache[fn]
113                          for fn in layer._callable_losses[:]]  # pylint: disable=protected-access
114
115  layer_metrics = data_structures._DictWrapper(  # pylint: disable=protected-access
116      {m.name: m for m in layer._metrics})  # pylint: disable=protected-access
117  return dict(
118      variables=data_structures.ListWrapper(layer.variables),
119      trainable_variables=data_structures.ListWrapper(
120          layer.trainable_variables),
121      non_trainable_variables=data_structures.ListWrapper(
122          layer.non_trainable_variables),
123      layers=data_structures.ListWrapper(utils.list_all_layers(layer)),
124      metrics=data_structures.ListWrapper(layer.metrics),
125      regularization_losses=data_structures.ListWrapper(
126          wrapped_loss_functions),
127      layer_regularization_losses=data_structures.ListWrapper(
128          wrapped_layer_losses),
129      layer_metrics=layer_metrics)
130  # pylint: disable=protected-access
131
132
133def wrap_layer_functions(layer, serialization_cache):
134  """Returns dict of wrapped layer call function and losses in tf.functions.
135
136  Args:
137    layer: Keras Layer object.
138    serialization_cache: Dictionary shared between all objects during
139      serialization.
140
141  Returns:
142    A dictionary containing all keras tf.functions to serialize. See
143    LayerAttributes and ModelAttributes for the list of all attributes.
144  """
145  # Since Sequential models may be modified in place using model.add() or
146  # model.pop(), don't use saved functions.
147  if (isinstance(layer, keras_load.RevivedLayer) and
148      not isinstance(layer, sequential_lib.Sequential)):
149    return {fn_name: getattr(layer.keras_api, fn_name, None)
150            for fn_name in serialized_attributes.LayerAttributes.all_functions}
151
152  # Reset the losses of the layer and its children. The call function in each
153  # child layer is replaced with tf.functions.
154  original_fns = _replace_child_layer_functions(layer, serialization_cache)
155  original_losses = _reset_layer_losses(layer)
156
157  # Wrap all the layer call and activity regularizer functions.
158
159  # Use LayerCallCollection to ensure that all layer call functions (__call__,
160  # call with losses) are traced with the same inputs.
161  call_collection = LayerCallCollection(layer)
162  call_fn_with_losses = call_collection.add_function(
163      _wrap_call_and_conditional_losses(layer),
164      '{}_layer_call_and_return_conditional_losses'.format(layer.name))
165  call_fn = call_collection.add_function(
166      _extract_outputs_from_fn(layer, call_fn_with_losses),
167      '{}_layer_call_fn'.format(layer.name))
168
169  fns = {'call_and_return_conditional_losses': call_fn_with_losses,
170         '__call__': call_fn}
171
172  if layer._activity_regularizer is not None:  # pylint: disable=protected-access
173    fns['activity_regularizer_fn'] = _wrap_activity_regularizer(layer)
174    fns['call_and_return_all_conditional_losses'] = (
175        call_collection.add_function(
176            _append_activity_regularizer_loss(layer,
177                                              call_fn_with_losses,
178                                              fns['activity_regularizer_fn']),
179            '{}_layer_call_and_return_all_conditional_losses'.format(layer.name)
180            ))
181  else:
182    fns['activity_regularizer_fn'] = None
183    fns['call_and_return_all_conditional_losses'] = call_fn_with_losses
184
185  # Manually trigger traces before restoring the overwritten functions. The
186  # functions are traced within the layer call context to ensure that layer
187  # functions (e.g. add_loss) behave as though running in graph mode.
188  with base_layer_utils.call_context().enter(
189      layer, inputs=None, build_graph=True, training=None, saving=True):
190    for fn in fns.values():
191      if fn is not None and fn.input_signature is not None:
192        fn.get_concrete_function()
193
194  # Restore overwritten functions and losses
195  _restore_child_layer_functions(original_fns)
196  _restore_layer_losses(original_losses)
197
198  return fns
199
200
201def default_save_signature(layer):
202  original_losses = _reset_layer_losses(layer)
203  fn = saving_utils.trace_model_call(layer)
204  fn.get_concrete_function()
205  _restore_layer_losses(original_losses)
206  return fn
207
208
209def _replace_child_layer_functions(layer, serialization_cache):
210  """Replaces functions in the children layers with wrapped tf.functions.
211
212  This step allows functions from parent layers to reference the wrapped
213  functions from their children layers instead of retracing the ops.
214
215  This function also resets all losses stored in the layer. These are stored in
216  the returned dictionary. Use `_restore_child_layer_functions` to restore
217  the original attributes.
218
219  Args:
220    layer: Keras Layer object.
221    serialization_cache: Dictionary shared between all objects during
222      serialization.
223
224  Returns:
225    Dictionary mapping layer objects -> original functions and losses:
226      { Child layer 1: {
227          'losses': Original losses,
228          'call': Original call function
229          '_activity_regularizer': Original activity regularizer},
230        Child layer 2: ...
231      }
232  """
233  # pylint: disable=protected-access
234  original_fns = {}
235
236  def replace_layer_functions(child_layer, serialized_fns):
237    """Replaces layer call and activity regularizer with wrapped functions."""
238    original_fns[child_layer] = {
239        'call': child_layer.call,
240        '_activity_regularizer': child_layer._activity_regularizer
241    }
242    with trackable.no_automatic_dependency_tracking_scope(child_layer):
243      try:
244        child_layer._activity_regularizer = serialized_fns.get(
245            'activity_regularizer_fn')
246      except AttributeError:
247        # Some layers have an unsettable activity regularizer.
248        pass
249      child_layer.call = utils.use_wrapped_call(
250          child_layer,
251          serialized_fns['call_and_return_conditional_losses'],
252          default_training_value=False)
253
254  def replace_metric_functions(child_layer, serialized_fns):
255    """Replaces metric functions with wrapped functions."""
256    original_fns[child_layer] = {
257        '__call__': child_layer.__call__,
258        'result': child_layer.result,
259        'update_state': child_layer.update_state
260    }
261    with trackable.no_automatic_dependency_tracking_scope(child_layer):
262      child_layer.__call__ = serialized_fns['__call__']
263      child_layer.result = serialized_fns['result']
264      child_layer.update_state = serialized_fns['update_state']
265
266  for child_layer in utils.list_all_layers(layer):
267    if isinstance(child_layer, input_layer.InputLayer):
268      continue
269
270    if child_layer not in serialization_cache[constants.KERAS_CACHE_KEY]:
271      serialized_functions = (
272          child_layer._trackable_saved_model_saver._get_serialized_attributes(
273              serialization_cache).functions)
274    else:
275      serialized_functions = (
276          serialization_cache[constants.KERAS_CACHE_KEY][child_layer].functions)
277    if not serialized_functions:
278      # This indicates either:
279      #   - circular dependency, which means the current layer's functions
280      #     should be wrapped first.
281      #   - Child layer's inputs are not defined, so its functions have not been
282      #     wrapped. In this case, no replacement is necessary so move on to the
283      #     next child.
284      continue
285
286    if isinstance(child_layer, metrics.Metric):
287      replace_metric_functions(child_layer, serialized_functions)
288    else:
289      replace_layer_functions(child_layer, serialized_functions)
290
291  return original_fns
292  # pylint: enable=protected-access
293
294
295def _restore_child_layer_functions(original_fns):
296  """Restores attributes replaced with `_replace_child_layer_functions`."""
297  for child_layer, fns in original_fns.items():
298    with trackable.no_automatic_dependency_tracking_scope(child_layer):
299      for fn_name, fn in fns.items():
300        try:
301          setattr(child_layer, fn_name, fn)  # pylint: disable=protected-access
302        except AttributeError:
303          pass  # In the case of _activity_regularizer, setting the attribute
304          # may be disallowed.
305
306
307# pylint: disable=protected-access
308def _reset_layer_losses(parent_layer):
309  """Resets losses of layer and its sublayers, and returns original losses."""
310  losses_dict = {}
311  for layer in utils.list_all_layers_and_sublayers(parent_layer):
312    losses_dict[layer] = {'losses': layer._losses[:],
313                          'eager_losses': layer._eager_losses[:]}
314    with trackable.no_automatic_dependency_tracking_scope(layer):
315      layer._losses = []
316      layer._eager_losses = []
317  return losses_dict
318
319
320def _restore_layer_losses(losses_dict):
321  for layer in losses_dict:
322    with trackable.no_automatic_dependency_tracking_scope(layer):
323      layer._losses = losses_dict[layer]['losses']
324      layer._eager_losses = losses_dict[layer]['eager_losses']
325# pylint: enable=protected-access
326
327
328class LayerCallCollection(object):
329  """Groups wrapped layer call functions.
330
331  This is used to ensure that all layer call functions are traced with the same
332  inputs-
333    - call
334    - call_and_return_conditional_losses
335    - call_and_return_all_conditional_losses
336  """
337
338  def __init__(self, layer):
339    self.layer = layer
340
341    self.layer_call_method = _get_layer_call_method(layer)
342    self._expects_training_arg = utils.layer_uses_training_bool(layer)
343    self._training_arg_index = utils.get_training_arg_index(
344        self.layer_call_method)
345
346    # If the layer call function has kwargs, then the traced function cannot
347    # have an input signature.
348    arg_spec = tf_inspect.getfullargspec(self.layer_call_method)
349    self._has_kwargs = bool(self._expects_training_arg or
350                            arg_spec.defaults or
351                            arg_spec.kwonlyargs or
352                            arg_spec.varkw)
353
354    self._input_signature = self._generate_input_signature(layer)
355    self._functions = weakref.WeakValueDictionary()
356    # Bool indicating whether this object is currently tracing the layer call
357    # functions.
358    self.tracing = False
359
360    # Get the input argument name from the args.
361    args = arg_spec.args
362    if tf_inspect.ismethod(self.layer_call_method):
363      args = args[1:]
364    self._input_arg_name = args[0] if args else 'inputs'
365
366  def _generate_input_signature(self, layer):
367    """Inspects layer object and returns the inferred input signature.
368
369    Args:
370      layer: Layer object.
371
372    Returns:
373      List of possibly nested TensorSpecs of the layer call function inputs.
374      The list does not contain the `training` argument.
375    """
376    if (isinstance(layer.call, def_function.Function) and
377        layer.call.input_signature is not None):
378      return layer.call.input_signature
379    elif isinstance(layer, training_lib.Model):
380      return saving_utils.model_input_signature(layer)
381    elif (layer.input_spec is not None and
382          layer._use_input_spec_as_call_signature):  # pylint: disable=protected-access
383
384      def to_tensor_spec_or_none(x):
385        spec = input_spec.to_tensor_spec(x, layer._compute_dtype)  # pylint: disable=protected-access
386        # If the shape is too general (e.g. multiple dimensions are allowed),
387        # return None so that separate functions can be generated for each
388        # inferred input signature.
389        # TODO(b/134962016): currently partial signatures are not supported.
390        if spec.shape == tensor_shape.TensorShape(None):
391          return None
392        return spec
393      input_signature = [nest.map_structure(
394          to_tensor_spec_or_none, layer.input_spec)]
395
396      return input_signature
397    else:
398      return None
399
400  def add_trace(self, *args, **kwargs):
401    """Traces all functions with the same args and kwargs.
402
403    Args:
404      *args: Positional args passed to the original function.
405      **kwargs: Keyword args passed to the original function.
406    """
407    args = list(args)
408    kwargs = kwargs.copy()
409    self.tracing = True
410    for fn in self._functions.values():
411      # TODO(kathywu): Replace arguments with broader shapes defined in the
412      # input signature.
413      if self._expects_training_arg:
414        def trace_with_training(value, fn=fn):
415          utils.set_training_arg(value, self._training_arg_index, args, kwargs)
416          with K.deprecated_internal_learning_phase_scope(value):
417            fn.get_concrete_function(*args, **kwargs)
418
419        trace_with_training(True)
420        trace_with_training(False)
421      else:
422        fn.get_concrete_function(*args, **kwargs)
423    self.tracing = False
424
425  @property
426  def fn_input_signature(self):
427    """Returns input signature for the wrapped layer call function."""
428    if self._has_kwargs:
429      # Input signatures may only describe tensor arguments and kwargs are not
430      # supported.
431      return None
432    if None in nest.flatten(self._input_signature):
433      # TODO(b/134962016): If input signature cannot be partially defined.
434      return None
435    return self._input_signature
436
437  def training_arg_was_passed(self, args, kwargs):
438    if not self.layer._expects_training_arg and self._expects_training_arg:  # pylint: disable=protected-access
439      return (utils.get_training_arg(self._training_arg_index, args, kwargs)
440              is not None)
441    else:
442      return self.layer._call_arg_was_passed(  # pylint: disable=protected-access
443          'training', args, kwargs, inputs_in_args=True)
444
445  def get_training_arg_value(self, args, kwargs):
446    if not self.layer._expects_training_arg and self._expects_training_arg:  # pylint: disable=protected-access
447      return utils.get_training_arg(self._training_arg_index, args, kwargs)
448    else:
449      return self.layer._get_call_arg_value(  # pylint: disable=protected-access
450          'training', args, kwargs, inputs_in_args=True)
451
452  def get_input_arg_value(self, args, kwargs):
453    return self.layer._get_call_arg_value(  # pylint: disable=protected-access
454        self._input_arg_name, args, kwargs, inputs_in_args=True)
455
456  def _maybe_wrap_with_training_arg(self, call_fn):
457    """Wraps call function with added training argument if necessary."""
458    if not self.layer._expects_training_arg and self._expects_training_arg:  # pylint: disable=protected-access
459      # Add training arg to wrapper function.
460      arg_spec = tf_inspect.getfullargspec(call_fn)
461      args = arg_spec.args + ['training']
462      defaults = list(arg_spec.defaults or [])
463      defaults.append(False)
464      new_arg_spec = tf_inspect.FullArgSpec(
465          args=args,
466          varargs=arg_spec.varargs,
467          varkw=arg_spec.varkw,
468          defaults=defaults,
469          kwonlyargs=arg_spec.kwonlyargs,
470          kwonlydefaults=arg_spec.kwonlydefaults,
471          annotations=arg_spec.annotations)
472
473      # Set new training arg index
474      self._training_arg_index = len(args) - 1
475      if tf_inspect.ismethod(call_fn):
476        self._training_arg_index -= 1
477
478      def wrap_with_training_arg(*args, **kwargs):
479        # Remove the training value, since the original call_fn does not expect
480        # a training arg. Instead, the training value will be propagated using
481        # the call context created in LayerCall.
482        args = list(args)
483        kwargs = kwargs.copy()
484        utils.remove_training_arg(self._training_arg_index, args, kwargs)
485        return call_fn(*args, **kwargs)
486
487      return tf_decorator.make_decorator(
488          target=call_fn,
489          decorator_func=wrap_with_training_arg,
490          decorator_argspec=new_arg_spec)
491
492    return call_fn
493
494  def add_function(self, call_fn, name):
495    """Adds a layer call function to the collection."""
496    self._functions[name] = fn = LayerCall(
497        self, self._maybe_wrap_with_training_arg(call_fn), name,
498        input_signature=self.fn_input_signature)
499
500    if (None not in nest.flatten(self._input_signature) and
501        self._has_kwargs):
502      # Manually add traces for layers that have keyword arguments and have
503      # a fully defined input signature.
504      self.add_trace(*self._input_signature)
505    return fn
506
507
508def layer_call_wrapper(call_collection, method):
509  """Ensures layer losses are kept the same, and runs method in call context."""
510  def wrapper(*args, **kwargs):
511    """Calls method within call context."""
512    layer = call_collection.layer
513    training = None
514    inputs = call_collection.get_input_arg_value(args, kwargs)
515    # pylint: disable=protected-access
516    if (args or kwargs) and call_collection.training_arg_was_passed(
517        args, kwargs):
518      training = call_collection.get_training_arg_value(args, kwargs)
519    # pylint: enable=protected-access
520    original_losses = _reset_layer_losses(layer)
521    with base_layer_utils.call_context().enter(
522        layer, inputs=inputs, build_graph=False, training=training,
523        saving=True):
524      with autocast_variable.enable_auto_cast_variables(
525          layer._compute_dtype_object):  # pylint: disable=protected-access
526        ret = method(*args, **kwargs)
527    _restore_layer_losses(original_losses)
528    return ret
529  return tf_decorator.make_decorator(target=method, decorator_func=wrapper)
530
531
532class LayerCall(def_function.Function):
533  """Function that triggers traces of other functions in the same collection."""
534
535  def __init__(self, call_collection, python_function, *args, **kwargs):
536    self.call_collection = call_collection
537    self.original_call = call_collection.layer_call_method
538    python_function = layer_call_wrapper(call_collection, python_function)
539    super(LayerCall, self).__init__(python_function, *args, **kwargs)
540
541  def __call__(self, *args, **kwargs):
542    if not self.call_collection.tracing:
543      self.call_collection.add_trace(*args, **kwargs)
544    return super(LayerCall, self).__call__(*args, **kwargs)
545
546  def get_concrete_function(self, *args, **kwargs):
547    if not self.call_collection.tracing:
548      self.call_collection.add_trace(*args, **kwargs)
549    return super(LayerCall, self).get_concrete_function(*args, **kwargs)
550
551
552def _wrap_call_and_conditional_losses(layer):
553  """Wraps call function that returns a tuple of (outputs, losses).
554
555  The losses returned are conditional on the inputs passed to the call function.
556  Unconditional losses (e.g. weight regularizeration) are wrapped separately.
557
558  Args:
559    layer: a Keras layer object
560
561  Returns:
562    python call function that returns outputs and conditional losses -- excludes
563    activity regularizer
564  """
565  # Create function that generates both outputs and losses
566  layer_call = _get_layer_call_method(layer)
567  def call_and_return_conditional_losses(inputs, *args, **kwargs):
568    """Returns layer (call_output, conditional losses) tuple."""
569    call_output = layer_call(inputs, *args, **kwargs)
570    if version_utils.is_v1_layer_or_model(layer):
571      conditional_losses = layer.get_losses_for(inputs)
572    else:
573      conditional_losses = [
574          l for l in layer.losses if not hasattr(l, '_unconditional_loss')
575      ]
576    return call_output, conditional_losses
577
578  return _create_call_fn_decorator(layer, call_and_return_conditional_losses)
579
580
581def _extract_outputs_from_fn(layer, call_and_return_conditional_losses):
582  """Returns a function that returns only call function outputs."""
583  if isinstance(layer, keras_load.RevivedLayer):
584    return layer.keras_api.__call__  # pylint: disable=protected-access
585  def call(inputs, *args, **kwargs):
586    return call_and_return_conditional_losses(inputs, *args, **kwargs)[0]
587  return _create_call_fn_decorator(layer, call)
588
589
590def _append_activity_regularizer_loss(
591    layer, call_fn_with_losses, activity_regularizer_fn):
592  """Appends activity regularizer loss to losses returned by the wrapped fn."""
593  def fn(inputs, *args, **kwargs):
594    outputs, losses = call_fn_with_losses(inputs, *args, **kwargs)
595    losses.append(activity_regularizer_fn(outputs))
596    return outputs, losses
597  return _create_call_fn_decorator(layer, fn)
598
599
600def _create_call_fn_decorator(layer, wrapped_call):
601  call_fn = _get_layer_call_method(layer)
602  fn, arg_spec = utils.maybe_add_training_arg(
603      call_fn, wrapped_call, layer._expects_training_arg,  # pylint: disable=protected-access
604      default_training_value=False)
605  return tf_decorator.make_decorator(
606      target=call_fn,
607      decorator_func=fn,
608      decorator_argspec=arg_spec)
609
610
611def _wrap_unconditional_loss(loss_fn, index):
612  """Wraps callable/unconditional loss, returning a serializable function."""
613  # Extract original loss function from partial function
614  fn = loss_fn.args[0] if isinstance(loss_fn, functools.partial) else loss_fn
615  if isinstance(fn, def_function.Function):
616    return fn
617  else:
618    return def_function.Function(
619        fn, 'loss_fn_{}'.format(index), input_signature=[])
620
621
622def _wrap_activity_regularizer(layer):
623  """Wraps the activity regularizer."""
624  # pylint: disable=protected-access
625  if isinstance(layer._activity_regularizer, def_function.Function):
626    return layer._activity_regularizer
627  return def_function.Function(
628      layer._activity_regularizer,
629      '{}_activity_regularizer'.format(layer.name),
630      input_signature=[
631          tensor_spec.TensorSpec(None, layer._compute_dtype or K.floatx())
632      ])
633  # pylint: enable=protected-access
634
635
636def _get_layer_call_method(layer):
637  if isinstance(layer.call, (def_function.Function)):
638    return layer.call.python_function
639  return layer.call
640