• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Code for model cloning, plus model-related API entries.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from tensorflow.python.framework import ops
23from tensorflow.python.keras import backend as K
24from tensorflow.python.keras import metrics as metrics_module
25from tensorflow.python.keras import optimizer_v1
26from tensorflow.python.keras.engine import functional
27from tensorflow.python.keras.engine import sequential
28from tensorflow.python.keras.engine import training
29from tensorflow.python.keras.engine import training_v1
30from tensorflow.python.keras.engine.base_layer import AddMetric
31from tensorflow.python.keras.engine.base_layer import Layer
32from tensorflow.python.keras.engine.input_layer import Input
33from tensorflow.python.keras.engine.input_layer import InputLayer
34from tensorflow.python.keras.saving import model_config
35from tensorflow.python.keras.saving import save
36from tensorflow.python.keras.utils import generic_utils
37from tensorflow.python.keras.utils import version_utils
38from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
39from tensorflow.python.platform import tf_logging as logging
40from tensorflow.python.util import nest
41from tensorflow.python.util.tf_export import keras_export
42
43
44# API entries importable from `keras.models`:
45Model = training.Model  # pylint: disable=invalid-name
46Sequential = sequential.Sequential  # pylint: disable=invalid-name
47Functional = functional.Functional  # pylint: disable=invalid-name
48save_model = save.save_model
49load_model = save.load_model
50model_from_config = model_config.model_from_config
51model_from_yaml = model_config.model_from_yaml
52model_from_json = model_config.model_from_json
53
54
55# Callable used to clone a layer with weights preserved.
56def share_weights(layer):
57  return layer
58
59
60def _clone_layer(layer):
61  return layer.__class__.from_config(layer.get_config())
62
63
64def _insert_ancillary_layers(model, ancillary_layers, metrics_names, new_nodes):
65  """Inserts ancillary layers into the model with the proper order."""
66  # Sort `AddMetric` layers so they agree with metrics_names.
67  metric_layers = [
68      layer for layer in ancillary_layers if isinstance(layer, AddMetric)
69  ]
70  metric_layers.sort(key=lambda layer: metrics_names.index(layer.metric_name))
71  ancillary_layers = [
72      layer for layer in ancillary_layers if not isinstance(layer, AddMetric)
73  ] + metric_layers
74  model._insert_layers(ancillary_layers, relevant_nodes=list(new_nodes))
75
76
77def _make_new_nodes(nodes_by_depth, layer_fn, layer_map, tensor_map):
78  """Uses the layers in `layer_map` to make new nodes based on `nodes_by_depth`.
79
80  Args:
81    nodes_by_depth: Provides structure information to create new nodes.
82    layer_fn: Function to clone layers.
83    layer_map: Map from layers in `model` to new layers.
84    tensor_map: Map from tensors in `model` to newly compute tensors.
85
86  Returns:
87    A set of new nodes. `layer_map` and `tensor_map` are updated.
88  """
89  # Iterated over every node in the reference model, in depth order.
90  new_nodes = set()
91  depth_keys = list(nodes_by_depth.keys())
92  depth_keys.sort(reverse=True)
93  for depth in depth_keys:
94    nodes = nodes_by_depth[depth]
95    for node in nodes:
96      # Recover the corresponding layer.
97      layer = node.outbound_layer
98
99      # Get or create layer.
100      if layer not in layer_map:
101        new_layer = layer_fn(layer)
102        layer_map[layer] = new_layer
103        layer = new_layer
104      else:
105        # Reuse previously cloned layer.
106        layer = layer_map[layer]
107        # Don't call InputLayer multiple times.
108        if isinstance(layer, InputLayer):
109          continue
110
111      # If all previous input tensors are available in tensor_map,
112      # then call node.inbound_layer on them.
113      if all(
114          tensor in tensor_map for tensor in nest.flatten(node.input_tensors)):
115        # Call layer.
116        args = nest.map_structure(lambda t: tensor_map.get(t, t),
117                                  node.call_args)
118        kwargs = nest.map_structure(lambda t: tensor_map.get(t, t),
119                                    node.call_kwargs)
120        output_tensors = layer(*args, **kwargs)
121
122        # Thread-safe way to keep track of what node was created.
123        first_output_tensor = nest.flatten(output_tensors)[0]
124        new_nodes.add(
125            layer._inbound_nodes[first_output_tensor._keras_history.node_index])
126
127        for x, y in zip(
128            nest.flatten(node.output_tensors), nest.flatten(output_tensors)):
129          tensor_map[x] = y
130  return new_nodes
131
132
133def _clone_functional_model(model, input_tensors=None, layer_fn=_clone_layer):
134  """Clone a functional `Model` instance.
135
136  Model cloning is similar to calling a model on new inputs,
137  except that it creates new layers (and thus new weights) instead
138  of sharing the weights of the existing layers.
139
140  Input layers are always cloned.
141
142  Args:
143      model: Instance of `Model`.
144      input_tensors: optional list of input tensors
145          to build the model upon. If not provided,
146          placeholders will be created.
147      layer_fn: callable to be applied on non-input layers in the model. By
148          default it clones the layer. Another example is to preserve the layer
149          to share the weights. This is required when we create a per-replica
150          copy of the model with distribution strategy; we want the weights to
151          be shared but still feed inputs separately so we create new input
152          layers.
153
154  Returns:
155      An instance of `Model` reproducing the behavior
156      of the original model, on top of new inputs tensors,
157      using newly instantiated weights.
158
159  Raises:
160      ValueError: in case of invalid `model` argument value or `layer_fn`
161      argument value.
162  """
163  if not isinstance(model, Model):
164    raise ValueError('Expected `model` argument '
165                     'to be a `Model` instance, got ', model)
166  if isinstance(model, Sequential):
167    raise ValueError('Expected `model` argument '
168                     'to be a functional `Model` instance, '
169                     'got a `Sequential` instance instead:', model)
170  if not model._is_graph_network:
171    raise ValueError('Expected `model` argument '
172                     'to be a functional `Model` instance, '
173                     'but got a subclass model instead.')
174
175  new_input_layers = {}  # Cache for created layers.
176  if input_tensors is not None:
177    # Make sure that all input tensors come from a Keras layer.
178    input_tensors = nest.flatten(input_tensors)
179    for i, input_tensor in enumerate(input_tensors):
180      original_input_layer = model._input_layers[i]
181
182      # Cache input layer. Create a new layer if the tensor is originally not
183      # from a Keras layer.
184      if not K.is_keras_tensor(input_tensor):
185        name = original_input_layer.name
186        input_tensor = Input(tensor=input_tensor,
187                             name='input_wrapper_for_' + name)
188        newly_created_input_layer = input_tensor._keras_history.layer
189        new_input_layers[original_input_layer] = newly_created_input_layer
190      else:
191        new_input_layers[original_input_layer] = original_input_layer
192
193  if not callable(layer_fn):
194    raise ValueError('Expected `layer_fn` argument to be a callable.')
195
196  model_configs, created_layers = _clone_layers_and_model_config(
197      model, new_input_layers, layer_fn)
198  # Reconstruct model from the config, using the cloned layers.
199  input_tensors, output_tensors, created_layers = (
200      functional.reconstruct_from_config(model_configs,
201                                         created_layers=created_layers))
202  metrics_names = model.metrics_names
203  model = Model(input_tensors, output_tensors, name=model.name)
204  # Layers not directly tied to outputs of the Model, such as loss layers
205  # created in `add_loss` and `add_metric`.
206  ancillary_layers = [
207      layer for layer in created_layers.values() if layer not in model.layers
208  ]
209  # TODO(b/162887610): This may need to adjust the inbound node index if the
210  # created layers had already been used to define other models.
211  if ancillary_layers:
212    new_nodes = nest.flatten([
213        layer.inbound_nodes[1:]
214        if functional._should_skip_first_node(layer)
215        else layer.inbound_nodes for layer in created_layers.values()
216    ])
217    _insert_ancillary_layers(model, ancillary_layers, metrics_names, new_nodes)
218  return model
219
220
221def _clone_layers_and_model_config(model, input_layers, layer_fn):
222  """Clones all layers, and returns the model config without serializing layers.
223
224  This function ensures that only the node graph is retrieved when getting the
225  model config. The `layer_fn` used to clone layers might not rely on
226  `layer.get_config()`, so some custom layers do not define `get_config`.
227  Trying to retrieve the config results in errors.
228
229  Args:
230    model: A Functional model.
231    input_layers: Dictionary mapping input layers in `model` to new input layers
232    layer_fn: Function used to clone all non-input layers.
233
234  Returns:
235    Model config object, and a dictionary of newly created layers.
236  """
237  created_layers = {}
238  def _copy_layer(layer):
239    # Whenever the network config attempts to get the layer serialization,
240    # return a dummy dictionary.
241    if layer in input_layers:
242      created_layers[layer.name] = input_layers[layer]
243    elif layer in model._input_layers:
244      created_layers[layer.name] = InputLayer(**layer.get_config())
245    else:
246      created_layers[layer.name] = layer_fn(layer)
247    return {}
248
249  config = functional.get_network_config(
250      model, serialize_layer_fn=_copy_layer)
251  return config, created_layers
252
253
254def _remove_ancillary_layers(model, layer_map, layers):
255  """Removes and returns any ancillary layers from `layers` based on `model`.
256
257  Ancillary layers are part of the model topology but not used to compute the
258  model outputs, e.g., layers from `add_loss` and `add_metric`.
259
260  Args:
261    model: A Keras Model.
262    layer_map: A map to from layers in the `model` to those in `layers`.
263    layers: A list of all layers.
264
265  Returns:
266    Two lists of layers: (1) `layers` with the ancillary layers removed, and (2)
267    the ancillary layers.
268  """
269  ancillary_layers = []  # Additional layers for computing losses and metrics.
270  if not model._is_graph_network:
271    return layers, ancillary_layers
272
273  # Ancillary layers are those with depth < 0.
274  depths = [depth for depth in model._nodes_by_depth.keys() if depth < 0]
275  depths.sort(reverse=True)  # Order topologically from inputs to outputs.
276  for depth in depths:
277    for node in model._nodes_by_depth[depth]:
278      ancillary_layers.append(layer_map[node.outbound_layer])
279
280  return [l for l in layers if l not in ancillary_layers], ancillary_layers
281
282
283def _clone_sequential_model(model, input_tensors=None, layer_fn=_clone_layer):
284  """Clone a `Sequential` model instance.
285
286  Model cloning is similar to calling a model on new inputs,
287  except that it creates new layers (and thus new weights) instead
288  of sharing the weights of the existing layers.
289
290  Args:
291      model: Instance of `Sequential`.
292      input_tensors: optional list of input tensors
293          to build the model upon. If not provided,
294          placeholders will be created.
295      layer_fn: callable to be applied on non-input layers in the model. By
296          default it clones the layer. Another example is to preserve the layer
297          to share the weights. This is required when we create a per-replica
298          copy of the model with distribution strategy; we want the weights to
299          be shared but still feed inputs separately so we create new input
300          layers.
301
302  Returns:
303      An instance of `Sequential` reproducing the behavior
304      of the original model, on top of new inputs tensors,
305      using newly instantiated weights.
306
307  Raises:
308      ValueError: in case of invalid `model` argument value or `layer_fn`
309      argument value.
310  """
311  if not isinstance(model, Sequential):
312    raise ValueError('Expected `model` argument '
313                     'to be a `Sequential` model instance, '
314                     'but got:', model)
315
316  if not callable(layer_fn):
317    raise ValueError('Expected `layer_fn` argument to be a callable.')
318
319  layers = []  # Layers needed to compute the model's outputs.
320  layer_map = {}
321  # Ensure that all layers are cloned. The model's layers
322  # property will exclude the initial InputLayer (if it exists) in the model,
323  # resulting in a different Sequential model structure.
324  for layer in model._flatten_layers(include_self=False, recursive=False):
325    if isinstance(layer, InputLayer) and input_tensors is not None:
326      # If input tensors are provided, the original model's InputLayer is
327      # overwritten with a different InputLayer.
328      continue
329    cloned_layer = (
330        _clone_layer(layer)
331        if isinstance(layer, InputLayer) else layer_fn(layer))
332    layers.append(cloned_layer)
333    layer_map[layer] = cloned_layer
334  layers, ancillary_layers = _remove_ancillary_layers(model, layer_map, layers)
335
336  if input_tensors is None:
337    cloned_model = Sequential(layers=layers, name=model.name)
338  elif len(generic_utils.to_list(input_tensors)) != 1:
339    raise ValueError('To clone a `Sequential` model, we expect '
340                     ' at most one tensor '
341                     'as part of `input_tensors`.')
342  else:
343    # Overwrite the original model's input layer.
344    if isinstance(input_tensors, tuple):
345      input_tensors = list(input_tensors)
346    x = generic_utils.to_list(input_tensors)[0]
347    if K.is_keras_tensor(x):
348      origin_layer = x._keras_history.layer
349      if isinstance(origin_layer, InputLayer):
350        cloned_model = Sequential(
351            layers=[origin_layer] + layers, name=model.name)
352      else:
353        raise ValueError('Cannot clone a `Sequential` model on top '
354                         'of a tensor that comes from a Keras layer '
355                         'other than an `InputLayer`. '
356                         'Use the functional API instead.')
357    else:
358      input_tensor = Input(tensor=x, name='input_wrapper_for_' + str(x.name))
359      input_layer = input_tensor._keras_history.layer
360      cloned_model = Sequential(layers=[input_layer] + layers, name=model.name)
361
362  if not ancillary_layers:
363    return cloned_model
364
365  tensor_map = {}  # Maps tensors from `model` to those in `cloned_model`.
366  for depth, cloned_nodes in cloned_model._nodes_by_depth.items():
367    nodes = model._nodes_by_depth[depth]
368    # This should be safe in a Sequential model. In an arbitrary network, you
369    # need to sort using the outbound layer of the node as a key.
370    for cloned_node, node in zip(cloned_nodes, nodes):
371      if isinstance(cloned_node.output_tensors, list):
372        for j, output_tensor in enumerate(cloned_node.output_tensors):
373          tensor_map[node.output_tensors[j]] = output_tensor
374      else:
375        tensor_map[node.output_tensors] = cloned_node.output_tensors
376  # Ancillary nodes have negative depth.
377  new_nodes = _make_new_nodes(
378      {
379          depth: nodes
380          for depth, nodes in model._nodes_by_depth.items()
381          if depth < 0
382      }, layer_fn, layer_map, tensor_map)
383  _insert_ancillary_layers(cloned_model, ancillary_layers, model.metrics_names,
384                           new_nodes)
385  return cloned_model
386
387
388@keras_export('keras.models.clone_model')
389def clone_model(model, input_tensors=None, clone_function=None):
390  """Clone any `Model` instance.
391
392  Model cloning is similar to calling a model on new inputs,
393  except that it creates new layers (and thus new weights) instead
394  of sharing the weights of the existing layers.
395
396  `clone_model` will not preserve the uniqueness of shared objects within the
397  model (e.g. a single variable attached to two distinct layers will be
398  restored as two separate variables).
399
400  Args:
401      model: Instance of `Model`
402          (could be a functional model or a Sequential model).
403      input_tensors: optional list of input tensors or InputLayer objects
404          to build the model upon. If not provided,
405          placeholders will be created.
406      clone_function: Callable to be used to clone each layer in the target
407          model (except `InputLayer` instances). It takes as argument the layer
408          instance to be cloned, and returns the corresponding layer instance to
409          be used in the model copy. If unspecified, this callable defaults to
410          the following serialization/deserialization function:
411          `lambda layer: layer.__class__.from_config(layer.get_config())`.
412          By passing a custom callable, you can customize your copy of the
413          model, e.g. by wrapping certain layers of interest (you might want to
414          replace all `LSTM` instances with equivalent
415          `Bidirectional(LSTM(...))` instances, for example).
416
417  Returns:
418      An instance of `Model` reproducing the behavior
419      of the original model, on top of new inputs tensors,
420      using newly instantiated weights. The cloned model might behave
421      differently from the original model if a custom clone_function
422      modifies the layer.
423
424  Raises:
425      ValueError: in case of invalid `model` argument value.
426  """
427  with generic_utils.DisableSharedObjectScope():
428    if clone_function is None:
429      clone_function = _clone_layer
430
431    if isinstance(model, Sequential):
432      return _clone_sequential_model(
433          model, input_tensors=input_tensors, layer_fn=clone_function)
434    else:
435      return _clone_functional_model(
436          model, input_tensors=input_tensors, layer_fn=clone_function)
437
438
439# "Clone" a subclassed model by reseting all of the attributes.
440def _in_place_subclassed_model_reset(model):
441  """Substitute for model cloning that works for subclassed models.
442
443  Subclassed models cannot be cloned because their topology is not serializable.
444  To "instantiate" an identical model in a new TF graph, we reuse the original
445  model object, but we clear its state.
446
447  After calling this function on a model instance, you can use the model
448  instance as if it were a model clone (in particular you can use it in a new
449  graph).
450
451  This method clears the state of the input model. It is thus destructive.
452  However the original state can be restored fully by calling
453  `_in_place_subclassed_model_state_restoration`.
454
455  Args:
456    model: Instance of a Keras model created via subclassing.
457
458  Raises:
459    ValueError: In case the model uses a subclassed model as inner layer.
460  """
461  assert not model._is_graph_network  # Only makes sense for subclassed networks
462  # Select correct base class for new Model.
463  version_utils.swap_class(model.__class__, training.Model, training_v1.Model,
464                           ops.executing_eagerly_outside_functions())
465  # Retrieve all layers tracked by the model as well as their attribute names
466  attributes_cache = {}
467  for name in dir(model):
468    # Skip attrs that track other trackables.
469    if name == 'submodules' or name == '_self_tracked_trackables':
470      continue
471
472    try:
473      value = getattr(model, name)
474    except (AttributeError, ValueError, TypeError):
475      continue
476    if isinstance(value, Layer):
477      attributes_cache[name] = value
478      assert value in model.layers
479      if hasattr(value, 'layers') and value.layers:
480        raise ValueError('We do not support the use of nested layers '
481                         'in `model_to_estimator` at this time. Found nested '
482                         'layer: %s' % value)
483    elif isinstance(
484        value, (list, tuple)) and name not in ('layers', '_layers', 'metrics',
485                                               '_compile_metric_functions',
486                                               '_output_loss_metrics'):
487      # Handle case: list/tuple of layers (also tracked by the Network API).
488      if value and all(isinstance(val, Layer) for val in value):
489        raise ValueError('We do not support the use of list-of-layers '
490                         'attributes in subclassed models used with '
491                         '`model_to_estimator` at this time. Found list '
492                         'model: %s' % name)
493
494  # Replace layers on the model with fresh layers
495  layers_to_names = {value: key for key, value in attributes_cache.items()}
496  original_layers = list(
497      model._flatten_layers(include_self=False, recursive=False))
498  setattr_tracking = model._setattr_tracking
499  model._setattr_tracking = False
500  model._self_tracked_trackables = []
501  for layer in original_layers:  # We preserve layer order.
502    config = layer.get_config()
503    # This will not work for nested subclassed models used as layers.
504    # This would be theoretically possible to support, but would add complexity.
505    # Only do it if users complain.
506    if isinstance(layer, training.Model) and not layer._is_graph_network:
507      raise ValueError('We do not support the use of nested subclassed models '
508                       'in `model_to_estimator` at this time. Found nested '
509                       'model: %s' % layer)
510    fresh_layer = layer.__class__.from_config(config)
511    name = layers_to_names[layer]
512    setattr(model, name, fresh_layer)
513    model._self_tracked_trackables.append(fresh_layer)
514
515  # Cache original model build attributes (in addition to layers)
516  if (not hasattr(model, '_original_attributes_cache') or
517      model._original_attributes_cache is None):
518    if model.built:
519      attributes_to_cache = [
520          'inputs',
521          'outputs',
522          'total_loss',
523          'optimizer',
524          'train_function',
525          'test_function',
526          'predict_function',
527          '_training_endpoints',
528          '_collected_trainable_weights',
529          '_feed_inputs',
530          '_feed_input_names',
531          '_feed_input_shapes',
532      ]
533      for name in attributes_to_cache:
534        attributes_cache[name] = getattr(model, name)
535  model._original_attributes_cache = attributes_cache
536  _reset_build_compile_trackers(model)
537  model._setattr_tracking = setattr_tracking
538
539
540def _reset_build_compile_trackers(model):
541  """Reset state trackers for model.
542
543  Note that we do not actually zero out attributes such as optimizer,
544  but instead rely on the expectation that all of the attrs will be
545  over-written on calling build/compile/etc. This is somewhat fragile,
546  insofar as we check elsewhere for the presence of these attributes as
547  evidence of having been built/compiled/etc. Pending a better way to do this,
548  we reset key attributes here to allow building and compiling.
549
550  Args:
551    model: the model that is being reset
552  """
553  # Reset build state
554  model.built = False
555  model.inputs = None
556  model.outputs = None
557  # Reset compile state
558  model._is_compiled = False  # pylint:disable=protected-access
559  if not ops.executing_eagerly_outside_functions():
560    model._v1_compile_was_called = False
561  model.optimizer = None
562
563
564def in_place_subclassed_model_state_restoration(model):
565  """Restores the original state of a model after it was "reset".
566
567  This undoes this action of `_in_place_subclassed_model_reset`, which is called
568  in `clone_and_build_model` if `in_place_reset` is set to True.
569
570  Args:
571    model: Instance of a Keras model created via subclassing, on which
572      `_in_place_subclassed_model_reset` was previously called.
573  """
574  assert not model._is_graph_network
575  # Restore layers and build attributes
576  if (hasattr(model, '_original_attributes_cache') and
577      model._original_attributes_cache is not None):
578    # Models have sticky attribute assignment, so we want to be careful to add
579    # back the previous attributes and track Layers by their original names
580    # without adding dependencies on "utility" attributes which Models exempt
581    # when they're constructed.
582    setattr_tracking = model._setattr_tracking
583    model._setattr_tracking = False
584    model._self_tracked_trackables = []
585    for name, value in model._original_attributes_cache.items():
586      setattr(model, name, value)
587      if isinstance(value, Layer):
588        model._self_tracked_trackables.append(value)
589    model._original_attributes_cache = None
590    model._setattr_tracking = setattr_tracking
591  else:
592    # Restore to the state of a never-called model.
593    _reset_build_compile_trackers(model)
594
595
596def clone_and_build_model(
597    model, input_tensors=None, target_tensors=None, custom_objects=None,
598    compile_clone=True, in_place_reset=False, optimizer_iterations=None,
599    optimizer_config=None):
600  """Clone a `Model` and build/compile it with the same settings used before.
601
602  This function can be run in the same graph or in a separate graph from the
603  model. When using a separate graph, `in_place_reset` must be `False`.
604
605  Note that, currently, the clone produced from this function may not work with
606  TPU DistributionStrategy. Try at your own risk.
607
608  Args:
609    model: `tf.keras.Model` object. Can be Functional, Sequential, or
610      sub-classed.
611    input_tensors: Optional list or dictionary of input tensors to build the
612      model upon. If not provided, placeholders will be created.
613    target_tensors: Optional list of target tensors for compiling the model. If
614      not provided, placeholders will be created.
615    custom_objects: Optional dictionary mapping string names to custom classes
616      or functions.
617    compile_clone: Boolean, whether to compile model clone (default `True`).
618    in_place_reset: Boolean, whether to reset the model in place. Only used if
619      the model is a subclassed model. In the case of a subclassed model,
620      this argument must be set to `True` (default `False`). To restore the
621      original model, use the function
622      `in_place_subclassed_model_state_restoration(model)`.
623    optimizer_iterations: An iterations variable that will be incremented by the
624      optimizer if the clone is compiled. This argument is used when a Keras
625      model is cloned into an Estimator model function, because Estimators
626      create their own global step variable.
627    optimizer_config: Optimizer config dictionary or list of dictionary
628      returned from `get_config()`. This argument should be defined if
629      `clone_and_build_model` is called in a different graph or session from
630      the original model, and the optimizer is an instance of `OptimizerV2`.
631
632  Returns:
633    Clone of the model.
634
635  Raises:
636    ValueError: Cloning fails in the following cases
637      - cloning a subclassed model with `in_place_reset` set to False.
638      - compiling the clone when the original model has not been compiled.
639  """
640  # Grab optimizer now, as we reset-in-place for subclassed models, but
641  # want to maintain access to the original optimizer.
642  orig_optimizer = model.optimizer
643  if compile_clone and not orig_optimizer:
644    raise ValueError(
645        'Error when cloning model: compile_clone was set to True, but the '
646        'original model has not been compiled.')
647
648  if compile_clone:
649    compile_args = model._get_compile_args()  # pylint: disable=protected-access
650    # Allows this method to be robust to switching graph and eager classes.
651    model._get_compile_args = lambda: compile_args
652
653  with CustomObjectScope(custom_objects or {}):
654    if model._is_graph_network:
655      clone = clone_model(model, input_tensors=input_tensors)
656    elif isinstance(model, Sequential):
657      clone = clone_model(model, input_tensors=input_tensors)
658      if (not clone._is_graph_network and model._build_input_shape is not None):
659        if ops.executing_eagerly_outside_functions():
660          clone.build(model._build_input_shape)
661        else:
662          clone._set_inputs(
663              K.placeholder(
664                  model._build_input_shape, dtype=model.inputs[0].dtype))
665    else:
666      try:
667        # Prefer cloning the model if serial/deserial logic is implemented for
668        # subclassed model.
669        clone = model.__class__.from_config(model.get_config())
670      except NotImplementedError:
671        logging.warning('This model is a subclassed model. Please implement '
672                        '`get_config` and `from_config` to better support '
673                        'cloning the model.')
674        if not in_place_reset:
675          raise ValueError(
676              'This model is a subclassed model. '
677              'Such a model cannot be cloned, but there is a workaround where '
678              'the model is reset in-place. To use this, please set the '
679              'argument `in_place_reset` to `True`. This will reset the '
680              'attributes in the original model. To restore the attributes, '
681              'call `in_place_subclassed_model_state_restoration(model)`.')
682        clone = model
683        _in_place_subclassed_model_reset(clone)
684      if input_tensors is not None:
685        if isinstance(input_tensors, (list, tuple)) and len(input_tensors) == 1:
686          input_tensors = input_tensors[0]
687        clone._set_inputs(input_tensors)
688
689  if compile_clone:
690    if isinstance(orig_optimizer, optimizer_v1.TFOptimizer):
691      optimizer = optimizer_v1.TFOptimizer(
692          orig_optimizer.optimizer, optimizer_iterations)
693      K.track_tf_optimizer(optimizer)
694    else:
695      if not isinstance(orig_optimizer, (tuple, list)):
696        orig_optimizer = [orig_optimizer]
697      if optimizer_config is None:
698        optimizer = [
699            opt.__class__.from_config(opt.get_config())
700            for opt in orig_optimizer
701        ]
702      elif isinstance(optimizer_config, dict):
703        optimizer = [orig_optimizer[0].__class__.from_config(optimizer_config)]
704      else:
705        # optimizer config is list of dict, same order as orig_optimizer.
706        optimizer = [
707            opt.__class__.from_config(opt_config)
708            for (opt, opt_config) in zip(orig_optimizer, optimizer_config)
709        ]
710      if optimizer_iterations is not None:
711        for opt in optimizer:
712          opt.iterations = optimizer_iterations
713
714      if len(optimizer) == 1:
715        optimizer = optimizer[0]
716
717    compile_args['optimizer'] = optimizer
718    if target_tensors is not None:
719      compile_args['target_tensors'] = target_tensors
720    # Ensure Metric objects in new model are separate from existing model.
721    compile_args['metrics'] = metrics_module.clone_metrics(
722        compile_args['metrics'])
723    compile_args['weighted_metrics'] = metrics_module.clone_metrics(
724        compile_args['weighted_metrics'])
725    clone.compile(**compile_args)
726
727  return clone
728