1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""A `Network` is way to compose layers: the topological form of a `Model`. 17""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import copy 23import json 24import os 25 26from six.moves import zip # pylint: disable=redefined-builtin 27 28from tensorflow.python import pywrap_tensorflow 29from tensorflow.python.eager import context 30from tensorflow.python.framework import errors 31from tensorflow.python.framework import errors_impl 32from tensorflow.python.framework import func_graph 33from tensorflow.python.framework import ops 34from tensorflow.python.framework import tensor_shape 35from tensorflow.python.keras import backend 36from tensorflow.python.keras.engine import base_layer 37from tensorflow.python.keras.engine import base_layer_utils 38from tensorflow.python.keras.engine import training_utils 39from tensorflow.python.keras.mixed_precision.experimental import policy 40from tensorflow.python.keras.saving import hdf5_format 41from tensorflow.python.keras.utils import generic_utils 42from tensorflow.python.keras.utils import layer_utils 43from tensorflow.python.keras.utils import tf_utils 44from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite 45from tensorflow.python.platform import tf_logging as logging 46from tensorflow.python.training import checkpoint_management 47from tensorflow.python.training.tracking import base as trackable 48from tensorflow.python.training.tracking import data_structures 49from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils 50from tensorflow.python.training.tracking import util as trackable_utils 51from tensorflow.python.util import nest 52from tensorflow.python.util import serialization 53from tensorflow.python.util import tf_inspect 54 55 56# pylint: disable=g-import-not-at-top 57try: 58 import h5py 59except ImportError: 60 h5py = None 61 62try: 63 import yaml 64except ImportError: 65 yaml = None 66# pylint: enable=g-import-not-at-top 67 68 69class Network(base_layer.Layer): 70 """A `Network` is a composition of layers. 71 72 `Network` is the topological form of a "model". A `Model` 73 is simply a `Network` with added training routines. 74 75 Two types of `Networks` exist: Graph Networks and Subclass Networks. Graph 76 networks are used in the Keras Functional and Sequential APIs. Subclassed 77 networks are used when a user subclasses the `Model` class. In general, 78 more Keras features are supported with Graph Networks than with Subclassed 79 Networks, specifically: 80 81 - Model cloning (`keras.models.clone`) 82 - Serialization (`model.get_config()/from_config`, `model.to_json()/to_yaml()` 83 - Whole-model saving (`model.save()`) 84 85 A Graph Network can be instantiated by passing two arguments to `__init__`. 86 The first argument is the `keras.Input` Tensors that represent the inputs 87 to the Network. The second argument specifies the output Tensors that 88 represent the outputs of this Network. Both arguments can be a nested 89 structure of Tensors. 90 91 Example: 92 93 ``` 94 inputs = {'x1': keras.Input(shape=(10,)), 'x2': keras.Input(shape=(1,))} 95 t = keras.layers.Dense(1, activation='relu')(inputs['x1']) 96 outputs = keras.layers.Add()([t, inputs['x2']) 97 network = Network(inputs, outputs) 98 ``` 99 100 A Graph Network constructed using the Functional API can also include raw 101 TensorFlow functions, with the exception of functions that create Variables 102 or assign ops. 103 104 Example: 105 106 ``` 107 inputs = keras.Input(shape=(10,)) 108 x = keras.layers.Dense(1)(inputs) 109 outputs = tf.nn.relu(x) 110 network = Network(inputs, outputs) 111 ``` 112 113 Subclassed Networks can be instantiated via `name` and (optional) `dynamic` 114 keyword arguments. Subclassed Networks keep track of their Layers, and their 115 `call` method can be overridden. Subclassed Networks are typically created 116 indirectly, by subclassing the `Model` class. 117 118 Example: 119 120 ``` 121 class MyModel(keras.Model): 122 def __init__(self): 123 super(MyModel, self).__init__(name='my_model', dynamic=False) 124 125 self.layer1 = keras.layers.Dense(10, activation='relu') 126 127 def call(self, inputs): 128 return self.layer1(inputs) 129 ``` 130 """ 131 132 def __init__(self, *args, **kwargs): # pylint: disable=super-init-not-called 133 # Signature detection 134 if (len(args) == 2 or 135 len(args) == 1 and 'outputs' in kwargs or 136 'inputs' in kwargs and 'outputs' in kwargs): 137 # Graph network 138 self._init_graph_network(*args, **kwargs) 139 else: 140 # Subclassed network 141 self._init_subclassed_network(**kwargs) 142 143 # Several Network methods have "no_automatic_dependency_tracking" 144 # annotations. Since Network does automatic dependency tracking on attribute 145 # assignment, including for common data structures such as lists, by default 146 # we'd have quite a few empty dependencies which users don't care about (or 147 # would need some way to ignore dependencies automatically, which is confusing 148 # when applied to user code). Some attributes, such as _layers, would cause 149 # structural issues (_layers being the place where Layers assigned to tracked 150 # attributes are stored). 151 # 152 # Aside from these aesthetic and structural issues, useless dependencies on 153 # empty lists shouldn't cause issues; adding or removing them will not break 154 # checkpoints, but may cause "all Python objects matched" assertions to fail 155 # (in which case less strict assertions may be substituted if necessary). 156 @trackable.no_automatic_dependency_tracking 157 def _base_init(self, name=None): 158 # The following are implemented as property functions: 159 # self.trainable_weights 160 # self.non_trainable_weights 161 # self.input_spec 162 # self.losses 163 # self.updates 164 165 self._init_set_name(name, zero_based=True) 166 self._activity_regularizer = None 167 # This acts just like the `trainable` attribute of any layer instance. 168 # It does not affect users of the underlying layers, only users of the 169 # Network instance. 170 self.trainable = True 171 self._is_compiled = False 172 self._expects_training_arg = False 173 174 # This is True for Sequential networks and Functional networks. 175 self._compute_output_and_mask_jointly = False 176 177 self.supports_masking = False 178 if not hasattr(self, 'optimizer'): 179 # Don't reset optimizer if already set. 180 self.optimizer = None 181 182 # Private attributes to implement compatibility with Layer. 183 self._trainable_weights = [] 184 self._non_trainable_weights = [] 185 self._updates = [] # Used in symbolic mode only. 186 self._losses = [] 187 self._eager_losses = [] 188 # A list of metric instances corresponding to the symbolic metric tensors 189 # added using the `add_metric` API. 190 self._metrics = [] 191 # A dictionary that maps metric names to metric result tensors. 192 self._metrics_tensors = {} 193 self._scope = None # Never used. 194 self._reuse = None # Never used. 195 if context.executing_eagerly(): 196 self._graph = None 197 else: 198 self._graph = ops.get_default_graph() # Used in symbolic mode only. 199 # A Network does not create weights of its own, thus has no dtype. 200 self._dtype = None 201 202 # All layers in order of horizontal graph traversal. 203 # Entries are unique. Includes input and output layers. 204 self._layers = [] 205 206 # Used in symbolic mode only, only in conjunction with graph-networks 207 self._outbound_nodes = [] 208 self._inbound_nodes = [] 209 210 self._trackable_saver = ( 211 trackable_utils.saver_with_op_caching(self)) 212 213 # Networks do not need to do any casting of inputs or variables, because 214 # each of its layers will handle casting through the layer's own 215 # implementation. Therefore networks use the 'infer' policy, which does no 216 # casting. 217 self._mixed_precision_policy = policy.Policy('infer') 218 219 @trackable.no_automatic_dependency_tracking 220 def _init_graph_network(self, inputs, outputs, name=None): 221 self._call_convention = (base_layer_utils 222 .CallConvention.EXPLICIT_INPUTS_ARGUMENT) 223 # Normalize and set self.inputs, self.outputs. 224 if isinstance(inputs, list) and len(nest.flatten(inputs)) == 1: 225 inputs = inputs[0] 226 if isinstance(outputs, list) and len(nest.flatten(outputs)) == 1: 227 outputs = outputs[0] 228 self._nested_outputs = outputs 229 self._nested_inputs = inputs 230 self.inputs = nest.flatten(inputs) 231 self.outputs = nest.flatten(outputs) 232 233 if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs): 234 base_layer_utils.create_keras_history(self._nested_outputs) 235 236 self._base_init(name=name) 237 self._validate_graph_inputs_and_outputs() 238 239 self._compute_previous_mask = ( 240 'mask' in tf_inspect.getfullargspec(self.call).args or 241 hasattr(self, 'compute_mask')) 242 # A Network does not create weights of its own, thus it is already 243 # built. 244 self.built = True 245 self._compute_output_and_mask_jointly = True 246 self._is_graph_network = True 247 self._dynamic = False 248 # `_expects_training_arg` is True since the `training` argument is always 249 # present in the signature of the `call` method of a graph network. 250 self._expects_training_arg = True 251 252 self._input_layers = [] 253 self._output_layers = [] 254 self._input_coordinates = [] 255 self._output_coordinates = [] 256 257 # This is for performance optimization when calling the Network on new 258 # inputs. Every time the Network is called on a set on input tensors, 259 # we compute the output tensors, output masks and output shapes in one pass, 260 # then cache them here. When any of these outputs is queried later, we 261 # retrieve it from there instead of recomputing it. 262 self._output_mask_cache = {} 263 self._output_tensor_cache = {} 264 self._output_shape_cache = {} 265 266 # Build self._output_layers: 267 for x in self.outputs: 268 layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access 269 self._output_layers.append(layer) 270 self._output_coordinates.append((layer, node_index, tensor_index)) 271 272 # Build self._input_layers: 273 for x in self.inputs: 274 layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access 275 # It's supposed to be an input layer, so only one node 276 # and one tensor output. 277 assert node_index == 0 278 assert tensor_index == 0 279 self._input_layers.append(layer) 280 self._input_coordinates.append((layer, node_index, tensor_index)) 281 282 # Keep track of the network's nodes and layers. 283 nodes, nodes_by_depth, layers, layers_by_depth = _map_graph_network( 284 self.inputs, self.outputs) 285 self._network_nodes = nodes 286 self._nodes_by_depth = nodes_by_depth 287 self._layers = layers 288 self._layers_by_depth = layers_by_depth 289 self._layer_call_argspecs = {} 290 for layer in self._layers: 291 self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call) 292 293 self._track_layers(layers) 294 295 # Create the node linking internal inputs to internal outputs. 296 base_layer.Node( 297 outbound_layer=self, 298 inbound_layers=[], 299 node_indices=[], 300 tensor_indices=[], 301 input_tensors=self._nested_inputs, 302 output_tensors=self._nested_outputs) 303 304 # Build self.input_names and self.output_names. 305 self.input_names = [] 306 self.output_names = [] 307 self._feed_input_names = [] 308 self._feed_inputs = [] 309 self._feed_input_shapes = [] 310 for i, layer in enumerate(self._input_layers): 311 self.input_names.append(layer.name) 312 if layer.is_placeholder: 313 self._feed_input_names.append(layer.name) 314 self._feed_input_shapes.append(backend.int_shape(self.inputs[i])) 315 self._feed_inputs.append(layer.input) 316 for layer in self._output_layers: 317 self.output_names.append(layer.name) 318 319 @trackable.no_automatic_dependency_tracking 320 def _init_subclassed_network(self, name=None, dynamic=False): 321 self._base_init(name=name) 322 self._is_graph_network = False 323 self._dynamic = dynamic 324 call_argspec = tf_inspect.getfullargspec(self.call) 325 if 'training' in call_argspec.args: 326 self._expects_training_arg = True 327 else: 328 self._expects_training_arg = False 329 self._call_convention = self._determine_call_convention(call_argspec) 330 self.outputs = [] 331 self.inputs = [] 332 self.built = False 333 334 @property 335 def dynamic(self): 336 if self._is_graph_network: 337 return any(layer.dynamic for layer in self.layers) 338 return self._dynamic or any(layer.dynamic for layer in self.layers) 339 340 def _determine_call_convention(self, call_argspec): 341 """Decides how `self.call()` is invoked. See `CallConvention`.""" 342 if call_argspec.varargs: 343 may_take_single_argument = False 344 else: 345 try: 346 # Note: tf_inspect doesn't raise a TypeError when regular inspect would, 347 # so we need to keep in mind that "getcallargs" may have returned 348 # something even though we under-specified positional arguments. 349 all_args = tf_inspect.getcallargs(self.call, None) 350 self_args = set() 351 for arg_name, obj in all_args.items(): 352 if obj is self: 353 self_args.add(arg_name) 354 may_take_single_argument = True 355 except TypeError: 356 may_take_single_argument = False 357 if may_take_single_argument: 358 # A single positional argument (plus "self") is considered equivalent to 359 # an "inputs" argument. 360 all_positional_args = len(call_argspec.args) 361 if call_argspec.defaults is not None: 362 all_positional_args -= len(call_argspec.defaults) 363 non_self_positional_args = all_positional_args 364 for positional_arg_name in call_argspec.args[:all_positional_args]: 365 if positional_arg_name in self_args: 366 non_self_positional_args -= 1 367 if non_self_positional_args == 1: 368 if 'inputs' in call_argspec.args[all_positional_args:]: 369 raise TypeError( 370 "Model.call() takes a single positional argument (to which " 371 "inputs are passed by convention) and a separate 'inputs' " 372 "argument. Unable to determine which arguments are inputs.") 373 return base_layer_utils.CallConvention.SINGLE_POSITIONAL_ARGUMENT 374 if 'inputs' in call_argspec.args: 375 return base_layer_utils.CallConvention.EXPLICIT_INPUTS_ARGUMENT 376 else: 377 return base_layer_utils.CallConvention.POSITIONAL_ARGUMENTS_ARE_INPUTS 378 379 def _track_layers(self, layers): 380 """Add Trackable dependencies on a list of Layers.""" 381 weight_layer_index = 0 382 for layer_index, layer in enumerate(layers): 383 if layer.weights: 384 # Keep a separate index for layers which have weights. This allows users 385 # to insert Layers without weights anywhere in the network without 386 # breaking checkpoints. 387 self._track_trackable( 388 layer, name='layer_with_weights-%d' % weight_layer_index, 389 overwrite=True) 390 weight_layer_index += 1 391 # Even if it doesn't have weights, we should still track everything in 392 # case it has/will have Trackable dependencies. 393 self._track_trackable( 394 layer, name='layer-%d' % layer_index, overwrite=True) 395 396 def __setattr__(self, name, value): 397 if not getattr(self, '_setattr_tracking', True): 398 super(Network, self).__setattr__(name, value) 399 return 400 401 if all( 402 isinstance(v, (base_layer.Layer, 403 data_structures.TrackableDataStructure)) or 404 trackable_layer_utils.has_weights(v) for v in nest.flatten(value)): 405 try: 406 self._is_graph_network 407 except AttributeError: 408 raise RuntimeError('It looks like you are subclassing `Model` and you ' 409 'forgot to call `super(YourClass, self).__init__()`.' 410 ' Always start with this line.') 411 412 super(Network, self).__setattr__(name, value) 413 414 # Keep track of metric instance created in subclassed model/layer. 415 # We do this so that we can maintain the correct order of metrics by adding 416 # the instance to the `metrics` list as soon as it is created. 417 from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top 418 if isinstance(value, metrics_module.Metric): 419 self._metrics.append(value) 420 421 @property 422 def stateful(self): 423 return any((hasattr(layer, 'stateful') and layer.stateful) 424 for layer in self.layers) 425 426 def reset_states(self): 427 for layer in self.layers: 428 if hasattr(layer, 'reset_states') and getattr(layer, 'stateful', False): 429 layer.reset_states() 430 431 @property 432 def state_updates(self): 433 """Returns the `updates` from all layers that are stateful. 434 435 This is useful for separating training updates and 436 state updates, e.g. when we need to update a layer's internal state 437 during prediction. 438 439 Returns: 440 A list of update ops. 441 """ 442 state_updates = [] 443 for layer in self.layers: 444 if getattr(layer, 'stateful', False): 445 if hasattr(layer, 'updates'): 446 state_updates += layer.updates 447 return state_updates 448 449 def get_weights(self): 450 """Retrieves the weights of the model. 451 452 Returns: 453 A flat list of Numpy arrays. 454 """ 455 weights = [] 456 for layer in self.layers: 457 weights += layer.weights 458 return backend.batch_get_value(weights) 459 460 def set_weights(self, weights): 461 """Sets the weights of the model. 462 463 Arguments: 464 weights: A list of Numpy arrays with shapes and types matching 465 the output of `model.get_weights()`. 466 """ 467 tuples = [] 468 for layer in self.layers: 469 num_param = len(layer.weights) 470 layer_weights = weights[:num_param] 471 for sw, w in zip(layer.weights, layer_weights): 472 tuples.append((sw, w)) 473 weights = weights[num_param:] 474 backend.batch_set_value(tuples) 475 476 def compute_mask(self, inputs, mask): 477 if not self._is_graph_network: 478 return None 479 480 # TODO(omalleyt): b/123540974 This function is not really safe to call 481 # by itself because it will duplicate any updates and losses in graph 482 # mode by `call`ing the Layers again. 483 output_tensors = self._run_internal_graph(inputs, mask=mask) 484 return nest.map_structure(lambda t: t._keras_mask, output_tensors) 485 486 @property 487 def layers(self): 488 return trackable_layer_utils.filter_empty_layer_containers( 489 self._layers) 490 491 def get_layer(self, name=None, index=None): 492 """Retrieves a layer based on either its name (unique) or index. 493 494 If `name` and `index` are both provided, `index` will take precedence. 495 Indices are based on order of horizontal graph traversal (bottom-up). 496 497 Arguments: 498 name: String, name of layer. 499 index: Integer, index of layer. 500 501 Returns: 502 A layer instance. 503 504 Raises: 505 ValueError: In case of invalid layer name or index. 506 """ 507 # TODO(fchollet): We could build a dictionary based on layer names 508 # since they are constant, but we have not done that yet. 509 if index is not None: 510 if len(self.layers) <= index: 511 raise ValueError('Was asked to retrieve layer at index ' + str(index) + 512 ' but model only has ' + str(len(self.layers)) + 513 ' layers.') 514 else: 515 return self.layers[index] 516 else: 517 if not name: 518 raise ValueError('Provide either a layer name or layer index.') 519 for layer in self.layers: 520 if layer.name == name: 521 return layer 522 raise ValueError('No such layer: ' + name) 523 524 def _get_unfiltered_updates(self, check_trainable=True): 525 if check_trainable and not self.trainable and not self.stateful: 526 return [] 527 updates = [] 528 for layer in self.layers: 529 updates += layer._get_unfiltered_updates(check_trainable=check_trainable) 530 updates += list(self._updates) 531 return updates 532 533 @property 534 def _unfiltered_losses(self): 535 losses = [] 536 537 # If any eager losses are present, we assume the model to be part of an 538 # eager training loop (either a custom one or the one used when 539 # `run_eagerly=True`), and so we always return just the eager losses in that 540 # case. 541 if self._eager_losses: 542 losses.extend(self._eager_losses) 543 else: 544 losses.extend(self._losses) 545 for layer in self.layers: 546 if isinstance(layer, Network): 547 losses += layer._unfiltered_losses 548 else: 549 losses += layer.losses 550 return losses 551 552 @trackable.no_automatic_dependency_tracking 553 def _clear_losses(self): 554 """Used every step in eager to reset losses.""" 555 self._eager_losses = [] 556 for layer in self.layers: 557 layer._clear_losses() 558 559 @property 560 def updates(self): 561 """Retrieves the network's updates. 562 563 Will only include updates that are either 564 unconditional, or conditional on inputs to this model 565 (e.g. will not include updates that were created by layers of this model 566 outside of the model). 567 568 When the network has no registered inputs, all updates are returned. 569 570 Effectively, `network.updates` behaves like `layer.updates`. 571 572 Concrete example: 573 574 ```python 575 bn = keras.layers.BatchNormalization() 576 x1 = keras.layers.Input(shape=(10,)) 577 _ = bn(x1) # This creates 2 updates. 578 579 x2 = keras.layers.Input(shape=(10,)) 580 y2 = bn(x2) # This creates 2 more updates. 581 582 # The BN layer has now 4 updates. 583 self.assertEqual(len(bn.updates), 4) 584 585 # Let's create a model from x2 to y2. 586 model = keras.models.Model(x2, y2) 587 588 # The model does not list all updates from its underlying layers, 589 # but only the updates that are relevant to it. Updates created by layers 590 # outside of the model are discarded. 591 self.assertEqual(len(model.updates), 2) 592 593 # If you keep calling the model, you append to its updates, just like 594 # what happens for a layer. 595 x3 = keras.layers.Input(shape=(10,)) 596 y3 = model(x3) 597 self.assertEqual(len(model.updates), 4) 598 599 # But if you call the inner BN layer independently, you don't affect 600 # the model's updates. 601 x4 = keras.layers.Input(shape=(10,)) 602 _ = bn(x4) 603 self.assertEqual(len(model.updates), 4) 604 ``` 605 606 Returns: 607 A list of update ops. 608 """ 609 610 updates = self._get_unfiltered_updates(check_trainable=True) 611 612 # `updates` might contain irrelevant updates, so it needs to be filtered 613 # with respect to inputs the model has been called on. 614 relevant_inputs = [] 615 for i in range(0, len(self._inbound_nodes)): 616 inputs = self.get_input_at(i) 617 if isinstance(inputs, list): 618 relevant_inputs += inputs 619 else: 620 relevant_inputs.append(inputs) 621 if not relevant_inputs: 622 return list(set(updates)) 623 624 reachable = tf_utils.get_reachable_from_inputs(relevant_inputs, updates) 625 relevant_conditional_updates = [x for x in updates if x in reachable] 626 unconditional_updates = [ 627 x for x in updates if x._unconditional_update] # pylint: disable=protected-access 628 # A layer could be used multiple times in a nested structure, 629 # so the updates list must be de-duped. 630 return list(set(relevant_conditional_updates + unconditional_updates)) 631 632 @property 633 def losses(self): 634 """Retrieves the network's losses. 635 636 Will only include losses that are either 637 unconditional, or conditional on inputs to this model 638 (e.g. will not include losses that depend on tensors 639 that aren't inputs to this model). 640 641 When the network has no registered inputs, all losses are returned. 642 643 Returns: 644 A list of loss tensors. 645 """ 646 losses = self._unfiltered_losses 647 648 if context.executing_eagerly(): 649 return losses 650 651 # TODO(kaftan/fchollet): Clean this up / make it obsolete. 652 # This is a super ugly, confusing check necessary to 653 # handle the case where we are executing in a function graph in eager mode 654 # but the model was constructed symbolically in a separate graph scope. 655 # We need to capture the losses created in the current graph function, 656 # and filter out the incorrect loss tensors created when symbolically 657 # building the graph. 658 # We have to use this check because the code after it that checks 659 # for reachable inputs only captures the part of the model that was 660 # built symbolically, and captures the wrong tensors from a different 661 # func graph (causing a crash later on when trying to execute the 662 # graph function) 663 with ops.init_scope(): 664 if context.executing_eagerly(): 665 return [loss for loss in losses 666 if loss.graph == ops.get_default_graph()] 667 668 relevant_inputs = [] 669 for i in range(0, len(self._inbound_nodes)): 670 inputs = self.get_input_at(i) 671 if isinstance(inputs, list): 672 relevant_inputs += inputs 673 else: 674 relevant_inputs.append(inputs) 675 if not relevant_inputs: 676 return losses 677 678 reachable = tf_utils.get_reachable_from_inputs(relevant_inputs, losses) 679 relevant_conditional_losses = [x for x in losses if x in reachable] 680 unconditional_losses = [ 681 x for x in losses if x._unconditional_loss] # pylint: disable=protected-access 682 return list(set( 683 relevant_conditional_losses + unconditional_losses + self._losses)) 684 685 @property 686 def trainable_weights(self): 687 return trackable_layer_utils.gather_trainable_weights( 688 trainable=self.trainable, 689 sub_layers=self._layers, 690 extra_variables=self._trainable_weights) 691 692 @property 693 def non_trainable_weights(self): 694 return trackable_layer_utils.gather_non_trainable_weights( 695 trainable=self.trainable, 696 sub_layers=self._layers, 697 extra_variables=self._non_trainable_weights + self._trainable_weights) 698 699 @property 700 def metrics(self): 701 """Returns the network's symbolic metrics. 702 703 Model overrides this function to include the metrics from `compile` API. 704 """ 705 metrics = [] 706 for layer in self.layers: 707 metrics += layer._metrics # pylint: disable=protected-access 708 return metrics + self._metrics 709 710 @property 711 def _all_metrics_tensors(self): 712 """Returns the network's symbolic metric tensors.""" 713 # TODO(psv): Remove this property. 714 metrics_tensors = {} 715 for layer in self.layers: 716 if isinstance(layer, Network): 717 metrics_tensors.update(layer._all_metrics_tensors) 718 else: 719 metrics_tensors.update(layer._metrics_tensors) 720 metrics_tensors.update(self._metrics_tensors) 721 return metrics_tensors 722 723 @property 724 def input_spec(self): 725 """Gets the network's input specs. 726 727 Returns: 728 A list of `InputSpec` instances (one per input to the model) 729 or a single instance if the model has only one input. 730 """ 731 # If subclassed model, can't assume anything. 732 if not self._is_graph_network: 733 return None 734 735 specs = [] 736 for layer in self._input_layers: 737 if layer.input_spec is None: 738 specs.append(None) 739 else: 740 if not isinstance(layer.input_spec, list): 741 raise TypeError('Layer ' + layer.name + 742 ' has an input_spec attribute that ' 743 'is not a list. We expect a list. ' 744 'Found input_spec = ' + str(layer.input_spec)) 745 specs += layer.input_spec 746 if len(specs) == 1: 747 return specs[0] 748 return specs 749 750 @base_layer.default 751 def build(self, input_shape): 752 """Builds the model based on input shapes received. 753 754 This is to be used for subclassed models, which do not know at instantiation 755 time what their inputs look like. 756 757 This method only exists for users who want to call `model.build()` in a 758 standalone way (as a substitute for calling the model on real data to 759 build it). It will never be called by the framework (and thus it will 760 never throw unexpected errors in an unrelated workflow). 761 762 Args: 763 input_shape: Single tuple, TensorShape, or list of shapes, where shapes 764 are tuples, integers, or TensorShapes. 765 766 Raises: 767 ValueError: 768 1. In case of invalid user-provided data (not of type tuple, 769 list, or TensorShape). 770 2. If the model requires call arguments that are agnostic 771 to the input shapes (positional or kwarg in call signature). 772 3. If not all layers were properly built. 773 4. If float type inputs are not supported within the layers. 774 775 In each of these cases, the user should build their model by calling it 776 on real tensor data. 777 """ 778 if self._is_graph_network: 779 self.built = True 780 return 781 782 # If subclass network 783 if input_shape is None: 784 raise ValueError('Input shape must be defined when calling build on a ' 785 'model subclass network.') 786 valid_types = (tuple, list, tensor_shape.TensorShape) 787 if not isinstance(input_shape, valid_types): 788 raise ValueError('Specified input shape is not one of the valid types. ' 789 'Please specify a batch input shape of type tuple or ' 790 'list of input shapes. User provided ' 791 'input type: {}'.format(type(input_shape))) 792 793 if input_shape and not self.inputs: 794 # We create placeholders for the `None`s in the shape and build the model 795 # in a Graph. Since tf.Variable is compatible with both eager execution 796 # and graph building, the variables created after building the model in 797 # a Graph are still valid when executing eagerly. 798 if context.executing_eagerly(): 799 graph = func_graph.FuncGraph('build_graph') 800 else: 801 graph = backend.get_graph() 802 with graph.as_default(): 803 if isinstance(input_shape, list): 804 x = [base_layer_utils.generate_placeholders_from_shape(shape) 805 for shape in input_shape] 806 else: 807 x = base_layer_utils.generate_placeholders_from_shape(input_shape) 808 809 kwargs = {} 810 call_signature = tf_inspect.getfullargspec(self.call) 811 call_args = call_signature.args 812 # Exclude `self`, `inputs`, and any argument with a default value. 813 if len(call_args) > 2: 814 if call_signature.defaults: 815 call_args = call_args[2:-len(call_signature.defaults)] 816 else: 817 call_args = call_args[2:] 818 for arg in call_args: 819 if arg == 'training': 820 # Case where `training` is a positional arg with no default. 821 kwargs['training'] = False 822 else: 823 # Has invalid call signature with unknown positional arguments. 824 raise ValueError( 825 'Currently, you cannot build your model if it has ' 826 'positional or keyword arguments that are not ' 827 'inputs to the model, but are required for its ' 828 '`call` method. Instead, in order to instantiate ' 829 'and build your model, `call` your model on real ' 830 'tensor data with all expected call arguments.') 831 elif len(call_args) < 2: 832 # Signature without `inputs`. 833 raise ValueError('You can only call `build` on a model if its `call` ' 834 'method accepts an `inputs` argument.') 835 try: 836 self.call(x, **kwargs) 837 except (errors.InvalidArgumentError, TypeError): 838 raise ValueError('You cannot build your model by calling `build` ' 839 'if your layers do not support float type inputs. ' 840 'Instead, in order to instantiate and build your ' 841 'model, `call` your model on real tensor data (of ' 842 'the correct dtype).') 843 if self._layers: 844 self._track_layers(self._layers) 845 self.built = True 846 847 def call(self, inputs, training=None, mask=None): 848 """Calls the model on new inputs. 849 850 In this case `call` just reapplies 851 all ops in the graph to the new inputs 852 (e.g. build a new computational graph from the provided inputs). 853 854 Arguments: 855 inputs: A tensor or list of tensors. 856 training: Boolean or boolean scalar tensor, indicating whether to run 857 the `Network` in training mode or inference mode. 858 mask: A mask or list of masks. A mask can be 859 either a tensor or None (no mask). 860 861 Returns: 862 A tensor if there is a single output, or 863 a list of tensors if there are more than one outputs. 864 """ 865 if not self._is_graph_network: 866 raise NotImplementedError('When subclassing the `Model` class, you should' 867 ' implement a `call` method.') 868 869 return self._run_internal_graph(inputs, training=training, mask=mask) 870 871 def compute_output_shape(self, input_shape): 872 if not self._is_graph_network: 873 return super(Network, self).compute_output_shape(input_shape) 874 875 # Convert any shapes in tuple format to TensorShapes. 876 input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False) 877 878 if len(nest.flatten(input_shape)) != len(nest.flatten(self._input_layers)): 879 raise ValueError('Invalid input_shape argument ' + str(input_shape) + 880 ': model has ' + str(len(self._input_layers)) + 881 ' tensor inputs.') 882 883 cache_key = generic_utils.object_list_uid(input_shape) 884 if cache_key in self._output_shape_cache: 885 # Cache hit. Return shapes as TensorShapes. 886 return self._output_shape_cache[cache_key] 887 888 layers_to_output_shapes = {} 889 for layer, shape in zip(self._input_layers, nest.flatten(input_shape)): 890 # It's an input layer: then `compute_output_shape` is identity, 891 # and there is only one node and one tensor.. 892 shape_key = layer.name + '_0_0' 893 layers_to_output_shapes[shape_key] = shape 894 895 depth_keys = list(self._nodes_by_depth.keys()) 896 depth_keys.sort(reverse=True) 897 # Iterate over nodes, by depth level. 898 if len(depth_keys) > 1: 899 for depth in depth_keys: 900 nodes = self._nodes_by_depth[depth] 901 for node in nodes: 902 # This is always a single layer, never a list. 903 layer = node.outbound_layer 904 if layer in self._input_layers: 905 # We've already covered the input layers 906 # a few lines above. 907 continue 908 # Potentially redundant list, 909 # same size as node.input_tensors. 910 layer_input_shapes = [] 911 for inbound_layer, node_id, tensor_id, _ in node.iterate_inbound(): 912 input_layer_key = inbound_layer.name + '_%s_%s' % (node_id, 913 tensor_id) 914 layer_input_shapes.append(layers_to_output_shapes[input_layer_key]) 915 layer_input_shapes = nest.pack_sequence_as(node.inbound_layers, 916 layer_input_shapes) 917 # Layers expect shapes to be tuples for `compute_output_shape`. 918 layer_input_shapes = tf_utils.convert_shapes( 919 layer_input_shapes, to_tuples=True) 920 layer_output_shapes = layer.compute_output_shape(layer_input_shapes) 921 # Convert back to TensorShapes. 922 layer_output_shapes = tf_utils.convert_shapes( 923 layer_output_shapes, to_tuples=False) 924 925 node_index = layer._inbound_nodes.index(node) # pylint: disable=protected-access 926 for j, shape in enumerate(nest.flatten(layer_output_shapes)): 927 shape_key = layer.name + '_%s_%s' % (node_index, j) 928 layers_to_output_shapes[shape_key] = shape 929 930 # Read final output shapes from layers_to_output_shapes. 931 output_shapes = [] 932 for i in range(len(self._output_layers)): 933 layer, node_index, tensor_index = self._output_coordinates[i] 934 shape_key = layer.name + '_%s_%s' % (node_index, tensor_index) 935 output_shapes.append(layers_to_output_shapes[shape_key]) 936 output_shapes = nest.pack_sequence_as(self._nested_outputs, output_shapes) 937 # Store in cache. 938 self._output_shape_cache[cache_key] = output_shapes 939 940 # Return shapes as TensorShapes. 941 return output_shapes 942 943 def _run_internal_graph(self, inputs, training=None, mask=None): 944 """Computes output tensors for new inputs. 945 946 # Note: 947 - Expects `inputs` to be a list (potentially with 1 element). 948 - Can be run on non-Keras tensors. 949 950 Arguments: 951 inputs: Tensor or nested structure of Tensors. 952 training: Boolean learning phase. 953 mask: (Optional) Tensor or nested structure of Tensors. 954 955 Returns: 956 Two lists: output_tensors, output_masks 957 """ 958 # Note: masking support is relevant mainly for Keras. 959 # It cannot be factored out without having the fully reimplement the network 960 # calling logic on the Keras side. We choose to incorporate it in 961 # Network because 1) it may be useful to fully support in tf.layers in 962 # the future and 2) Keras is a major user of Network. If you don't 963 # use masking, it does not interfere with regular behavior at all and you 964 # can ignore it. 965 inputs = nest.flatten(inputs) 966 if mask is None: 967 masks = [None for _ in range(len(inputs))] 968 else: 969 masks = nest.flatten(mask) 970 971 for input_t, mask in zip(inputs, masks): 972 input_t._keras_mask = mask 973 974 # Dictionary mapping reference tensors to computed tensors. 975 tensor_dict = {} 976 977 for x, y, mask in zip(self.inputs, inputs, masks): 978 tensor_dict[str(id(x))] = y 979 980 depth_keys = list(self._nodes_by_depth.keys()) 981 depth_keys.sort(reverse=True) 982 # Ignore the InputLayers when computing the graph. 983 depth_keys = depth_keys[1:] 984 985 for depth in depth_keys: 986 nodes = self._nodes_by_depth[depth] 987 for node in nodes: 988 # This is always a single layer, never a list. 989 layer = node.outbound_layer 990 991 if all( 992 str(id(tensor)) in tensor_dict 993 for tensor in nest.flatten(node.input_tensors)): 994 995 # Call layer (reapplying ops to new inputs). 996 computed_tensors = nest.map_structure( 997 lambda t: tensor_dict[str(id(t))], node.input_tensors) 998 999 # Ensure `training` and `mask` arg propagation if applicable. 1000 kwargs = node.arguments or {} 1001 argspec = self._layer_call_argspecs[layer].args 1002 if 'training' in argspec: 1003 kwargs.setdefault('training', training) 1004 if 'mask' in argspec: 1005 computed_masks = nest.map_structure(lambda t: t._keras_mask, 1006 computed_tensors) 1007 kwargs.setdefault('mask', computed_masks) 1008 1009 # Compute outputs. 1010 output_tensors = layer(computed_tensors, **kwargs) 1011 1012 # Update tensor_dict. 1013 for x, y in zip( 1014 nest.flatten(node.output_tensors), nest.flatten(output_tensors)): 1015 tensor_dict[str(id(x))] = y 1016 1017 output_tensors = [] 1018 output_shapes = [] 1019 for x in self.outputs: 1020 assert str(id(x)) in tensor_dict, 'Could not compute output ' + str(x) 1021 tensor = tensor_dict[str(id(x))] 1022 output_shapes.append(x.shape) 1023 output_tensors.append(tensor) 1024 1025 if output_shapes is not None: 1026 input_shapes = [x.shape for x in inputs] 1027 cache_key = generic_utils.object_list_uid(input_shapes) 1028 self._output_shape_cache[cache_key] = nest.pack_sequence_as( 1029 self._nested_outputs, output_shapes) 1030 1031 output_tensors = nest.pack_sequence_as(self._nested_outputs, output_tensors) 1032 return output_tensors 1033 1034 def get_config(self): 1035 if not self._is_graph_network: 1036 raise NotImplementedError 1037 1038 config = { 1039 'name': self.name, 1040 } 1041 node_conversion_map = {} 1042 for layer in self.layers: 1043 if issubclass(layer.__class__, Network): 1044 # Networks start with a pre-existing node 1045 # linking their input to output. 1046 kept_nodes = 1 1047 else: 1048 kept_nodes = 0 1049 for original_node_index, node in enumerate(layer._inbound_nodes): 1050 node_key = _make_node_key(layer.name, original_node_index) 1051 if node_key in self._network_nodes: 1052 node_conversion_map[node_key] = kept_nodes 1053 kept_nodes += 1 1054 layer_configs = [] 1055 for layer in self.layers: # From the earliest layers on. 1056 layer_class_name = layer.__class__.__name__ 1057 layer_config = layer.get_config() 1058 filtered_inbound_nodes = [] 1059 for original_node_index, node in enumerate(layer._inbound_nodes): 1060 node_key = _make_node_key(layer.name, original_node_index) 1061 if node_key in self._network_nodes: 1062 # The node is relevant to the model: 1063 # add to filtered_inbound_nodes. 1064 if node.arguments: 1065 try: 1066 json.dumps(node.arguments) 1067 kwargs = node.arguments 1068 except TypeError: 1069 logging.warning( 1070 'Layer ' + layer.name + 1071 ' was passed non-serializable keyword arguments: ' + 1072 str(node.arguments) + '. They will not be included ' 1073 'in the serialized model (and thus will be missing ' 1074 'at deserialization time).') 1075 kwargs = {} 1076 else: 1077 kwargs = {} 1078 if node.inbound_layers: 1079 node_data = [] 1080 for inbound_layer, node_id, tensor_id, _ in node.iterate_inbound(): 1081 node_key = _make_node_key(inbound_layer.name, node_id) 1082 new_node_index = node_conversion_map.get(node_key, 0) 1083 node_data.append( 1084 tf_utils.ListWrapper( 1085 [inbound_layer.name, new_node_index, tensor_id, kwargs])) 1086 node_data = nest.pack_sequence_as(node.input_tensors, node_data) 1087 # Convert ListWrapper to list for backwards compatible configs. 1088 node_data = tf_utils.convert_inner_node_data(node_data) 1089 filtered_inbound_nodes.append(node_data) 1090 layer_configs.append({ 1091 'name': layer.name, 1092 'class_name': layer_class_name, 1093 'config': layer_config, 1094 'inbound_nodes': filtered_inbound_nodes, 1095 }) 1096 config['layers'] = layer_configs 1097 1098 # Gather info about inputs and outputs. 1099 model_inputs = [] 1100 for i in range(len(self._input_layers)): 1101 layer, node_index, tensor_index = self._input_coordinates[i] 1102 node_key = _make_node_key(layer.name, node_index) 1103 if node_key not in self._network_nodes: 1104 continue 1105 new_node_index = node_conversion_map[node_key] 1106 model_inputs.append( 1107 tf_utils.ListWrapper([layer.name, new_node_index, tensor_index])) 1108 model_inputs = nest.pack_sequence_as(self._nested_inputs, model_inputs) 1109 # Preserve external Keras compat for Models with single input. 1110 if not nest.is_sequence(model_inputs): 1111 model_inputs = [model_inputs] 1112 model_inputs = tf_utils.convert_inner_node_data(model_inputs) 1113 config['input_layers'] = model_inputs 1114 1115 model_outputs = [] 1116 for i in range(len(self._output_layers)): 1117 layer, node_index, tensor_index = self._output_coordinates[i] 1118 node_key = _make_node_key(layer.name, node_index) 1119 if node_key not in self._network_nodes: 1120 continue 1121 new_node_index = node_conversion_map[node_key] 1122 model_outputs.append( 1123 tf_utils.ListWrapper([layer.name, new_node_index, tensor_index])) 1124 model_outputs = nest.pack_sequence_as(self._nested_outputs, model_outputs) 1125 # Preserve external Keras compat for Models with single output. 1126 if not nest.is_sequence(model_outputs): 1127 model_outputs = [model_outputs] 1128 model_outputs = tf_utils.convert_inner_node_data(model_outputs) 1129 config['output_layers'] = model_outputs 1130 return copy.deepcopy(config) 1131 1132 @classmethod 1133 def from_config(cls, config, custom_objects=None): 1134 """Instantiates a Model from its config (output of `get_config()`). 1135 1136 Arguments: 1137 config: Model config dictionary. 1138 custom_objects: Optional dictionary mapping names 1139 (strings) to custom classes or functions to be 1140 considered during deserialization. 1141 1142 Returns: 1143 A model instance. 1144 1145 Raises: 1146 ValueError: In case of improperly formatted config dict. 1147 """ 1148 # Layer instances created during 1149 # the graph reconstruction process 1150 created_layers = {} 1151 1152 # Dictionary mapping layer instances to 1153 # node data that specifies a layer call. 1154 # It acts as a queue that maintains any unprocessed 1155 # layer call until it becomes possible to process it 1156 # (i.e. until the input tensors to the call all exist). 1157 unprocessed_nodes = {} 1158 1159 def add_unprocessed_node(layer, node_data): 1160 if layer not in unprocessed_nodes: 1161 unprocessed_nodes[layer] = [node_data] 1162 else: 1163 unprocessed_nodes[layer].append(node_data) 1164 1165 def process_node(layer, node_data): 1166 """Deserialize a node. 1167 1168 Arguments: 1169 layer: layer instance. 1170 node_data: Nested structure of `ListWrapper`. 1171 1172 Raises: 1173 ValueError: In case of improperly formatted `node_data`. 1174 """ 1175 input_tensors = [] 1176 for input_data in nest.flatten(node_data): 1177 input_data = input_data.as_list() 1178 inbound_layer_name = input_data[0] 1179 inbound_node_index = input_data[1] 1180 inbound_tensor_index = input_data[2] 1181 if len(input_data) == 3: 1182 kwargs = {} 1183 elif len(input_data) == 4: 1184 kwargs = input_data[3] 1185 else: 1186 raise ValueError('Improperly formatted model config.') 1187 1188 inbound_layer = created_layers[inbound_layer_name] 1189 if len(inbound_layer._inbound_nodes) <= inbound_node_index: 1190 add_unprocessed_node(layer, node_data) 1191 return 1192 inbound_node = inbound_layer._inbound_nodes[inbound_node_index] 1193 input_tensors.append( 1194 nest.flatten(inbound_node.output_tensors)[inbound_tensor_index]) 1195 input_tensors = nest.pack_sequence_as(node_data, input_tensors) 1196 # Call layer on its inputs, thus creating the node 1197 # and building the layer if needed. 1198 if input_tensors is not None: 1199 # Preserve compatibility with older configs. 1200 flat_input_tensors = nest.flatten(input_tensors) 1201 if len(flat_input_tensors) == 1: 1202 layer(flat_input_tensors[0], **kwargs) 1203 else: 1204 layer(input_tensors, **kwargs) 1205 1206 def process_layer(layer_data): 1207 """Deserializes a layer, then call it on appropriate inputs. 1208 1209 Arguments: 1210 layer_data: layer config dict. 1211 1212 Raises: 1213 ValueError: In case of improperly formatted `layer_data` dict. 1214 """ 1215 layer_name = layer_data['name'] 1216 1217 # Instantiate layer. 1218 from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top 1219 1220 layer = deserialize_layer(layer_data, custom_objects=custom_objects) 1221 created_layers[layer_name] = layer 1222 1223 # Gather layer inputs and convert to `ListWrapper` objects. 1224 inbound_nodes_data = layer_data['inbound_nodes'] 1225 inbound_nodes_data = tf_utils.convert_inner_node_data( 1226 inbound_nodes_data, wrap=True) 1227 for node_data in inbound_nodes_data: 1228 # We don't process nodes (i.e. make layer calls) 1229 # on the fly because the inbound node may not yet exist, 1230 # in case of layer shared at different topological depths 1231 # (e.g. a model such as A(B(A(B(x))))) 1232 add_unprocessed_node(layer, node_data) 1233 1234 # First, we create all layers and enqueue nodes to be processed 1235 for layer_data in config['layers']: 1236 process_layer(layer_data) 1237 # Then we process nodes in order of layer depth. 1238 # Nodes that cannot yet be processed (if the inbound node 1239 # does not yet exist) are re-enqueued, and the process 1240 # is repeated until all nodes are processed. 1241 while unprocessed_nodes: 1242 for layer_data in config['layers']: 1243 layer = created_layers[layer_data['name']] 1244 if layer in unprocessed_nodes: 1245 for node_data in unprocessed_nodes.pop(layer): 1246 process_node(layer, node_data) 1247 1248 name = config.get('name') 1249 input_tensors = [] 1250 output_tensors = [] 1251 1252 input_layers = tf_utils.convert_inner_node_data( 1253 config['input_layers'], wrap=True) 1254 for layer_data in nest.flatten(input_layers): 1255 layer_name, node_index, tensor_index = layer_data.as_list() 1256 assert layer_name in created_layers 1257 layer = created_layers[layer_name] 1258 layer_output_tensors = layer._inbound_nodes[node_index].output_tensors 1259 input_tensors.append(nest.flatten(layer_output_tensors)[tensor_index]) 1260 1261 output_layers = tf_utils.convert_inner_node_data( 1262 config['output_layers'], wrap=True) 1263 for layer_data in nest.flatten(output_layers): 1264 layer_name, node_index, tensor_index = layer_data.as_list() 1265 assert layer_name in created_layers 1266 layer = created_layers[layer_name] 1267 layer_output_tensors = layer._inbound_nodes[node_index].output_tensors 1268 output_tensors.append(nest.flatten(layer_output_tensors)[tensor_index]) 1269 1270 input_tensors = nest.pack_sequence_as(input_layers, input_tensors) 1271 output_tensors = nest.pack_sequence_as(output_layers, output_tensors) 1272 return cls(inputs=input_tensors, outputs=output_tensors, name=name) 1273 1274 def save(self, filepath, overwrite=True, include_optimizer=True): 1275 """Saves the model to a single HDF5 file. 1276 1277 The savefile includes: 1278 - The model architecture, allowing to re-instantiate the model. 1279 - The model weights. 1280 - The state of the optimizer, allowing to resume training 1281 exactly where you left off. 1282 1283 This allows you to save the entirety of the state of a model 1284 in a single file. 1285 1286 Saved models can be reinstantiated via `keras.models.load_model`. 1287 The model returned by `load_model` 1288 is a compiled model ready to be used (unless the saved model 1289 was never compiled in the first place). 1290 1291 Arguments: 1292 filepath: String, path to the file to save the weights to. 1293 overwrite: Whether to silently overwrite any existing file at the 1294 target location, or provide the user with a manual prompt. 1295 include_optimizer: If True, save optimizer's state together. 1296 1297 Example: 1298 1299 ```python 1300 from keras.models import load_model 1301 1302 model.save('my_model.h5') # creates a HDF5 file 'my_model.h5' 1303 del model # deletes the existing model 1304 1305 # returns a compiled model 1306 # identical to the previous one 1307 model = load_model('my_model.h5') 1308 ``` 1309 """ 1310 if not self._is_graph_network: 1311 raise NotImplementedError( 1312 'The `save` method requires the model to be a Functional model or a ' 1313 'Sequential model. It does not work for subclassed models, ' 1314 'because such models are defined via the body of a Python method, ' 1315 'which isn\'t safely serializable. Consider ' 1316 'using `save_weights`, in order to save the weights of the model.') 1317 1318 from tensorflow.python.keras.models import save_model # pylint: disable=g-import-not-at-top 1319 save_model(self, filepath, overwrite, include_optimizer) 1320 1321 def save_weights(self, filepath, overwrite=True, save_format=None): 1322 """Saves all layer weights. 1323 1324 Either saves in HDF5 or in TensorFlow format based on the `save_format` 1325 argument. 1326 1327 When saving in HDF5 format, the weight file has: 1328 - `layer_names` (attribute), a list of strings 1329 (ordered names of model layers). 1330 - For every layer, a `group` named `layer.name` 1331 - For every such layer group, a group attribute `weight_names`, 1332 a list of strings 1333 (ordered names of weights tensor of the layer). 1334 - For every weight in the layer, a dataset 1335 storing the weight value, named after the weight tensor. 1336 1337 When saving in TensorFlow format, all objects referenced by the network are 1338 saved in the same format as `tf.train.Checkpoint`, including any `Layer` 1339 instances or `Optimizer` instances assigned to object attributes. For 1340 networks constructed from inputs and outputs using `tf.keras.Model(inputs, 1341 outputs)`, `Layer` instances used by the network are tracked/saved 1342 automatically. For user-defined classes which inherit from `tf.keras.Model`, 1343 `Layer` instances must be assigned to object attributes, typically in the 1344 constructor. See the documentation of `tf.train.Checkpoint` and 1345 `tf.keras.Model` for details. 1346 1347 Arguments: 1348 filepath: String, path to the file to save the weights to. When saving 1349 in TensorFlow format, this is the prefix used for checkpoint files 1350 (multiple files are generated). Note that the '.h5' suffix causes 1351 weights to be saved in HDF5 format. 1352 overwrite: Whether to silently overwrite any existing file at the 1353 target location, or provide the user with a manual prompt. 1354 save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or 1355 '.keras' will default to HDF5 if `save_format` is `None`. Otherwise 1356 `None` defaults to 'tf'. 1357 1358 Raises: 1359 ImportError: If h5py is not available when attempting to save in HDF5 1360 format. 1361 ValueError: For invalid/unknown format arguments. 1362 """ 1363 filepath_is_h5 = _is_hdf5_filepath(filepath) 1364 if save_format is None: 1365 if filepath_is_h5: 1366 save_format = 'h5' 1367 else: 1368 save_format = 'tf' 1369 else: 1370 user_format = save_format.lower().strip() 1371 if user_format in ('tensorflow', 'tf'): 1372 save_format = 'tf' 1373 elif user_format in ('hdf5', 'h5', 'keras'): 1374 save_format = 'h5' 1375 else: 1376 raise ValueError( 1377 'Unknown format "%s". Was expecting one of {"tf", "h5"}.' % ( 1378 save_format,)) 1379 if save_format == 'tf' and filepath_is_h5: 1380 raise ValueError( 1381 ('save_weights got save_format="tf"/"tensorflow", but the ' 1382 'filepath ("%s") looks like an HDF5 file. Omit the ".h5"/".keras" ' 1383 'when saving in TensorFlow format.') 1384 % filepath) 1385 1386 if save_format == 'h5' and h5py is None: 1387 raise ImportError( 1388 '`save_weights` requires h5py when saving in hdf5.') 1389 if save_format == 'tf': 1390 check_filepath = filepath + '.index' 1391 else: 1392 check_filepath = filepath 1393 # If file exists and should not be overwritten: 1394 if not overwrite and os.path.isfile(check_filepath): 1395 proceed = ask_to_proceed_with_overwrite(check_filepath) 1396 if not proceed: 1397 return 1398 if save_format == 'h5': 1399 with h5py.File(filepath, 'w') as f: 1400 hdf5_format.save_weights_to_hdf5_group(f, self.layers) 1401 else: 1402 if context.executing_eagerly(): 1403 session = None 1404 else: 1405 session = backend.get_session() 1406 optimizer = getattr(self, 'optimizer', None) 1407 if (optimizer 1408 and not isinstance(optimizer, trackable.Trackable)): 1409 logging.warning( 1410 ('This model was compiled with a Keras optimizer (%s) but is being ' 1411 'saved in TensorFlow format with `save_weights`. The model\'s ' 1412 'weights will be saved, but unlike with TensorFlow optimizers in ' 1413 'the TensorFlow format the optimizer\'s state will not be ' 1414 'saved.\n\nConsider using a TensorFlow optimizer from `tf.train`.') 1415 % (optimizer,)) 1416 self._trackable_saver.save(filepath, session=session) 1417 # Record this checkpoint so it's visible from tf.train.latest_checkpoint. 1418 checkpoint_management.update_checkpoint_state_internal( 1419 save_dir=os.path.dirname(filepath), 1420 model_checkpoint_path=filepath, 1421 save_relative_paths=True, 1422 all_model_checkpoint_paths=[filepath]) 1423 1424 def load_weights(self, filepath, by_name=False): 1425 """Loads all layer weights, either from a TensorFlow or an HDF5 weight file. 1426 1427 If `by_name` is False weights are loaded based on the network's 1428 topology. This means the architecture should be the same as when the weights 1429 were saved. Note that layers that don't have weights are not taken into 1430 account in the topological ordering, so adding or removing layers is fine as 1431 long as they don't have weights. 1432 1433 If `by_name` is True, weights are loaded into layers only if they share the 1434 same name. This is useful for fine-tuning or transfer-learning models where 1435 some of the layers have changed. 1436 1437 Only topological loading (`by_name=False`) is supported when loading weights 1438 from the TensorFlow format. Note that topological loading differs slightly 1439 between TensorFlow and HDF5 formats for user-defined classes inheriting from 1440 `tf.keras.Model`: HDF5 loads based on a flattened list of weights, while the 1441 TensorFlow format loads based on the object-local names of attributes to 1442 which layers are assigned in the `Model`'s constructor. 1443 1444 Arguments: 1445 filepath: String, path to the weights file to load. For weight files in 1446 TensorFlow format, this is the file prefix (the same as was passed 1447 to `save_weights`). 1448 by_name: Boolean, whether to load weights by name or by topological 1449 order. Only topological loading is supported for weight files in 1450 TensorFlow format. 1451 1452 Returns: 1453 When loading a weight file in TensorFlow format, returns the same status 1454 object as `tf.train.Checkpoint.restore`. When graph building, restore 1455 ops are run automatically as soon as the network is built (on first call 1456 for user-defined classes inheriting from `Model`, immediately if it is 1457 already built). 1458 1459 When loading weights in HDF5 format, returns `None`. 1460 1461 Raises: 1462 ImportError: If h5py is not available and the weight file is in HDF5 1463 format. 1464 """ 1465 if _is_hdf5_filepath(filepath): 1466 save_format = 'h5' 1467 else: 1468 try: 1469 pywrap_tensorflow.NewCheckpointReader(filepath) 1470 save_format = 'tf' 1471 except errors_impl.DataLossError: 1472 # The checkpoint is not readable in TensorFlow format. Try HDF5. 1473 save_format = 'h5' 1474 if save_format == 'tf': 1475 status = self._trackable_saver.restore(filepath) 1476 if by_name: 1477 raise NotImplementedError( 1478 'Weights may only be loaded based on topology into Models when ' 1479 'loading TensorFlow-formatted weights (got by_name=True to ' 1480 'load_weights).') 1481 if not context.executing_eagerly(): 1482 session = backend.get_session() 1483 # Restore existing variables (if any) immediately, and set up a 1484 # streaming restore for any variables created in the future. 1485 trackable_utils.streaming_restore(status=status, session=session) 1486 status.assert_nontrivial_match() 1487 return status 1488 if h5py is None: 1489 raise ImportError( 1490 '`load_weights` requires h5py when loading weights from HDF5.') 1491 if self._is_graph_network and not self.built: 1492 raise NotImplementedError( 1493 'Unable to load weights saved in HDF5 format into a subclassed ' 1494 'Model which has not created its variables yet. Call the Model ' 1495 'first, then load the weights.') 1496 with h5py.File(filepath, 'r') as f: 1497 if 'layer_names' not in f.attrs and 'model_weights' in f: 1498 f = f['model_weights'] 1499 if by_name: 1500 hdf5_format.load_weights_from_hdf5_group_by_name(f, self.layers) 1501 else: 1502 hdf5_format.load_weights_from_hdf5_group(f, self.layers) 1503 1504 def _updated_config(self): 1505 """Util shared between different serialization methods. 1506 1507 Returns: 1508 Model config with Keras version information added. 1509 """ 1510 from tensorflow.python.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top 1511 1512 config = self.get_config() 1513 model_config = { 1514 'class_name': self.__class__.__name__, 1515 'config': config, 1516 'keras_version': keras_version, 1517 'backend': backend.backend() 1518 } 1519 return model_config 1520 1521 def to_json(self, **kwargs): 1522 """Returns a JSON string containing the network configuration. 1523 1524 To load a network from a JSON save file, use 1525 `keras.models.model_from_json(json_string, custom_objects={})`. 1526 1527 Arguments: 1528 **kwargs: Additional keyword arguments 1529 to be passed to `json.dumps()`. 1530 1531 Returns: 1532 A JSON string. 1533 """ 1534 model_config = self._updated_config() 1535 return json.dumps( 1536 model_config, default=serialization.get_json_type, **kwargs) 1537 1538 def to_yaml(self, **kwargs): 1539 """Returns a yaml string containing the network configuration. 1540 1541 To load a network from a yaml save file, use 1542 `keras.models.model_from_yaml(yaml_string, custom_objects={})`. 1543 1544 `custom_objects` should be a dictionary mapping 1545 the names of custom losses / layers / etc to the corresponding 1546 functions / classes. 1547 1548 Arguments: 1549 **kwargs: Additional keyword arguments 1550 to be passed to `yaml.dump()`. 1551 1552 Returns: 1553 A YAML string. 1554 1555 Raises: 1556 ImportError: if yaml module is not found. 1557 """ 1558 if yaml is None: 1559 raise ImportError( 1560 'Requires yaml module installed (`pip install pyyaml`).') 1561 return yaml.dump(self._updated_config(), **kwargs) 1562 1563 def summary(self, line_length=None, positions=None, print_fn=None): 1564 """Prints a string summary of the network. 1565 1566 Arguments: 1567 line_length: Total length of printed lines 1568 (e.g. set this to adapt the display to different 1569 terminal window sizes). 1570 positions: Relative or absolute positions of log elements 1571 in each line. If not provided, 1572 defaults to `[.33, .55, .67, 1.]`. 1573 print_fn: Print function to use. Defaults to `print`. 1574 It will be called on each line of the summary. 1575 You can set it to a custom function 1576 in order to capture the string summary. 1577 1578 Raises: 1579 ValueError: if `summary()` is called before the model is built. 1580 """ 1581 if not self.built: 1582 raise ValueError('This model has not yet been built. ' 1583 'Build the model first by calling `build()` or calling ' 1584 '`fit()` with some data, or specify ' 1585 'an `input_shape` argument in the first layer(s) for ' 1586 'automatic build.') 1587 layer_utils.print_summary(self, 1588 line_length=line_length, 1589 positions=positions, 1590 print_fn=print_fn) 1591 1592 def _validate_graph_inputs_and_outputs(self): 1593 """Validates the inputs and outputs of a Graph Network.""" 1594 # Check for redundancy in inputs. 1595 if len(set(self.inputs)) != len(self.inputs): 1596 raise ValueError('The list of inputs passed to the model ' 1597 'is redundant. ' 1598 'All inputs should only appear once.' 1599 ' Found: ' + str(self.inputs)) 1600 1601 for x in self.inputs: 1602 # Check that x has appropriate `_keras_history` metadata. 1603 if not hasattr(x, '_keras_history'): 1604 cls_name = self.__class__.__name__ 1605 raise ValueError('Input tensors to a ' + cls_name + ' ' + 1606 'must come from `tf.keras.Input`. ' 1607 'Received: ' + str(x) + 1608 ' (missing previous layer metadata).') 1609 # Check that x is an input tensor. 1610 # pylint: disable=protected-access 1611 layer, _, _ = x._keras_history 1612 if len(layer._inbound_nodes) > 1 or ( 1613 layer._inbound_nodes and layer._inbound_nodes[0].inbound_layers): 1614 cls_name = self.__class__.__name__ 1615 logging.warning(cls_name + ' inputs must come from ' 1616 '`tf.keras.Input` (thus holding past layer metadata), ' 1617 'they cannot be the output of ' 1618 'a previous non-Input layer. ' 1619 'Here, a tensor specified as ' 1620 'input to "' + self.name + '" was not an Input tensor, ' 1621 'it was generated by layer ' + layer.name + '.\n' 1622 'Note that input tensors are ' 1623 'instantiated via `tensor = tf.keras.Input(shape)`.\n' 1624 'The tensor that caused the issue was: ' + str(x.name)) 1625 1626 # Check compatibility of batch sizes of Input Layers. 1627 input_batch_sizes = [ 1628 training_utils.get_static_batch_size(x._keras_history[0]) 1629 for x in self.inputs 1630 ] 1631 consistent_batch_size = None 1632 for batch_size in input_batch_sizes: 1633 if batch_size is not None: 1634 if (consistent_batch_size is not None and 1635 batch_size != consistent_batch_size): 1636 raise ValueError('The specified batch sizes of the Input Layers' 1637 ' are incompatible. Found batch sizes: {}'.format( 1638 input_batch_sizes)) 1639 consistent_batch_size = batch_size 1640 1641 for x in self.outputs: 1642 if not hasattr(x, '_keras_history'): 1643 cls_name = self.__class__.__name__ 1644 raise ValueError('Output tensors to a ' + cls_name + ' must be ' 1645 'the output of a TensorFlow `Layer` ' 1646 '(thus holding past layer metadata). Found: ' + str(x)) 1647 1648 1649def _is_hdf5_filepath(filepath): 1650 return (filepath.endswith('.h5') or filepath.endswith('.keras') or 1651 filepath.endswith('.hdf5')) 1652 1653 1654def _make_node_key(layer_name, node_index): 1655 return layer_name + '_ib-' + str(node_index) 1656 1657 1658def _map_graph_network(inputs, outputs): 1659 """Validates a network's topology and gather its layers and nodes. 1660 1661 Arguments: 1662 inputs: List of input tensors. 1663 outputs: List of outputs tensors. 1664 1665 Returns: 1666 A tuple `(nodes, nodes_by_depth, layers, layers_by_depth)`. 1667 - nodes: list of Node instances. 1668 - nodes_by_depth: dict mapping ints (depth) to lists of node instances. 1669 - layers: list of Layer instances. 1670 - layers_by_depth: dict mapping ints (depth) to lists of layer instances. 1671 1672 Raises: 1673 ValueError: In case the network is not valid (e.g. disconnected graph). 1674 """ 1675 # Network_nodes: set of nodes included in the graph of layers 1676 # (not all nodes included in the layers are relevant to the current graph). 1677 network_nodes = set() # ids of all nodes relevant to the Network 1678 nodes_depths = {} # dict {node: depth value} 1679 layers_depths = {} # dict {layer: depth value} 1680 layer_indices = {} # dict {layer: index in traversal} 1681 nodes_in_decreasing_depth = [] 1682 1683 def build_map(tensor, 1684 finished_nodes, 1685 nodes_in_progress, 1686 layer, 1687 node_index, 1688 tensor_index): 1689 """Builds a map of the graph of layers. 1690 1691 This recursively updates the map `layer_indices`, 1692 the list `nodes_in_decreasing_depth` and the set `network_nodes`. 1693 1694 Arguments: 1695 tensor: Some tensor in a graph. 1696 finished_nodes: Set of nodes whose subgraphs have been traversed 1697 completely. Useful to prevent duplicated work. 1698 nodes_in_progress: Set of nodes that are currently active on the 1699 recursion stack. Useful to detect cycles. 1700 layer: Layer from which `tensor` comes from. If not provided, 1701 will be obtained from `tensor._keras_history`. 1702 node_index: Node index from which `tensor` comes from. 1703 tensor_index: Tensor_index from which `tensor` comes from. 1704 1705 Raises: 1706 ValueError: if a cycle is detected. 1707 """ 1708 node = layer._inbound_nodes[node_index] # pylint: disable=protected-access 1709 1710 # Prevent cycles. 1711 if node in nodes_in_progress: 1712 raise ValueError('The tensor ' + str(tensor) + ' at layer "' + 1713 layer.name + '" is part of a cycle.') 1714 1715 # Don't repeat work for shared subgraphs 1716 if node in finished_nodes: 1717 return 1718 1719 node_key = _make_node_key(layer.name, node_index) 1720 # Update network_nodes. 1721 network_nodes.add(node_key) 1722 1723 # Store the traversal order for layer sorting. 1724 if layer not in layer_indices: 1725 layer_indices[layer] = len(layer_indices) 1726 1727 nodes_in_progress.add(node) 1728 1729 # Propagate to all previous tensors connected to this node. 1730 for layer, node_index, tensor_index, tensor in node.iterate_inbound(): 1731 build_map(tensor, finished_nodes, nodes_in_progress, layer, node_index, 1732 tensor_index) 1733 1734 finished_nodes.add(node) 1735 nodes_in_progress.remove(node) 1736 nodes_in_decreasing_depth.append(node) 1737 1738 finished_nodes = set() 1739 nodes_in_progress = set() 1740 for x in outputs: 1741 layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access 1742 build_map(x, finished_nodes, nodes_in_progress, 1743 layer=layer, 1744 node_index=node_index, 1745 tensor_index=tensor_index) 1746 1747 for node in reversed(nodes_in_decreasing_depth): 1748 # If the depth is not set, the node has no outbound nodes (depth 0). 1749 depth = nodes_depths.setdefault(node, 0) 1750 1751 # Update the depth of the corresponding layer 1752 previous_depth = layers_depths.get(node.outbound_layer, 0) 1753 # If we've seen this layer before at a higher depth, 1754 # we should use that depth instead of the node depth. 1755 # This is necessary for shared layers that have inputs at different 1756 # depth levels in the graph. 1757 depth = max(depth, previous_depth) 1758 layers_depths[node.outbound_layer] = depth 1759 nodes_depths[node] = depth 1760 1761 # Update the depth of inbound nodes. 1762 # The "depth" of a node is the max of the depths 1763 # of all layers it is connected to. 1764 for inbound_layer, node_index, _, _ in node.iterate_inbound(): 1765 inbound_node = inbound_layer._inbound_nodes[node_index] # pylint: disable=protected-access 1766 previous_depth = nodes_depths.get(inbound_node, 0) 1767 nodes_depths[inbound_node] = max(depth + 1, previous_depth) 1768 1769 # Build a dict {depth: list of nodes with this depth} 1770 nodes_by_depth = {} 1771 for node, depth in nodes_depths.items(): 1772 if depth not in nodes_by_depth: 1773 nodes_by_depth[depth] = [] 1774 nodes_by_depth[depth].append(node) 1775 1776 # Build a dict {depth: list of layers with this depth} 1777 layers_by_depth = {} 1778 for layer, depth in layers_depths.items(): 1779 if depth not in layers_by_depth: 1780 layers_by_depth[depth] = [] 1781 layers_by_depth[depth].append(layer) 1782 1783 # Get sorted list of layer depths. 1784 depth_keys = list(layers_by_depth.keys()) 1785 depth_keys.sort(reverse=True) 1786 1787 # Set self.layers and self._layers_by_depth. 1788 layers = [] 1789 for depth in depth_keys: 1790 layers_for_depth = layers_by_depth[depth] 1791 # Network.layers needs to have a deterministic order: 1792 # here we order them by traversal order. 1793 layers_for_depth.sort(key=lambda x: layer_indices[x]) 1794 layers.extend(layers_for_depth) 1795 1796 # Get sorted list of node depths. 1797 depth_keys = list(nodes_by_depth.keys()) 1798 depth_keys.sort(reverse=True) 1799 1800 # Check that all tensors required are computable. 1801 # computable_tensors: all tensors in the graph 1802 # that can be computed from the inputs provided. 1803 computable_tensors = [] 1804 for x in inputs: 1805 computable_tensors.append(x) 1806 1807 layers_with_complete_input = [] # To provide a better error msg. 1808 for depth in depth_keys: 1809 for node in nodes_by_depth[depth]: 1810 layer = node.outbound_layer 1811 if layer: 1812 for x in nest.flatten(node.input_tensors): 1813 if x not in computable_tensors: 1814 raise ValueError('Graph disconnected: ' 1815 'cannot obtain value for tensor ' + str(x) + 1816 ' at layer "' + layer.name + '". ' 1817 'The following previous layers ' 1818 'were accessed without issue: ' + 1819 str(layers_with_complete_input)) 1820 for x in nest.flatten(node.output_tensors): 1821 computable_tensors.append(x) 1822 layers_with_complete_input.append(layer.name) 1823 1824 # Ensure name unicity, which will be crucial for serialization 1825 # (since serialized nodes refer to layers by their name). 1826 all_names = [layer.name for layer in layers] 1827 for name in all_names: 1828 if all_names.count(name) != 1: 1829 raise ValueError('The name "' + name + '" is used ' + 1830 str(all_names.count(name)) + ' times in the model. ' 1831 'All layer names should be unique.') 1832 return network_nodes, nodes_by_depth, layers, layers_by_depth 1833