1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Code for model cloning, plus model-related API entries. 17""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from tensorflow.python.keras import backend as K 23from tensorflow.python.keras import metrics as metrics_module 24from tensorflow.python.keras import optimizers 25from tensorflow.python.keras.engine import sequential 26from tensorflow.python.keras.engine import training 27from tensorflow.python.keras.engine.base_layer import Layer 28from tensorflow.python.keras.engine.input_layer import Input 29from tensorflow.python.keras.engine.input_layer import InputLayer 30from tensorflow.python.keras.engine.network import Network 31from tensorflow.python.keras.saving import hdf5_format 32from tensorflow.python.keras.saving import model_config 33from tensorflow.python.keras.utils import generic_utils 34from tensorflow.python.keras.utils.generic_utils import CustomObjectScope 35from tensorflow.python.util import nest 36from tensorflow.python.util.tf_export import keras_export 37 38 39# API entries importable from `keras.models`: 40Model = training.Model # pylint: disable=invalid-name 41Sequential = sequential.Sequential # pylint: disable=invalid-name 42save_model = hdf5_format.save_model 43load_model = hdf5_format.load_model 44model_from_config = model_config.model_from_config 45model_from_yaml = model_config.model_from_yaml 46model_from_json = model_config.model_from_json 47 48 49def _clone_layer(layer): 50 return layer.__class__.from_config(layer.get_config()) 51 52 53def _clone_functional_model(model, input_tensors=None, share_weights=False): 54 """Clone a functional `Model` instance. 55 56 Model cloning is similar to calling a model on new inputs, 57 except that it creates new layers (and thus new weights) instead 58 of sharing the weights of the existing layers. 59 60 Arguments: 61 model: Instance of `Model`. 62 input_tensors: optional list of input tensors 63 to build the model upon. If not provided, 64 placeholders will be created. 65 share_weights: flag to enable sharing of non-input layers between the 66 cloned and original model. Note this still clones the input layers. 67 This is required when we create a per-replica copy of the model with 68 distribution strategy; we want the weights to be shared but still 69 feed inputs separately so we create new input layers. 70 71 Returns: 72 An instance of `Model` reproducing the behavior 73 of the original model, on top of new inputs tensors, 74 using newly instantiated weights. 75 76 Raises: 77 ValueError: in case of invalid `model` argument value. 78 """ 79 if not isinstance(model, Model): 80 raise ValueError('Expected `model` argument ' 81 'to be a `Model` instance, got ', model) 82 if isinstance(model, Sequential): 83 raise ValueError('Expected `model` argument ' 84 'to be a functional `Model` instance, ' 85 'got a `Sequential` instance instead:', model) 86 87 layer_map = {} # Cache for created layers. 88 tensor_map = {} # Map {reference_tensor: corresponding_tensor} 89 if input_tensors is None: 90 # Create placeholders to build the model on top of. 91 input_tensors = [] 92 for layer in model._input_layers: 93 input_tensor = Input( 94 batch_shape=layer._batch_input_shape, 95 dtype=layer.dtype, 96 sparse=layer.sparse, 97 name=layer.name) 98 input_tensors.append(input_tensor) 99 # Cache newly created input layer. 100 newly_created_input_layer = input_tensor._keras_history[0] 101 layer_map[layer] = newly_created_input_layer 102 else: 103 # Make sure that all input tensors come from a Keras layer. 104 # If tensor comes from an input layer: cache the input layer. 105 input_tensors = nest.flatten(input_tensors) 106 input_tensors_ = [] 107 for i in range(len(input_tensors)): 108 input_tensor = input_tensors[i] 109 if not K.is_keras_tensor(input_tensor): 110 original_input_layer = model._input_layers[i] 111 name = original_input_layer.name 112 input_tensor = Input(tensor=input_tensor, 113 name='input_wrapper_for_' + name) 114 115 input_tensors_.append(input_tensor) 116 # Cache newly created input layer. 117 newly_created_input_layer = input_tensor._keras_history[0] 118 layer_map[original_input_layer] = newly_created_input_layer 119 else: 120 input_tensors_.append(input_tensor) 121 input_tensors = input_tensors_ 122 123 for x, y in zip(model.inputs, input_tensors): 124 tensor_map[x] = y 125 126 # Iterated over every node in the reference model, in depth order. 127 depth_keys = list(model._nodes_by_depth.keys()) 128 depth_keys.sort(reverse=True) 129 for depth in depth_keys: 130 nodes = model._nodes_by_depth[depth] 131 for node in nodes: 132 # Recover the corresponding layer. 133 layer = node.outbound_layer 134 135 # Get or create layer. 136 if layer not in layer_map: 137 if not share_weights: 138 # Clone layer. 139 new_layer = _clone_layer(layer) 140 layer_map[layer] = new_layer 141 layer = new_layer 142 else: 143 # Reuse previously cloned layer. 144 layer = layer_map[layer] 145 # Don't call InputLayer multiple times. 146 if isinstance(layer, InputLayer): 147 continue 148 149 # If all previous input tensors are available in tensor_map, 150 # then call node.inbound_layer on them. 151 if all( 152 tensor in tensor_map for tensor in nest.flatten(node.input_tensors)): 153 computed_tensors = nest.map_structure(lambda t: tensor_map[t], 154 node.input_tensors) 155 # Call layer. 156 kwargs = node.arguments or {} 157 output_tensors = layer(computed_tensors, **kwargs) 158 159 for x, y in zip( 160 nest.flatten(node.output_tensors), nest.flatten(output_tensors)): 161 tensor_map[x] = y 162 163 # Check that we did compute the model outputs, 164 # then instantiate a new model from inputs and outputs. 165 output_tensors = [] 166 for x in model.outputs: 167 assert x in tensor_map, 'Could not compute output ' + str(x) 168 output_tensors.append(tensor_map[x]) 169 170 input_tensors = nest.pack_sequence_as(model._nested_inputs, input_tensors) 171 output_tensors = nest.pack_sequence_as(model._nested_outputs, output_tensors) 172 return Model(input_tensors, output_tensors, name=model.name) 173 174 175def _clone_sequential_model(model, input_tensors=None, share_weights=False): 176 """Clone a `Sequential` model instance. 177 178 Model cloning is similar to calling a model on new inputs, 179 except that it creates new layers (and thus new weights) instead 180 of sharing the weights of the existing layers. 181 182 Arguments: 183 model: Instance of `Sequential`. 184 input_tensors: optional list of input tensors 185 to build the model upon. If not provided, 186 placeholders will be created. 187 share_weights: flag to enable sharing of non-input layers between the 188 cloned and original model. Note this still clones the input layers. 189 This is required when we create a per-replica copy of the model with 190 distribution strategy; we want the weights to be shared but still 191 feed inputs separately so we create new input layers. 192 193 Returns: 194 An instance of `Sequential` reproducing the behavior 195 of the original model, on top of new inputs tensors, 196 using newly instantiated weights. 197 198 Raises: 199 ValueError: in case of invalid `model` argument value. 200 """ 201 if not isinstance(model, Sequential): 202 raise ValueError('Expected `model` argument ' 203 'to be a `Sequential` model instance, ' 204 'but got:', model) 205 206 # Use model._layers to ensure that all layers are cloned. The model's layers 207 # property will exclude the initial InputLayer (if it exists) in the model, 208 # resulting in a different Sequential model structure. 209 if input_tensors is None: 210 if share_weights: 211 # In preserve weights case we still want the input layers to be cloned. 212 layers = [] 213 for layer in model._layers: 214 if isinstance(layer, InputLayer): 215 layers.append(_clone_layer(layer)) 216 else: 217 layers.append(layer) 218 else: 219 layers = [_clone_layer(layer) for layer in model._layers] 220 return Sequential(layers=layers, name=model.name) 221 else: 222 # If input tensors are provided, the original model's InputLayer is 223 # overwritten with a different InputLayer. 224 layers = [ 225 layer for layer in model._layers if not isinstance(layer, InputLayer)] 226 if not share_weights: 227 layers = [_clone_layer(layer) for layer in layers] 228 if len(generic_utils.to_list(input_tensors)) != 1: 229 raise ValueError('To clone a `Sequential` model, we expect ' 230 ' at most one tensor ' 231 'as part of `input_tensors`.') 232 233 if isinstance(input_tensors, tuple): 234 input_tensors = list(input_tensors) 235 x = generic_utils.to_list(input_tensors)[0] 236 if K.is_keras_tensor(x): 237 origin_layer = x._keras_history[0] 238 if isinstance(origin_layer, InputLayer): 239 return Sequential(layers=[origin_layer] + layers, name=model.name) 240 else: 241 raise ValueError('Cannot clone a `Sequential` model on top ' 242 'of a tensor that comes from a Keras layer ' 243 'other than an `InputLayer`. ' 244 'Use the functional API instead.') 245 input_tensor = Input(tensor=x, name='input_wrapper_for_' + str(x.name)) 246 input_layer = input_tensor._keras_history[0] 247 return Sequential(layers=[input_layer] + layers, name=model.name) 248 249 250@keras_export('keras.models.clone_model') 251def clone_model(model, input_tensors=None): 252 """Clone any `Model` instance. 253 254 Model cloning is similar to calling a model on new inputs, 255 except that it creates new layers (and thus new weights) instead 256 of sharing the weights of the existing layers. 257 258 Arguments: 259 model: Instance of `Model` 260 (could be a functional model or a Sequential model). 261 input_tensors: optional list of input tensors or InputLayer objects 262 to build the model upon. If not provided, 263 placeholders will be created. 264 265 Returns: 266 An instance of `Model` reproducing the behavior 267 of the original model, on top of new inputs tensors, 268 using newly instantiated weights. 269 270 Raises: 271 ValueError: in case of invalid `model` argument value. 272 """ 273 if isinstance(model, Sequential): 274 return _clone_sequential_model(model, input_tensors=input_tensors) 275 else: 276 return _clone_functional_model(model, input_tensors=input_tensors) 277 278 279# "Clone" a subclassed model by reseting all of the attributes. 280def _in_place_subclassed_model_reset(model): 281 """Substitute for model cloning that works for subclassed models. 282 283 Subclassed models cannot be cloned because their topology is not serializable. 284 To "instantiate" an identical model in a new TF graph, we reuse the original 285 model object, but we clear its state. 286 287 After calling this function on a model instance, you can use the model 288 instance as if it were a model clone (in particular you can use it in a new 289 graph). 290 291 This method clears the state of the input model. It is thus destructive. 292 However the original state can be restored fully by calling 293 `_in_place_subclassed_model_state_restoration`. 294 295 Args: 296 model: Instance of a Keras model created via subclassing. 297 298 Raises: 299 ValueError: In case the model uses a subclassed model as inner layer. 300 """ 301 assert not model._is_graph_network # Only makes sense for subclassed networks 302 # Retrieve all layers tracked by the model as well as their attribute names 303 attributes_cache = {} 304 for name in dir(model): 305 try: 306 value = getattr(model, name) 307 except (AttributeError, ValueError, TypeError): 308 continue 309 if isinstance(value, Layer): 310 attributes_cache[name] = value 311 assert value in model.layers 312 if hasattr(value, 'layers') and value.layers: 313 raise ValueError('We do not support the use of nested layers ' 314 'in `model_to_estimator` at this time. Found nested ' 315 'layer: %s' % value) 316 elif isinstance( 317 value, (list, tuple)) and name not in ('layers', '_layers', 'metrics', 318 '_compile_metric_functions', 319 '_output_loss_metrics'): 320 # Handle case: list/tuple of layers (also tracked by the Network API). 321 if value and all(isinstance(val, Layer) for val in value): 322 raise ValueError('We do not support the use of list-of-layers ' 323 'attributes in subclassed models used with ' 324 '`model_to_estimator` at this time. Found list ' 325 'model: %s' % name) 326 327 # Replace layers on the model with fresh layers 328 layers_to_names = {value: key for key, value in attributes_cache.items()} 329 original_layers = model._layers[:] 330 setattr_tracking = model._setattr_tracking 331 model._setattr_tracking = False 332 model._layers = [] 333 for layer in original_layers: # We preserve layer order. 334 config = layer.get_config() 335 # This will not work for nested subclassed models used as layers. 336 # This would be theoretically possible to support, but would add complexity. 337 # Only do it if users complain. 338 if isinstance(layer, Network) and not layer._is_graph_network: 339 raise ValueError('We do not support the use of nested subclassed models ' 340 'in `model_to_estimator` at this time. Found nested ' 341 'model: %s' % layer) 342 fresh_layer = layer.__class__.from_config(config) 343 name = layers_to_names[layer] 344 setattr(model, name, fresh_layer) 345 model._layers.append(fresh_layer) 346 347 # Cache original model build attributes (in addition to layers) 348 if (not hasattr(model, '_original_attributes_cache') or 349 model._original_attributes_cache is None): 350 if model.built: 351 attributes_to_cache = [ 352 'inputs', 353 'outputs', 354 '_feed_outputs', 355 '_feed_output_names', 356 '_feed_output_shapes', 357 '_feed_loss_fns', 358 'loss_weights_list', 359 'targets', 360 '_feed_targets', 361 'sample_weight_modes', 362 'total_loss', 363 'sample_weights', 364 '_feed_sample_weights', 365 'train_function', 366 'test_function', 367 'predict_function', 368 '_collected_trainable_weights', 369 '_feed_inputs', 370 '_feed_input_names', 371 '_feed_input_shapes', 372 'optimizer', 373 ] 374 for name in attributes_to_cache: 375 attributes_cache[name] = getattr(model, name) 376 model._original_attributes_cache = attributes_cache 377 _reset_build_compile_trackers(model) 378 model._setattr_tracking = setattr_tracking 379 380 381def _reset_build_compile_trackers(model): 382 """Reset state trackers for model. 383 384 Note that we do not actually zero out attributes such as optimizer, 385 but instead rely on the expectation that all of the attrs will be 386 over-written on calling build/compile/etc. This is somewhat fragile, 387 insofar as we check elsewhere for the presence of these attributes as 388 evidence of having been built/compiled/etc. Pending a better way to do this, 389 we reset key attributes here to allow building and compiling. 390 391 Args: 392 model: the model that is being reset 393 """ 394 # Reset build state 395 model.built = False 396 model.inputs = None 397 model.outputs = None 398 # Reset compile state 399 model._is_compiled = False # pylint:disable=protected-access 400 model.optimizer = None 401 402 403def in_place_subclassed_model_state_restoration(model): 404 """Restores the original state of a model after it was "reset". 405 406 This undoes this action of `_in_place_subclassed_model_reset`, which is called 407 in `clone_and_build_model` if `in_place_reset` is set to True. 408 409 Args: 410 model: Instance of a Keras model created via subclassing, on which 411 `_in_place_subclassed_model_reset` was previously called. 412 """ 413 assert not model._is_graph_network 414 # Restore layers and build attributes 415 if (hasattr(model, '_original_attributes_cache') and 416 model._original_attributes_cache is not None): 417 # Models have sticky attribute assignment, so we want to be careful to add 418 # back the previous attributes and track Layers by their original names 419 # without adding dependencies on "utility" attributes which Models exempt 420 # when they're constructed. 421 setattr_tracking = model._setattr_tracking 422 model._setattr_tracking = False 423 model._layers = [] 424 for name, value in model._original_attributes_cache.items(): 425 setattr(model, name, value) 426 if isinstance(value, Layer): 427 model._layers.append(value) 428 model._original_attributes_cache = None 429 model._setattr_tracking = setattr_tracking 430 else: 431 # Restore to the state of a never-called model. 432 _reset_build_compile_trackers(model) 433 434 435def clone_and_build_model( 436 model, input_tensors=None, target_tensors=None, custom_objects=None, 437 compile_clone=True, in_place_reset=False, optimizer_iterations=None): 438 """Clone a `Model` and build/compile it with the same settings used before. 439 440 This function can be be run in the same graph or in a separate graph from the 441 model. When using a separate graph, `in_place_reset` must be `False`. 442 443 Note that, currently, the clone produced from this function may not work with 444 TPU DistributionStrategy. Try at your own risk. 445 446 Args: 447 model: `tf.keras.Model` object. Can be Functional, Sequential, or 448 sub-classed. 449 input_tensors: Optional list of input tensors to build the model upon. If 450 not provided, placeholders will be created. 451 target_tensors: Optional list of target tensors for compiling the model. If 452 not provided, placeholders will be created. 453 custom_objects: Optional dictionary mapping string names to custom classes 454 or functions. 455 compile_clone: Boolean, whether to compile model clone (default `True`). 456 in_place_reset: Boolean, whether to reset the model in place. Only used if 457 the model is a subclassed model. In the case of a subclassed model, 458 this argument must be set to `True` (default `False`). To restore the 459 original model, use the function 460 `in_place_subclassed_model_state_restoration(model)`. 461 optimizer_iterations: An iterations variable that will be incremented by the 462 optimizer if the clone is compiled. This argument is used when a Keras 463 model is cloned into an Estimator model function, because Estimators 464 create their own global step variable. 465 466 Returns: 467 Clone of the model. 468 469 Raises: 470 ValueError: Cloning fails in the following cases 471 - cloning a subclassed model with `in_place_reset` set to False. 472 - compiling the clone when the original model has not been compiled. 473 """ 474 # Grab optimizer now, as we reset-in-place for subclassed models, but 475 # want to maintain access to the original optimizer. 476 orig_optimizer = model.optimizer 477 if compile_clone and not orig_optimizer: 478 raise ValueError( 479 'Error when cloning model: compile_clone was set to True, but the ' 480 'original model has not been compiled.') 481 482 if model._is_graph_network or isinstance(model, Sequential): 483 if custom_objects: 484 with CustomObjectScope(custom_objects): 485 clone = clone_model(model, input_tensors=input_tensors) 486 else: 487 clone = clone_model(model, input_tensors=input_tensors) 488 489 if all([isinstance(clone, Sequential), 490 not clone._is_graph_network, 491 getattr(model, '_build_input_shape', None) is not None]): 492 # Set model inputs to build the model and add input/output properties. 493 # TODO(kathywu): Add multiple placeholders to handle edge case where 494 # sequential model has multiple inputs. 495 clone._set_inputs( 496 K.placeholder(model._build_input_shape, dtype=model.inputs[0].dtype)) 497 else: 498 if not in_place_reset: 499 raise ValueError( 500 'This model is a subclassed model. ' 501 'Such a model cannot be cloned, but there is a workaround where ' 502 'the model is reset in-place. To use this, please set the argument ' 503 '`in_place_reset` to `True`. This will reset the attributes in the ' 504 'original model. To restore the attributes, call ' 505 '`in_place_subclassed_model_state_restoration(model)`.') 506 clone = model 507 _in_place_subclassed_model_reset(clone) 508 if input_tensors is not None: 509 if isinstance(input_tensors, (list, tuple)) and len(input_tensors) == 1: 510 input_tensors = input_tensors[0] 511 clone._set_inputs(input_tensors) 512 513 if compile_clone: 514 if isinstance(orig_optimizer, optimizers.TFOptimizer): 515 optimizer = optimizers.TFOptimizer( 516 orig_optimizer.optimizer, optimizer_iterations) 517 K.track_tf_optimizer(optimizer) 518 else: 519 optimizer_config = orig_optimizer.get_config() 520 optimizer = orig_optimizer.__class__.from_config(optimizer_config) 521 if optimizer_iterations is not None: 522 optimizer.iterations = optimizer_iterations 523 524 clone.compile( 525 optimizer, 526 model.loss, 527 metrics=metrics_module.clone_metrics(model._compile_metrics), 528 loss_weights=model.loss_weights, 529 sample_weight_mode=model.sample_weight_mode, 530 weighted_metrics=metrics_module.clone_metrics( 531 model._compile_weighted_metrics), 532 target_tensors=target_tensors) 533 534 return clone 535