1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Code for model cloning, plus model-related API entries. 17""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from tensorflow.python.framework import ops 23from tensorflow.python.keras import backend as K 24from tensorflow.python.keras import metrics as metrics_module 25from tensorflow.python.keras import optimizer_v1 26from tensorflow.python.keras.engine import functional 27from tensorflow.python.keras.engine import sequential 28from tensorflow.python.keras.engine import training 29from tensorflow.python.keras.engine import training_v1 30from tensorflow.python.keras.engine.base_layer import AddMetric 31from tensorflow.python.keras.engine.base_layer import Layer 32from tensorflow.python.keras.engine.input_layer import Input 33from tensorflow.python.keras.engine.input_layer import InputLayer 34from tensorflow.python.keras.saving import model_config 35from tensorflow.python.keras.saving import save 36from tensorflow.python.keras.utils import generic_utils 37from tensorflow.python.keras.utils import version_utils 38from tensorflow.python.keras.utils.generic_utils import CustomObjectScope 39from tensorflow.python.platform import tf_logging as logging 40from tensorflow.python.util import nest 41from tensorflow.python.util.tf_export import keras_export 42 43 44# API entries importable from `keras.models`: 45Model = training.Model # pylint: disable=invalid-name 46Sequential = sequential.Sequential # pylint: disable=invalid-name 47Functional = functional.Functional # pylint: disable=invalid-name 48save_model = save.save_model 49load_model = save.load_model 50model_from_config = model_config.model_from_config 51model_from_yaml = model_config.model_from_yaml 52model_from_json = model_config.model_from_json 53 54 55# Callable used to clone a layer with weights preserved. 56def share_weights(layer): 57 return layer 58 59 60def _clone_layer(layer): 61 return layer.__class__.from_config(layer.get_config()) 62 63 64def _insert_ancillary_layers(model, ancillary_layers, metrics_names, new_nodes): 65 """Inserts ancillary layers into the model with the proper order.""" 66 # Sort `AddMetric` layers so they agree with metrics_names. 67 metric_layers = [ 68 layer for layer in ancillary_layers if isinstance(layer, AddMetric) 69 ] 70 metric_layers.sort(key=lambda layer: metrics_names.index(layer.metric_name)) 71 ancillary_layers = [ 72 layer for layer in ancillary_layers if not isinstance(layer, AddMetric) 73 ] + metric_layers 74 model._insert_layers(ancillary_layers, relevant_nodes=list(new_nodes)) 75 76 77def _make_new_nodes(nodes_by_depth, layer_fn, layer_map, tensor_map): 78 """Uses the layers in `layer_map` to make new nodes based on `nodes_by_depth`. 79 80 Args: 81 nodes_by_depth: Provides structure information to create new nodes. 82 layer_fn: Function to clone layers. 83 layer_map: Map from layers in `model` to new layers. 84 tensor_map: Map from tensors in `model` to newly compute tensors. 85 86 Returns: 87 A set of new nodes. `layer_map` and `tensor_map` are updated. 88 """ 89 # Iterated over every node in the reference model, in depth order. 90 new_nodes = set() 91 depth_keys = list(nodes_by_depth.keys()) 92 depth_keys.sort(reverse=True) 93 for depth in depth_keys: 94 nodes = nodes_by_depth[depth] 95 for node in nodes: 96 # Recover the corresponding layer. 97 layer = node.outbound_layer 98 99 # Get or create layer. 100 if layer not in layer_map: 101 new_layer = layer_fn(layer) 102 layer_map[layer] = new_layer 103 layer = new_layer 104 else: 105 # Reuse previously cloned layer. 106 layer = layer_map[layer] 107 # Don't call InputLayer multiple times. 108 if isinstance(layer, InputLayer): 109 continue 110 111 # If all previous input tensors are available in tensor_map, 112 # then call node.inbound_layer on them. 113 if all( 114 tensor in tensor_map for tensor in nest.flatten(node.input_tensors)): 115 # Call layer. 116 args = nest.map_structure(lambda t: tensor_map.get(t, t), 117 node.call_args) 118 kwargs = nest.map_structure(lambda t: tensor_map.get(t, t), 119 node.call_kwargs) 120 output_tensors = layer(*args, **kwargs) 121 122 # Thread-safe way to keep track of what node was created. 123 first_output_tensor = nest.flatten(output_tensors)[0] 124 new_nodes.add( 125 layer._inbound_nodes[first_output_tensor._keras_history.node_index]) 126 127 for x, y in zip( 128 nest.flatten(node.output_tensors), nest.flatten(output_tensors)): 129 tensor_map[x] = y 130 return new_nodes 131 132 133def _clone_functional_model(model, input_tensors=None, layer_fn=_clone_layer): 134 """Clone a functional `Model` instance. 135 136 Model cloning is similar to calling a model on new inputs, 137 except that it creates new layers (and thus new weights) instead 138 of sharing the weights of the existing layers. 139 140 Input layers are always cloned. 141 142 Args: 143 model: Instance of `Model`. 144 input_tensors: optional list of input tensors 145 to build the model upon. If not provided, 146 placeholders will be created. 147 layer_fn: callable to be applied on non-input layers in the model. By 148 default it clones the layer. Another example is to preserve the layer 149 to share the weights. This is required when we create a per-replica 150 copy of the model with distribution strategy; we want the weights to 151 be shared but still feed inputs separately so we create new input 152 layers. 153 154 Returns: 155 An instance of `Model` reproducing the behavior 156 of the original model, on top of new inputs tensors, 157 using newly instantiated weights. 158 159 Raises: 160 ValueError: in case of invalid `model` argument value or `layer_fn` 161 argument value. 162 """ 163 if not isinstance(model, Model): 164 raise ValueError('Expected `model` argument ' 165 'to be a `Model` instance, got ', model) 166 if isinstance(model, Sequential): 167 raise ValueError('Expected `model` argument ' 168 'to be a functional `Model` instance, ' 169 'got a `Sequential` instance instead:', model) 170 if not model._is_graph_network: 171 raise ValueError('Expected `model` argument ' 172 'to be a functional `Model` instance, ' 173 'but got a subclass model instead.') 174 175 new_input_layers = {} # Cache for created layers. 176 if input_tensors is not None: 177 # Make sure that all input tensors come from a Keras layer. 178 input_tensors = nest.flatten(input_tensors) 179 for i, input_tensor in enumerate(input_tensors): 180 original_input_layer = model._input_layers[i] 181 182 # Cache input layer. Create a new layer if the tensor is originally not 183 # from a Keras layer. 184 if not K.is_keras_tensor(input_tensor): 185 name = original_input_layer.name 186 input_tensor = Input(tensor=input_tensor, 187 name='input_wrapper_for_' + name) 188 newly_created_input_layer = input_tensor._keras_history.layer 189 new_input_layers[original_input_layer] = newly_created_input_layer 190 else: 191 new_input_layers[original_input_layer] = original_input_layer 192 193 if not callable(layer_fn): 194 raise ValueError('Expected `layer_fn` argument to be a callable.') 195 196 model_configs, created_layers = _clone_layers_and_model_config( 197 model, new_input_layers, layer_fn) 198 # Reconstruct model from the config, using the cloned layers. 199 input_tensors, output_tensors, created_layers = ( 200 functional.reconstruct_from_config(model_configs, 201 created_layers=created_layers)) 202 metrics_names = model.metrics_names 203 model = Model(input_tensors, output_tensors, name=model.name) 204 # Layers not directly tied to outputs of the Model, such as loss layers 205 # created in `add_loss` and `add_metric`. 206 ancillary_layers = [ 207 layer for layer in created_layers.values() if layer not in model.layers 208 ] 209 # TODO(b/162887610): This may need to adjust the inbound node index if the 210 # created layers had already been used to define other models. 211 if ancillary_layers: 212 new_nodes = nest.flatten([ 213 layer.inbound_nodes[1:] 214 if functional._should_skip_first_node(layer) 215 else layer.inbound_nodes for layer in created_layers.values() 216 ]) 217 _insert_ancillary_layers(model, ancillary_layers, metrics_names, new_nodes) 218 return model 219 220 221def _clone_layers_and_model_config(model, input_layers, layer_fn): 222 """Clones all layers, and returns the model config without serializing layers. 223 224 This function ensures that only the node graph is retrieved when getting the 225 model config. The `layer_fn` used to clone layers might not rely on 226 `layer.get_config()`, so some custom layers do not define `get_config`. 227 Trying to retrieve the config results in errors. 228 229 Args: 230 model: A Functional model. 231 input_layers: Dictionary mapping input layers in `model` to new input layers 232 layer_fn: Function used to clone all non-input layers. 233 234 Returns: 235 Model config object, and a dictionary of newly created layers. 236 """ 237 created_layers = {} 238 def _copy_layer(layer): 239 # Whenever the network config attempts to get the layer serialization, 240 # return a dummy dictionary. 241 if layer in input_layers: 242 created_layers[layer.name] = input_layers[layer] 243 elif layer in model._input_layers: 244 created_layers[layer.name] = InputLayer(**layer.get_config()) 245 else: 246 created_layers[layer.name] = layer_fn(layer) 247 return {} 248 249 config = functional.get_network_config( 250 model, serialize_layer_fn=_copy_layer) 251 return config, created_layers 252 253 254def _remove_ancillary_layers(model, layer_map, layers): 255 """Removes and returns any ancillary layers from `layers` based on `model`. 256 257 Ancillary layers are part of the model topology but not used to compute the 258 model outputs, e.g., layers from `add_loss` and `add_metric`. 259 260 Args: 261 model: A Keras Model. 262 layer_map: A map to from layers in the `model` to those in `layers`. 263 layers: A list of all layers. 264 265 Returns: 266 Two lists of layers: (1) `layers` with the ancillary layers removed, and (2) 267 the ancillary layers. 268 """ 269 ancillary_layers = [] # Additional layers for computing losses and metrics. 270 if not model._is_graph_network: 271 return layers, ancillary_layers 272 273 # Ancillary layers are those with depth < 0. 274 depths = [depth for depth in model._nodes_by_depth.keys() if depth < 0] 275 depths.sort(reverse=True) # Order topologically from inputs to outputs. 276 for depth in depths: 277 for node in model._nodes_by_depth[depth]: 278 ancillary_layers.append(layer_map[node.outbound_layer]) 279 280 return [l for l in layers if l not in ancillary_layers], ancillary_layers 281 282 283def _clone_sequential_model(model, input_tensors=None, layer_fn=_clone_layer): 284 """Clone a `Sequential` model instance. 285 286 Model cloning is similar to calling a model on new inputs, 287 except that it creates new layers (and thus new weights) instead 288 of sharing the weights of the existing layers. 289 290 Args: 291 model: Instance of `Sequential`. 292 input_tensors: optional list of input tensors 293 to build the model upon. If not provided, 294 placeholders will be created. 295 layer_fn: callable to be applied on non-input layers in the model. By 296 default it clones the layer. Another example is to preserve the layer 297 to share the weights. This is required when we create a per-replica 298 copy of the model with distribution strategy; we want the weights to 299 be shared but still feed inputs separately so we create new input 300 layers. 301 302 Returns: 303 An instance of `Sequential` reproducing the behavior 304 of the original model, on top of new inputs tensors, 305 using newly instantiated weights. 306 307 Raises: 308 ValueError: in case of invalid `model` argument value or `layer_fn` 309 argument value. 310 """ 311 if not isinstance(model, Sequential): 312 raise ValueError('Expected `model` argument ' 313 'to be a `Sequential` model instance, ' 314 'but got:', model) 315 316 if not callable(layer_fn): 317 raise ValueError('Expected `layer_fn` argument to be a callable.') 318 319 layers = [] # Layers needed to compute the model's outputs. 320 layer_map = {} 321 # Ensure that all layers are cloned. The model's layers 322 # property will exclude the initial InputLayer (if it exists) in the model, 323 # resulting in a different Sequential model structure. 324 for layer in model._flatten_layers(include_self=False, recursive=False): 325 if isinstance(layer, InputLayer) and input_tensors is not None: 326 # If input tensors are provided, the original model's InputLayer is 327 # overwritten with a different InputLayer. 328 continue 329 cloned_layer = ( 330 _clone_layer(layer) 331 if isinstance(layer, InputLayer) else layer_fn(layer)) 332 layers.append(cloned_layer) 333 layer_map[layer] = cloned_layer 334 layers, ancillary_layers = _remove_ancillary_layers(model, layer_map, layers) 335 336 if input_tensors is None: 337 cloned_model = Sequential(layers=layers, name=model.name) 338 elif len(generic_utils.to_list(input_tensors)) != 1: 339 raise ValueError('To clone a `Sequential` model, we expect ' 340 ' at most one tensor ' 341 'as part of `input_tensors`.') 342 else: 343 # Overwrite the original model's input layer. 344 if isinstance(input_tensors, tuple): 345 input_tensors = list(input_tensors) 346 x = generic_utils.to_list(input_tensors)[0] 347 if K.is_keras_tensor(x): 348 origin_layer = x._keras_history.layer 349 if isinstance(origin_layer, InputLayer): 350 cloned_model = Sequential( 351 layers=[origin_layer] + layers, name=model.name) 352 else: 353 raise ValueError('Cannot clone a `Sequential` model on top ' 354 'of a tensor that comes from a Keras layer ' 355 'other than an `InputLayer`. ' 356 'Use the functional API instead.') 357 else: 358 input_tensor = Input(tensor=x, name='input_wrapper_for_' + str(x.name)) 359 input_layer = input_tensor._keras_history.layer 360 cloned_model = Sequential(layers=[input_layer] + layers, name=model.name) 361 362 if not ancillary_layers: 363 return cloned_model 364 365 tensor_map = {} # Maps tensors from `model` to those in `cloned_model`. 366 for depth, cloned_nodes in cloned_model._nodes_by_depth.items(): 367 nodes = model._nodes_by_depth[depth] 368 # This should be safe in a Sequential model. In an arbitrary network, you 369 # need to sort using the outbound layer of the node as a key. 370 for cloned_node, node in zip(cloned_nodes, nodes): 371 if isinstance(cloned_node.output_tensors, list): 372 for j, output_tensor in enumerate(cloned_node.output_tensors): 373 tensor_map[node.output_tensors[j]] = output_tensor 374 else: 375 tensor_map[node.output_tensors] = cloned_node.output_tensors 376 # Ancillary nodes have negative depth. 377 new_nodes = _make_new_nodes( 378 { 379 depth: nodes 380 for depth, nodes in model._nodes_by_depth.items() 381 if depth < 0 382 }, layer_fn, layer_map, tensor_map) 383 _insert_ancillary_layers(cloned_model, ancillary_layers, model.metrics_names, 384 new_nodes) 385 return cloned_model 386 387 388@keras_export('keras.models.clone_model') 389def clone_model(model, input_tensors=None, clone_function=None): 390 """Clone any `Model` instance. 391 392 Model cloning is similar to calling a model on new inputs, 393 except that it creates new layers (and thus new weights) instead 394 of sharing the weights of the existing layers. 395 396 `clone_model` will not preserve the uniqueness of shared objects within the 397 model (e.g. a single variable attached to two distinct layers will be 398 restored as two separate variables). 399 400 Args: 401 model: Instance of `Model` 402 (could be a functional model or a Sequential model). 403 input_tensors: optional list of input tensors or InputLayer objects 404 to build the model upon. If not provided, 405 placeholders will be created. 406 clone_function: Callable to be used to clone each layer in the target 407 model (except `InputLayer` instances). It takes as argument the layer 408 instance to be cloned, and returns the corresponding layer instance to 409 be used in the model copy. If unspecified, this callable defaults to 410 the following serialization/deserialization function: 411 `lambda layer: layer.__class__.from_config(layer.get_config())`. 412 By passing a custom callable, you can customize your copy of the 413 model, e.g. by wrapping certain layers of interest (you might want to 414 replace all `LSTM` instances with equivalent 415 `Bidirectional(LSTM(...))` instances, for example). 416 417 Returns: 418 An instance of `Model` reproducing the behavior 419 of the original model, on top of new inputs tensors, 420 using newly instantiated weights. The cloned model might behave 421 differently from the original model if a custom clone_function 422 modifies the layer. 423 424 Raises: 425 ValueError: in case of invalid `model` argument value. 426 """ 427 with generic_utils.DisableSharedObjectScope(): 428 if clone_function is None: 429 clone_function = _clone_layer 430 431 if isinstance(model, Sequential): 432 return _clone_sequential_model( 433 model, input_tensors=input_tensors, layer_fn=clone_function) 434 else: 435 return _clone_functional_model( 436 model, input_tensors=input_tensors, layer_fn=clone_function) 437 438 439# "Clone" a subclassed model by reseting all of the attributes. 440def _in_place_subclassed_model_reset(model): 441 """Substitute for model cloning that works for subclassed models. 442 443 Subclassed models cannot be cloned because their topology is not serializable. 444 To "instantiate" an identical model in a new TF graph, we reuse the original 445 model object, but we clear its state. 446 447 After calling this function on a model instance, you can use the model 448 instance as if it were a model clone (in particular you can use it in a new 449 graph). 450 451 This method clears the state of the input model. It is thus destructive. 452 However the original state can be restored fully by calling 453 `_in_place_subclassed_model_state_restoration`. 454 455 Args: 456 model: Instance of a Keras model created via subclassing. 457 458 Raises: 459 ValueError: In case the model uses a subclassed model as inner layer. 460 """ 461 assert not model._is_graph_network # Only makes sense for subclassed networks 462 # Select correct base class for new Model. 463 version_utils.swap_class(model.__class__, training.Model, training_v1.Model, 464 ops.executing_eagerly_outside_functions()) 465 # Retrieve all layers tracked by the model as well as their attribute names 466 attributes_cache = {} 467 for name in dir(model): 468 # Skip attrs that track other trackables. 469 if name == 'submodules' or name == '_self_tracked_trackables': 470 continue 471 472 try: 473 value = getattr(model, name) 474 except (AttributeError, ValueError, TypeError): 475 continue 476 if isinstance(value, Layer): 477 attributes_cache[name] = value 478 assert value in model.layers 479 if hasattr(value, 'layers') and value.layers: 480 raise ValueError('We do not support the use of nested layers ' 481 'in `model_to_estimator` at this time. Found nested ' 482 'layer: %s' % value) 483 elif isinstance( 484 value, (list, tuple)) and name not in ('layers', '_layers', 'metrics', 485 '_compile_metric_functions', 486 '_output_loss_metrics'): 487 # Handle case: list/tuple of layers (also tracked by the Network API). 488 if value and all(isinstance(val, Layer) for val in value): 489 raise ValueError('We do not support the use of list-of-layers ' 490 'attributes in subclassed models used with ' 491 '`model_to_estimator` at this time. Found list ' 492 'model: %s' % name) 493 494 # Replace layers on the model with fresh layers 495 layers_to_names = {value: key for key, value in attributes_cache.items()} 496 original_layers = list( 497 model._flatten_layers(include_self=False, recursive=False)) 498 setattr_tracking = model._setattr_tracking 499 model._setattr_tracking = False 500 model._self_tracked_trackables = [] 501 for layer in original_layers: # We preserve layer order. 502 config = layer.get_config() 503 # This will not work for nested subclassed models used as layers. 504 # This would be theoretically possible to support, but would add complexity. 505 # Only do it if users complain. 506 if isinstance(layer, training.Model) and not layer._is_graph_network: 507 raise ValueError('We do not support the use of nested subclassed models ' 508 'in `model_to_estimator` at this time. Found nested ' 509 'model: %s' % layer) 510 fresh_layer = layer.__class__.from_config(config) 511 name = layers_to_names[layer] 512 setattr(model, name, fresh_layer) 513 model._self_tracked_trackables.append(fresh_layer) 514 515 # Cache original model build attributes (in addition to layers) 516 if (not hasattr(model, '_original_attributes_cache') or 517 model._original_attributes_cache is None): 518 if model.built: 519 attributes_to_cache = [ 520 'inputs', 521 'outputs', 522 'total_loss', 523 'optimizer', 524 'train_function', 525 'test_function', 526 'predict_function', 527 '_training_endpoints', 528 '_collected_trainable_weights', 529 '_feed_inputs', 530 '_feed_input_names', 531 '_feed_input_shapes', 532 ] 533 for name in attributes_to_cache: 534 attributes_cache[name] = getattr(model, name) 535 model._original_attributes_cache = attributes_cache 536 _reset_build_compile_trackers(model) 537 model._setattr_tracking = setattr_tracking 538 539 540def _reset_build_compile_trackers(model): 541 """Reset state trackers for model. 542 543 Note that we do not actually zero out attributes such as optimizer, 544 but instead rely on the expectation that all of the attrs will be 545 over-written on calling build/compile/etc. This is somewhat fragile, 546 insofar as we check elsewhere for the presence of these attributes as 547 evidence of having been built/compiled/etc. Pending a better way to do this, 548 we reset key attributes here to allow building and compiling. 549 550 Args: 551 model: the model that is being reset 552 """ 553 # Reset build state 554 model.built = False 555 model.inputs = None 556 model.outputs = None 557 # Reset compile state 558 model._is_compiled = False # pylint:disable=protected-access 559 if not ops.executing_eagerly_outside_functions(): 560 model._v1_compile_was_called = False 561 model.optimizer = None 562 563 564def in_place_subclassed_model_state_restoration(model): 565 """Restores the original state of a model after it was "reset". 566 567 This undoes this action of `_in_place_subclassed_model_reset`, which is called 568 in `clone_and_build_model` if `in_place_reset` is set to True. 569 570 Args: 571 model: Instance of a Keras model created via subclassing, on which 572 `_in_place_subclassed_model_reset` was previously called. 573 """ 574 assert not model._is_graph_network 575 # Restore layers and build attributes 576 if (hasattr(model, '_original_attributes_cache') and 577 model._original_attributes_cache is not None): 578 # Models have sticky attribute assignment, so we want to be careful to add 579 # back the previous attributes and track Layers by their original names 580 # without adding dependencies on "utility" attributes which Models exempt 581 # when they're constructed. 582 setattr_tracking = model._setattr_tracking 583 model._setattr_tracking = False 584 model._self_tracked_trackables = [] 585 for name, value in model._original_attributes_cache.items(): 586 setattr(model, name, value) 587 if isinstance(value, Layer): 588 model._self_tracked_trackables.append(value) 589 model._original_attributes_cache = None 590 model._setattr_tracking = setattr_tracking 591 else: 592 # Restore to the state of a never-called model. 593 _reset_build_compile_trackers(model) 594 595 596def clone_and_build_model( 597 model, input_tensors=None, target_tensors=None, custom_objects=None, 598 compile_clone=True, in_place_reset=False, optimizer_iterations=None, 599 optimizer_config=None): 600 """Clone a `Model` and build/compile it with the same settings used before. 601 602 This function can be run in the same graph or in a separate graph from the 603 model. When using a separate graph, `in_place_reset` must be `False`. 604 605 Note that, currently, the clone produced from this function may not work with 606 TPU DistributionStrategy. Try at your own risk. 607 608 Args: 609 model: `tf.keras.Model` object. Can be Functional, Sequential, or 610 sub-classed. 611 input_tensors: Optional list or dictionary of input tensors to build the 612 model upon. If not provided, placeholders will be created. 613 target_tensors: Optional list of target tensors for compiling the model. If 614 not provided, placeholders will be created. 615 custom_objects: Optional dictionary mapping string names to custom classes 616 or functions. 617 compile_clone: Boolean, whether to compile model clone (default `True`). 618 in_place_reset: Boolean, whether to reset the model in place. Only used if 619 the model is a subclassed model. In the case of a subclassed model, 620 this argument must be set to `True` (default `False`). To restore the 621 original model, use the function 622 `in_place_subclassed_model_state_restoration(model)`. 623 optimizer_iterations: An iterations variable that will be incremented by the 624 optimizer if the clone is compiled. This argument is used when a Keras 625 model is cloned into an Estimator model function, because Estimators 626 create their own global step variable. 627 optimizer_config: Optimizer config dictionary or list of dictionary 628 returned from `get_config()`. This argument should be defined if 629 `clone_and_build_model` is called in a different graph or session from 630 the original model, and the optimizer is an instance of `OptimizerV2`. 631 632 Returns: 633 Clone of the model. 634 635 Raises: 636 ValueError: Cloning fails in the following cases 637 - cloning a subclassed model with `in_place_reset` set to False. 638 - compiling the clone when the original model has not been compiled. 639 """ 640 # Grab optimizer now, as we reset-in-place for subclassed models, but 641 # want to maintain access to the original optimizer. 642 orig_optimizer = model.optimizer 643 if compile_clone and not orig_optimizer: 644 raise ValueError( 645 'Error when cloning model: compile_clone was set to True, but the ' 646 'original model has not been compiled.') 647 648 if compile_clone: 649 compile_args = model._get_compile_args() # pylint: disable=protected-access 650 # Allows this method to be robust to switching graph and eager classes. 651 model._get_compile_args = lambda: compile_args 652 653 with CustomObjectScope(custom_objects or {}): 654 if model._is_graph_network: 655 clone = clone_model(model, input_tensors=input_tensors) 656 elif isinstance(model, Sequential): 657 clone = clone_model(model, input_tensors=input_tensors) 658 if (not clone._is_graph_network and model._build_input_shape is not None): 659 if ops.executing_eagerly_outside_functions(): 660 clone.build(model._build_input_shape) 661 else: 662 clone._set_inputs( 663 K.placeholder( 664 model._build_input_shape, dtype=model.inputs[0].dtype)) 665 else: 666 try: 667 # Prefer cloning the model if serial/deserial logic is implemented for 668 # subclassed model. 669 clone = model.__class__.from_config(model.get_config()) 670 except NotImplementedError: 671 logging.warning('This model is a subclassed model. Please implement ' 672 '`get_config` and `from_config` to better support ' 673 'cloning the model.') 674 if not in_place_reset: 675 raise ValueError( 676 'This model is a subclassed model. ' 677 'Such a model cannot be cloned, but there is a workaround where ' 678 'the model is reset in-place. To use this, please set the ' 679 'argument `in_place_reset` to `True`. This will reset the ' 680 'attributes in the original model. To restore the attributes, ' 681 'call `in_place_subclassed_model_state_restoration(model)`.') 682 clone = model 683 _in_place_subclassed_model_reset(clone) 684 if input_tensors is not None: 685 if isinstance(input_tensors, (list, tuple)) and len(input_tensors) == 1: 686 input_tensors = input_tensors[0] 687 clone._set_inputs(input_tensors) 688 689 if compile_clone: 690 if isinstance(orig_optimizer, optimizer_v1.TFOptimizer): 691 optimizer = optimizer_v1.TFOptimizer( 692 orig_optimizer.optimizer, optimizer_iterations) 693 K.track_tf_optimizer(optimizer) 694 else: 695 if not isinstance(orig_optimizer, (tuple, list)): 696 orig_optimizer = [orig_optimizer] 697 if optimizer_config is None: 698 optimizer = [ 699 opt.__class__.from_config(opt.get_config()) 700 for opt in orig_optimizer 701 ] 702 elif isinstance(optimizer_config, dict): 703 optimizer = [orig_optimizer[0].__class__.from_config(optimizer_config)] 704 else: 705 # optimizer config is list of dict, same order as orig_optimizer. 706 optimizer = [ 707 opt.__class__.from_config(opt_config) 708 for (opt, opt_config) in zip(orig_optimizer, optimizer_config) 709 ] 710 if optimizer_iterations is not None: 711 for opt in optimizer: 712 opt.iterations = optimizer_iterations 713 714 if len(optimizer) == 1: 715 optimizer = optimizer[0] 716 717 compile_args['optimizer'] = optimizer 718 if target_tensors is not None: 719 compile_args['target_tensors'] = target_tensors 720 # Ensure Metric objects in new model are separate from existing model. 721 compile_args['metrics'] = metrics_module.clone_metrics( 722 compile_args['metrics']) 723 compile_args['weighted_metrics'] = metrics_module.clone_metrics( 724 compile_args['weighted_metrics']) 725 clone.compile(**compile_args) 726 727 return clone 728