1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""A `Network` is way to compose layers: the topological form of a `Model`. 17""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23import copy 24import itertools 25import json 26import os 27 28import numpy as np 29import six 30from six.moves import zip # pylint: disable=redefined-builtin 31 32from tensorflow.python.eager import context 33from tensorflow.python.framework import constant_op 34from tensorflow.python.framework import errors 35from tensorflow.python.framework import errors_impl 36from tensorflow.python.framework import func_graph 37from tensorflow.python.framework import ops 38from tensorflow.python.framework import tensor_shape 39from tensorflow.python.keras import backend 40from tensorflow.python.keras.engine import base_layer 41from tensorflow.python.keras.engine import base_layer_utils 42from tensorflow.python.keras.engine import input_layer as input_layer_module 43from tensorflow.python.keras.engine import node as node_module 44from tensorflow.python.keras.engine import training_utils 45from tensorflow.python.keras.saving import hdf5_format 46from tensorflow.python.keras.saving import save 47from tensorflow.python.keras.saving.saved_model import network_serialization 48from tensorflow.python.keras.utils import generic_utils 49from tensorflow.python.keras.utils import layer_utils 50from tensorflow.python.keras.utils import tf_utils 51from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite 52from tensorflow.python.ops.ragged import ragged_tensor 53from tensorflow.python.platform import tf_logging as logging 54from tensorflow.python.training import checkpoint_management 55from tensorflow.python.training import py_checkpoint_reader 56from tensorflow.python.training.tracking import base as trackable 57from tensorflow.python.training.tracking import data_structures 58from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils 59from tensorflow.python.training.tracking import tracking 60from tensorflow.python.training.tracking import util as trackable_utils 61from tensorflow.python.util import nest 62from tensorflow.python.util import serialization 63from tensorflow.python.util import tf_inspect 64 65 66# pylint: disable=g-import-not-at-top 67try: 68 import h5py 69except ImportError: 70 h5py = None 71 72try: 73 import yaml 74except ImportError: 75 yaml = None 76# pylint: enable=g-import-not-at-top 77 78 79class Network(base_layer.Layer): 80 """A `Network` is a composition of layers. 81 82 `Network` is the topological form of a "model". A `Model` 83 is simply a `Network` with added training routines. 84 85 Two types of `Networks` exist: Graph Networks and Subclass Networks. Graph 86 networks are used in the Keras Functional and Sequential APIs. Subclassed 87 networks are used when a user subclasses the `Model` class. In general, 88 more Keras features are supported with Graph Networks than with Subclassed 89 Networks, specifically: 90 91 - Model cloning (`keras.models.clone`) 92 - Serialization (`model.get_config()/from_config`, `model.to_json()/to_yaml()` 93 - Whole-model saving (`model.save()`) 94 95 A Graph Network can be instantiated by passing two arguments to `__init__`. 96 The first argument is the `keras.Input` Tensors that represent the inputs 97 to the Network. The second argument specifies the output Tensors that 98 represent the outputs of this Network. Both arguments can be a nested 99 structure of Tensors. 100 101 Example: 102 103 ``` 104 inputs = {'x1': keras.Input(shape=(10,)), 'x2': keras.Input(shape=(1,))} 105 t = keras.layers.Dense(1, activation='relu')(inputs['x1']) 106 outputs = keras.layers.Add()([t, inputs['x2']) 107 network = Network(inputs, outputs) 108 ``` 109 110 A Graph Network constructed using the Functional API can also include raw 111 TensorFlow functions, with the exception of functions that create Variables 112 or assign ops. 113 114 Example: 115 116 ``` 117 inputs = keras.Input(shape=(10,)) 118 x = keras.layers.Dense(1)(inputs) 119 outputs = tf.nn.relu(x) 120 network = Network(inputs, outputs) 121 ``` 122 123 Subclassed Networks can be instantiated via `name` and (optional) `dynamic` 124 keyword arguments. Subclassed Networks keep track of their Layers, and their 125 `call` method can be overridden. Subclassed Networks are typically created 126 indirectly, by subclassing the `Model` class. 127 128 Example: 129 130 ``` 131 class MyModel(keras.Model): 132 def __init__(self): 133 super(MyModel, self).__init__(name='my_model', dynamic=False) 134 135 self.layer1 = keras.layers.Dense(10, activation='relu') 136 137 def call(self, inputs): 138 return self.layer1(inputs) 139 ``` 140 141 Allowed args in `super().__init__`: 142 name: String name of the model. 143 dynamic: (Subclassed models only) Set this to `True` if your model should 144 only be run eagerly, and should not be used to generate a static 145 computation graph. This attribute is automatically set for Functional API 146 models. 147 trainable: Boolean, whether the model's variables should be trainable. 148 dtype: (Subclassed models only) Default dtype of the model's weights ( 149 default of `None` means use the type of the first input). This attribute 150 has no effect on Functional API models, which do not have weights of their 151 own. 152 """ 153 154 # See tf.Module for the usage of this property. 155 # The key of _layer_call_argspecs is a layer. tf.Module._flatten will fail to 156 # flatten the key since it is trying to convert Trackable/Layer to a string. 157 _TF_MODULE_IGNORED_PROPERTIES = frozenset(itertools.chain( 158 ('_layer_call_argspecs', '_compiled_trainable_state'), 159 base_layer.Layer._TF_MODULE_IGNORED_PROPERTIES 160 )) 161 162 def __init__(self, *args, **kwargs): # pylint: disable=super-init-not-called 163 # Signature detection 164 if (len(args) == 2 or 165 len(args) == 1 and 'outputs' in kwargs or 166 'inputs' in kwargs and 'outputs' in kwargs): 167 # Graph network 168 self._init_graph_network(*args, **kwargs) 169 else: 170 # Subclassed network 171 self._init_subclassed_network(**kwargs) 172 173 tf_utils.assert_no_legacy_layers(self.layers) 174 175 # Several Network methods have "no_automatic_dependency_tracking" 176 # annotations. Since Network does automatic dependency tracking on attribute 177 # assignment, including for common data structures such as lists, by default 178 # we'd have quite a few empty dependencies which users don't care about (or 179 # would need some way to ignore dependencies automatically, which is confusing 180 # when applied to user code). Some attributes, such as _layers, would cause 181 # structural issues (_layers being the place where Layers assigned to tracked 182 # attributes are stored). 183 # 184 # Aside from these aesthetic and structural issues, useless dependencies on 185 # empty lists shouldn't cause issues; adding or removing them will not break 186 # checkpoints, but may cause "all Python objects matched" assertions to fail 187 # (in which case less strict assertions may be substituted if necessary). 188 @trackable.no_automatic_dependency_tracking 189 def _base_init(self, name=None, **kwargs): 190 # The following are implemented as property functions: 191 # self.trainable_weights 192 # self.non_trainable_weights 193 # self.input_spec 194 # self.losses 195 # self.updates 196 197 generic_utils.validate_kwargs(kwargs, {'trainable', 'dtype', 'dynamic', 198 'autocast'}) 199 200 super(Network, self).__init__(name=name, **kwargs) 201 202 self._is_compiled = False 203 204 # This is True for Sequential networks and Functional networks. 205 self._compute_output_and_mask_jointly = False 206 207 if not hasattr(self, 'optimizer'): 208 # Don't reset optimizer if already set. 209 self.optimizer = None 210 211 self._scope = None # Never used. 212 self._reuse = None # Never used. 213 if context.executing_eagerly(): 214 self._graph = None 215 else: 216 self._graph = ops.get_default_graph() # Used in symbolic mode only. 217 218 self._trackable_saver = ( 219 trackable_utils.saver_with_op_caching(self)) 220 221 @trackable.no_automatic_dependency_tracking 222 def _init_graph_network(self, inputs, outputs, name=None, **kwargs): 223 generic_utils.validate_kwargs( 224 kwargs, {'trainable'}, 225 'Functional models may only specify `name` and `trainable` keyword ' 226 'arguments during initialization. Got an unexpected argument:') 227 # Normalize and set self.inputs, self.outputs. 228 if isinstance(inputs, list) and len(nest.flatten(inputs)) == 1: 229 inputs = inputs[0] 230 if isinstance(outputs, list) and len(nest.flatten(outputs)) == 1: 231 outputs = outputs[0] 232 self._nested_outputs = outputs 233 self._nested_inputs = inputs 234 self.inputs = nest.flatten(inputs) 235 self.outputs = nest.flatten(outputs) 236 237 if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs): 238 base_layer_utils.create_keras_history(self._nested_outputs) 239 240 self._base_init(name=name, **kwargs) 241 self._validate_graph_inputs_and_outputs() 242 243 # A Network does not create weights of its own, thus it is already 244 # built. 245 self.built = True 246 self._compute_output_and_mask_jointly = True 247 self._is_graph_network = True 248 # `_expects_training_arg` is True since the `training` argument is always 249 # present in the signature of the `call` method of a graph network. 250 self._expects_training_arg = True 251 self._expects_mask_arg = True 252 # A graph network does not autocast inputs, as its layers will cast them 253 # instead. 254 self._autocast = False 255 256 self._input_layers = [] 257 self._output_layers = [] 258 self._input_coordinates = [] 259 self._output_coordinates = [] 260 261 self._supports_ragged_inputs = None 262 263 # This is for performance optimization when calling the Network on new 264 # inputs. Every time the Network is called on a set on input tensors, 265 # we compute the output tensors, output masks and output shapes in one pass, 266 # then cache them here. When any of these outputs is queried later, we 267 # retrieve it from there instead of recomputing it. 268 self._output_mask_cache = {} 269 self._output_tensor_cache = {} 270 self._output_shape_cache = {} 271 272 # Build self._output_layers: 273 for x in self.outputs: 274 layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access 275 self._output_layers.append(layer) 276 self._output_coordinates.append((layer, node_index, tensor_index)) 277 278 # Build self._input_layers: 279 for x in self.inputs: 280 layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access 281 # It's supposed to be an input layer, so only one node 282 # and one tensor output. 283 assert node_index == 0 284 assert tensor_index == 0 285 self._input_layers.append(layer) 286 self._input_coordinates.append((layer, node_index, tensor_index)) 287 288 # Keep track of the network's nodes and layers. 289 nodes, nodes_by_depth, layers, _ = _map_graph_network( 290 self.inputs, self.outputs) 291 self._network_nodes = nodes 292 self._nodes_by_depth = nodes_by_depth 293 self._layers = layers 294 self._layer_call_argspecs = {} 295 for layer in self._layers: 296 self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call) 297 layer._attribute_sentinel.add_parent(self._attribute_sentinel) 298 299 # Create the node linking internal inputs to internal outputs. 300 node_module.Node( 301 outbound_layer=self, 302 inbound_layers=[], 303 node_indices=[], 304 tensor_indices=[], 305 input_tensors=self._nested_inputs, 306 output_tensors=self._nested_outputs) 307 308 # Build self.input_names and self.output_names. 309 self._set_output_names() 310 self.input_names = [] 311 self._feed_input_names = [] 312 self._feed_inputs = [] 313 self._feed_input_shapes = [] 314 for layer in self._input_layers: 315 self.input_names.append(layer.name) 316 if layer.is_placeholder: 317 self._feed_input_names.append(layer.name) 318 # Use batch_input_shape here because non-eager composite tensors may not 319 # have a shape attribute that's meaningful (sparse, for instance, has 320 # a tensor that's non-constant and needs to be fed). This means that 321 # input layers that create placeholders will need to have the 322 # batch_input_shape attr to allow for input shape validation. 323 self._feed_input_shapes.append(layer._batch_input_shape) 324 self._feed_inputs.append(layer.input) 325 326 self._compute_tensor_usage_count() 327 328 def _set_output_names(self): 329 """Assigns unique names to the Network's outputs. 330 331 Output layers with multiple output tensors would otherwise lead to duplicate 332 names in self.output_names. 333 """ 334 uniquified = [] 335 output_names = set() 336 prefix_count = {} 337 for layer in self._output_layers: 338 proposal = layer.name 339 while proposal in output_names: 340 existing_count = prefix_count.get(layer.name, 1) 341 proposal = '{}_{}'.format(layer.name, existing_count) 342 prefix_count[layer.name] = existing_count + 1 343 output_names.add(proposal) 344 uniquified.append(proposal) 345 self.output_names = uniquified 346 347 @trackable.no_automatic_dependency_tracking 348 def _init_subclassed_network(self, name=None, **kwargs): 349 self._base_init(name=name, **kwargs) 350 self._is_graph_network = False 351 self._init_call_fn_args() 352 self._autocast = kwargs.get('autocast', 353 base_layer_utils.v2_dtype_behavior_enabled()) 354 self._supports_ragged_inputs = None 355 self.outputs = [] 356 self.inputs = [] 357 self.built = False 358 359 @property 360 @trackable_layer_utils.cache_recursive_attribute('dynamic') 361 def dynamic(self): 362 if self._is_graph_network: 363 return any(layer.dynamic for layer in self.layers) 364 return self._dynamic or any(layer.dynamic for layer in self.layers) 365 366 @property 367 def _layer_checkpoint_dependencies(self): 368 """Dictionary of layer dependencies to be included in the checkpoint.""" 369 # Use getattr because this function can be called from __setattr__, at which 370 # point the _is_graph_network attribute has not been created. 371 if (not getattr(self, '_is_graph_network', False) and 372 base_layer_utils.is_subclassed(self)): 373 return {} # Only add layer dependencies for graph networks 374 375 weight_layer_index = 0 376 377 dependencies = {} 378 for layer_index, layer in enumerate(self.layers): 379 try: 380 if layer.weights: 381 # Keep a separate index for layers which have weights. This allows 382 # users to insert Layers without weights anywhere in the network 383 # without breaking checkpoints. 384 dependencies['layer_with_weights-%d' % weight_layer_index] = layer 385 weight_layer_index += 1 386 except ValueError: 387 # The layer might have weights, but may not be built yet. We just treat 388 # it as layer without weight. 389 pass 390 391 # Even if it doesn't have weights, we should still track everything in 392 # case it has/will have Trackable dependencies. 393 dependencies['layer-%d' % layer_index] = layer 394 return dependencies 395 396 @property 397 def _checkpoint_dependencies(self): 398 dependencies = [ 399 trackable.TrackableReference(name=name, ref=layer) 400 for name, layer in self._layer_checkpoint_dependencies.items()] 401 dependencies.extend(super(Network, self)._checkpoint_dependencies) 402 return dependencies 403 404 def _lookup_dependency(self, name): 405 layer_dependencies = self._layer_checkpoint_dependencies 406 if name in layer_dependencies: 407 return layer_dependencies[name] 408 return super(Network, self)._lookup_dependency(name) 409 410 def _handle_deferred_layer_dependencies(self, layers): 411 """Handles layer checkpoint dependencies that are added after init.""" 412 layer_checkpoint_dependencies = self._layer_checkpoint_dependencies 413 layer_to_name = {v: k for k, v in layer_checkpoint_dependencies.items()} 414 for layer in layers: 415 if layer in layer_to_name: 416 self._handle_deferred_dependencies(name=layer_to_name[layer], 417 trackable=layer) 418 419 def __setattr__(self, name, value): 420 if not getattr(self, '_self_setattr_tracking', True): 421 super(Network, self).__setattr__(name, value) 422 return 423 424 if all( 425 isinstance(v, (base_layer.Layer, 426 data_structures.TrackableDataStructure)) or 427 trackable_layer_utils.has_weights(v) for v in nest.flatten(value)): 428 try: 429 self._is_graph_network 430 except AttributeError: 431 # six.raise_from supresses the original AttributeError from being raised 432 six.raise_from( 433 RuntimeError('It looks like you are subclassing `Model` and you ' 434 'forgot to call `super(YourClass, self).__init__()`.' 435 ' Always start with this line.'), None) 436 437 super(Network, self).__setattr__(name, value) 438 439 # Keep track of metric instance created in subclassed model/layer. 440 # We do this so that we can maintain the correct order of metrics by adding 441 # the instance to the `metrics` list as soon as it is created. 442 from tensorflow.python.keras import metrics as metrics_module # pylint: disable=g-import-not-at-top 443 if isinstance(value, metrics_module.Metric): 444 self._metrics.append(value) 445 446 @property 447 @trackable_layer_utils.cache_recursive_attribute('stateful') 448 def stateful(self): 449 return any(getattr(layer, 'stateful', False) for layer in self.layers) 450 451 def reset_states(self): 452 for layer in self.layers: 453 if hasattr(layer, 'reset_states') and getattr(layer, 'stateful', False): 454 layer.reset_states() 455 456 @property 457 def state_updates(self): 458 """Returns the `updates` from all layers that are stateful. 459 460 This is useful for separating training updates and 461 state updates, e.g. when we need to update a layer's internal state 462 during prediction. 463 464 Returns: 465 A list of update ops. 466 """ 467 state_updates = [] 468 for layer in self.layers: 469 if getattr(layer, 'stateful', False): 470 if hasattr(layer, 'updates'): 471 state_updates += layer.updates 472 return state_updates 473 474 @property 475 def weights(self): 476 """Returns the list of all layer variables/weights. 477 478 Returns: 479 A list of variables. 480 """ 481 return self._dedup_weights(self._undeduplicated_weights) 482 483 @property 484 def _undeduplicated_weights(self): 485 """Returns the undeduplicated list of all layer variables/weights.""" 486 self._assert_weights_created() 487 weights = [] 488 for layer in self._layers: 489 weights += layer.weights 490 weights += (self._trainable_weights + self._non_trainable_weights) 491 return weights 492 493 @property 494 @tracking.cached_per_instance 495 def _should_compute_mask(self): 496 return self._is_graph_network and super(Network, self)._should_compute_mask 497 498 def compute_mask(self, inputs, mask): 499 if not self._is_graph_network: 500 return None 501 502 # TODO(omalleyt): b/123540974 This function is not really safe to call 503 # by itself because it will duplicate any updates and losses in graph 504 # mode by `call`ing the Layers again. 505 output_tensors = self._run_internal_graph(inputs, mask=mask) 506 return nest.map_structure(lambda t: t._keras_mask, output_tensors) 507 508 @property 509 def layers(self): 510 return list( 511 trackable_layer_utils.filter_empty_layer_containers(self._layers)) 512 513 def get_layer(self, name=None, index=None): 514 """Retrieves a layer based on either its name (unique) or index. 515 516 If `name` and `index` are both provided, `index` will take precedence. 517 Indices are based on order of horizontal graph traversal (bottom-up). 518 519 Arguments: 520 name: String, name of layer. 521 index: Integer, index of layer. 522 523 Returns: 524 A layer instance. 525 526 Raises: 527 ValueError: In case of invalid layer name or index. 528 """ 529 # TODO(fchollet): We could build a dictionary based on layer names 530 # since they are constant, but we have not done that yet. 531 if index is not None: 532 if len(self.layers) <= index: 533 raise ValueError('Was asked to retrieve layer at index ' + str(index) + 534 ' but model only has ' + str(len(self.layers)) + 535 ' layers.') 536 else: 537 return self.layers[index] 538 else: 539 if not name: 540 raise ValueError('Provide either a layer name or layer index.') 541 for layer in self.layers: 542 if layer.name == name: 543 return layer 544 raise ValueError('No such layer: ' + name) 545 546 @property 547 def trainable_weights(self): 548 self._assert_weights_created() 549 return self._dedup_weights( 550 trackable_layer_utils.gather_trainable_weights( 551 trainable=self.trainable, 552 sub_layers=self._layers, 553 extra_variables=self._trainable_weights)) 554 555 @property 556 def non_trainable_weights(self): 557 self._assert_weights_created() 558 return self._dedup_weights( 559 trackable_layer_utils.gather_non_trainable_weights( 560 trainable=self.trainable, 561 sub_layers=self._layers, 562 extra_variables=self._non_trainable_weights + 563 self._trainable_weights)) 564 565 @property 566 def input_spec(self): 567 """Gets the network's input specs. 568 569 Returns: 570 A list of `InputSpec` instances (one per input to the model) 571 or a single instance if the model has only one input. 572 """ 573 # If subclassed model, can't assume anything. 574 if not self._is_graph_network: 575 return None 576 577 specs = [] 578 for layer in self._input_layers: 579 if layer.input_spec is None: 580 specs.append(None) 581 else: 582 if not isinstance(layer.input_spec, list): 583 raise TypeError('Layer ' + layer.name + 584 ' has an input_spec attribute that ' 585 'is not a list. We expect a list. ' 586 'Found input_spec = ' + str(layer.input_spec)) 587 specs += layer.input_spec 588 if len(specs) == 1: 589 return specs[0] 590 return specs 591 592 @base_layer_utils.default 593 def build(self, input_shape): 594 """Builds the model based on input shapes received. 595 596 This is to be used for subclassed models, which do not know at instantiation 597 time what their inputs look like. 598 599 This method only exists for users who want to call `model.build()` in a 600 standalone way (as a substitute for calling the model on real data to 601 build it). It will never be called by the framework (and thus it will 602 never throw unexpected errors in an unrelated workflow). 603 604 Args: 605 input_shape: Single tuple, TensorShape, or list of shapes, where shapes 606 are tuples, integers, or TensorShapes. 607 608 Raises: 609 ValueError: 610 1. In case of invalid user-provided data (not of type tuple, 611 list, or TensorShape). 612 2. If the model requires call arguments that are agnostic 613 to the input shapes (positional or kwarg in call signature). 614 3. If not all layers were properly built. 615 4. If float type inputs are not supported within the layers. 616 617 In each of these cases, the user should build their model by calling it 618 on real tensor data. 619 """ 620 if self._is_graph_network: 621 self.built = True 622 return 623 624 # If subclass network 625 if input_shape is None: 626 raise ValueError('Input shape must be defined when calling build on a ' 627 'model subclass network.') 628 valid_types = (tuple, list, tensor_shape.TensorShape) 629 if not isinstance(input_shape, valid_types): 630 raise ValueError('Specified input shape is not one of the valid types. ' 631 'Please specify a batch input shape of type tuple or ' 632 'list of input shapes. User provided ' 633 'input type: {}'.format(type(input_shape))) 634 635 if input_shape and not self.inputs: 636 # We create placeholders for the `None`s in the shape and build the model 637 # in a Graph. Since tf.Variable is compatible with both eager execution 638 # and graph building, the variables created after building the model in 639 # a Graph are still valid when executing eagerly. 640 if context.executing_eagerly(): 641 graph = func_graph.FuncGraph('build_graph') 642 else: 643 graph = backend.get_graph() 644 with graph.as_default(): 645 if isinstance(input_shape, list): 646 x = [base_layer_utils.generate_placeholders_from_shape(shape) 647 for shape in input_shape] 648 else: 649 x = base_layer_utils.generate_placeholders_from_shape(input_shape) 650 651 kwargs = {} 652 call_signature = self._call_full_argspec 653 call_args = call_signature.args 654 # Exclude `self`, `inputs`, and any argument with a default value. 655 if len(call_args) > 2: 656 if call_signature.defaults: 657 call_args = call_args[2:-len(call_signature.defaults)] 658 else: 659 call_args = call_args[2:] 660 for arg in call_args: 661 if arg == 'training': 662 # Case where `training` is a positional arg with no default. 663 kwargs['training'] = False 664 else: 665 # Has invalid call signature with unknown positional arguments. 666 raise ValueError( 667 'Currently, you cannot build your model if it has ' 668 'positional or keyword arguments that are not ' 669 'inputs to the model, but are required for its ' 670 '`call` method. Instead, in order to instantiate ' 671 'and build your model, `call` your model on real ' 672 'tensor data with all expected call arguments.') 673 elif len(call_args) < 2: 674 # Signature without `inputs`. 675 raise ValueError('You can only call `build` on a model if its `call` ' 676 'method accepts an `inputs` argument.') 677 try: 678 self.call(x, **kwargs) 679 except (errors.InvalidArgumentError, TypeError): 680 raise ValueError('You cannot build your model by calling `build` ' 681 'if your layers do not support float type inputs. ' 682 'Instead, in order to instantiate and build your ' 683 'model, `call` your model on real tensor data (of ' 684 'the correct dtype).') 685 686 self.built = True 687 688 def call(self, inputs, training=None, mask=None): 689 """Calls the model on new inputs. 690 691 In this case `call` just reapplies 692 all ops in the graph to the new inputs 693 (e.g. build a new computational graph from the provided inputs). 694 695 Arguments: 696 inputs: A tensor or list of tensors. 697 training: Boolean or boolean scalar tensor, indicating whether to run 698 the `Network` in training mode or inference mode. 699 mask: A mask or list of masks. A mask can be 700 either a tensor or None (no mask). 701 702 Returns: 703 A tensor if there is a single output, or 704 a list of tensors if there are more than one outputs. 705 """ 706 if not self._is_graph_network: 707 raise NotImplementedError('When subclassing the `Model` class, you should' 708 ' implement a `call` method.') 709 710 return self._run_internal_graph( 711 inputs, training=training, mask=mask, 712 convert_kwargs_to_constants=base_layer_utils.call_context().saving) 713 714 def compute_output_shape(self, input_shape): 715 if not self._is_graph_network: 716 return super(Network, self).compute_output_shape(input_shape) 717 718 # Convert any shapes in tuple format to TensorShapes. 719 input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False) 720 721 if len(nest.flatten(input_shape)) != len(nest.flatten(self._input_layers)): 722 raise ValueError('Invalid input_shape argument ' + str(input_shape) + 723 ': model has ' + str(len(self._input_layers)) + 724 ' tensor inputs.') 725 726 cache_key = generic_utils.object_list_uid(input_shape) 727 if cache_key in self._output_shape_cache: 728 # Cache hit. Return shapes as TensorShapes. 729 return self._output_shape_cache[cache_key] 730 731 layers_to_output_shapes = {} 732 for layer, shape in zip(self._input_layers, nest.flatten(input_shape)): 733 # It's an input layer: then `compute_output_shape` is identity, 734 # and there is only one node and one tensor.. 735 shape_key = layer.name + '_0_0' 736 layers_to_output_shapes[shape_key] = shape 737 738 depth_keys = list(self._nodes_by_depth.keys()) 739 depth_keys.sort(reverse=True) 740 # Iterate over nodes, by depth level. 741 if len(depth_keys) > 1: 742 for depth in depth_keys: 743 nodes = self._nodes_by_depth[depth] 744 for node in nodes: 745 # This is always a single layer, never a list. 746 layer = node.outbound_layer 747 if layer in self._input_layers: 748 # We've already covered the input layers 749 # a few lines above. 750 continue 751 # Potentially redundant list, 752 # same size as node.input_tensors. 753 layer_input_shapes = [] 754 for inbound_layer, node_id, tensor_id, _ in node.iterate_inbound(): 755 input_layer_key = inbound_layer.name + '_%s_%s' % (node_id, 756 tensor_id) 757 layer_input_shapes.append(layers_to_output_shapes[input_layer_key]) 758 layer_input_shapes = nest.pack_sequence_as(node.inbound_layers, 759 layer_input_shapes) 760 # Layers expect shapes to be tuples for `compute_output_shape`. 761 layer_input_shapes = tf_utils.convert_shapes( 762 layer_input_shapes, to_tuples=True) 763 layer_output_shapes = layer.compute_output_shape(layer_input_shapes) 764 # Convert back to TensorShapes. 765 layer_output_shapes = tf_utils.convert_shapes( 766 layer_output_shapes, to_tuples=False) 767 768 node_index = layer._inbound_nodes.index(node) # pylint: disable=protected-access 769 for j, shape in enumerate(nest.flatten(layer_output_shapes)): 770 shape_key = layer.name + '_%s_%s' % (node_index, j) 771 layers_to_output_shapes[shape_key] = shape 772 773 # Read final output shapes from layers_to_output_shapes. 774 output_shapes = [] 775 for i in range(len(self._output_layers)): 776 layer, node_index, tensor_index = self._output_coordinates[i] 777 shape_key = layer.name + '_%s_%s' % (node_index, tensor_index) 778 output_shapes.append(layers_to_output_shapes[shape_key]) 779 output_shapes = nest.pack_sequence_as(self._nested_outputs, output_shapes) 780 # Store in cache. 781 self._output_shape_cache[cache_key] = output_shapes 782 783 # Return shapes as TensorShapes. 784 return output_shapes 785 786 def _run_internal_graph(self, inputs, training=None, mask=None, 787 convert_kwargs_to_constants=False): 788 """Computes output tensors for new inputs. 789 790 # Note: 791 - Can be run on non-Keras tensors. 792 793 Arguments: 794 inputs: Tensor or nested structure of Tensors. 795 training: Boolean learning phase. 796 mask: (Optional) Tensor or nested structure of Tensors. 797 convert_kwargs_to_constants: Whether to convert Tensor kwargs to 798 constants. This is used when tracing the model call function during 799 saving to ensure that external tensors aren't captured. 800 801 Returns: 802 Two lists: output_tensors, output_masks 803 """ 804 # Note: masking support is relevant mainly for Keras. 805 # It cannot be factored out without having the fully reimplement the network 806 # calling logic on the Keras side. We choose to incorporate it in 807 # Network because 1) it may be useful to fully support in tf.layers in 808 # the future and 2) Keras is a major user of Network. If you don't 809 # use masking, it does not interfere with regular behavior at all and you 810 # can ignore it. 811 812 if isinstance(inputs, dict) and isinstance(self._nested_inputs, 813 (list, tuple)): 814 # Backwards compat: Allows passing a dict to a Model constructed with a 815 # list. Matches dict keys to input names. 816 inputs = [ 817 inputs[inp._keras_history.layer.name] for inp in self._nested_inputs 818 ] 819 else: 820 inputs = nest.flatten(inputs) 821 822 if mask is None: 823 masks = [None for _ in range(len(inputs))] 824 else: 825 masks = nest.flatten(mask) 826 827 for input_t, mask in zip(inputs, masks): 828 input_t._keras_mask = mask 829 830 # Dictionary mapping reference tensors to computed tensors. 831 tensor_dict = {} 832 833 for x, y in zip(self.inputs, inputs): 834 x_id = str(id(x)) 835 tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id] 836 if isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor): 837 try: 838 y.set_shape(y.shape.merge_with(x.shape)) 839 except ValueError: 840 logging.warning( 841 'Model was constructed with shape {} for input {}, but it was ' 842 're-called on a Tensor with incompatible shape {}.' 843 .format(x, x.shape, y.shape)) 844 845 depth_keys = list(self._nodes_by_depth.keys()) 846 depth_keys.sort(reverse=True) 847 # Ignore the InputLayers when computing the graph. 848 depth_keys = depth_keys[1:] 849 850 for depth in depth_keys: 851 nodes = self._nodes_by_depth[depth] 852 for node in nodes: 853 # This is always a single layer, never a list. 854 layer = node.outbound_layer 855 856 if all( 857 str(id(tensor)) in tensor_dict 858 for tensor in nest.flatten(node.input_tensors)): 859 860 # Call layer (reapplying ops to new inputs). 861 computed_tensors = nest.map_structure( 862 lambda t: tensor_dict[str(id(t))].pop(), node.input_tensors) 863 864 # Ensure `training` arg propagation if applicable. 865 kwargs = copy.copy(node.arguments) if node.arguments else {} 866 if convert_kwargs_to_constants: 867 kwargs = _map_tensors_to_constants(kwargs) 868 869 argspec = self._layer_call_argspecs[layer].args 870 if 'training' in argspec: 871 kwargs.setdefault('training', training) 872 if (type(kwargs['training']) is ops.Tensor and # pylint: disable=unidiomatic-typecheck 873 any([kwargs['training'] is x 874 for x in backend._GRAPH_LEARNING_PHASES.values()])): 875 kwargs['training'] = training # Materialize placeholder. 876 877 # Map Keras tensors in kwargs to their computed value. 878 def _map_tensor_if_from_keras_layer(t): 879 if isinstance(t, ops.Tensor) and hasattr(t, '_keras_history'): 880 t_id = str(id(t)) 881 return tensor_dict[t_id].pop() 882 return t 883 884 kwargs = nest.map_structure(_map_tensor_if_from_keras_layer, kwargs) 885 886 # Compute outputs. 887 output_tensors = layer(computed_tensors, **kwargs) 888 889 # Update tensor_dict. 890 for x, y in zip( 891 nest.flatten(node.output_tensors), nest.flatten(output_tensors)): 892 x_id = str(id(x)) 893 tensor_dict[x_id] = [y] * self._tensor_usage_count[x_id] 894 895 output_tensors = [] 896 output_shapes = [] 897 for x in self.outputs: 898 assert str(id(x)) in tensor_dict, 'Could not compute output ' + str(x) 899 tensor = tensor_dict[str(id(x))].pop() 900 output_shapes.append(x.shape) 901 output_tensors.append(tensor) 902 903 if output_shapes is not None: 904 input_shapes = [x.shape for x in inputs] 905 cache_key = generic_utils.object_list_uid(input_shapes) 906 self._output_shape_cache[cache_key] = nest.pack_sequence_as( 907 self._nested_outputs, output_shapes) 908 909 output_tensors = nest.pack_sequence_as(self._nested_outputs, output_tensors) 910 return output_tensors 911 912 def get_config(self): 913 if not self._is_graph_network: 914 raise NotImplementedError 915 return copy.deepcopy(get_network_config(self)) 916 917 @classmethod 918 def from_config(cls, config, custom_objects=None): 919 """Instantiates a Model from its config (output of `get_config()`). 920 921 Arguments: 922 config: Model config dictionary. 923 custom_objects: Optional dictionary mapping names 924 (strings) to custom classes or functions to be 925 considered during deserialization. 926 927 Returns: 928 A model instance. 929 930 Raises: 931 ValueError: In case of improperly formatted config dict. 932 """ 933 input_tensors, output_tensors, created_layers = reconstruct_from_config( 934 config, custom_objects) 935 model = cls(inputs=input_tensors, outputs=output_tensors, 936 name=config.get('name')) 937 connect_ancillary_layers(model, created_layers) 938 return model 939 940 def save(self, 941 filepath, 942 overwrite=True, 943 include_optimizer=True, 944 save_format=None, 945 signatures=None, 946 options=None): 947 """Saves the model to Tensorflow SavedModel or a single HDF5 file. 948 949 The savefile includes: 950 - The model architecture, allowing to re-instantiate the model. 951 - The model weights. 952 - The state of the optimizer, allowing to resume training 953 exactly where you left off. 954 955 This allows you to save the entirety of the state of a model 956 in a single file. 957 958 Saved models can be reinstantiated via `keras.models.load_model`. 959 The model returned by `load_model` is a compiled model ready to be used 960 (unless the saved model was never compiled in the first place). 961 962 Models built with the Sequential and Functional API can be saved to both the 963 HDF5 and SavedModel formats. Subclassed models can only be saved with the 964 SavedModel format. 965 966 Note that the model weights may have different scoped names after being 967 loaded. Scoped names include the model/layer names, such as 968 "dense_1/kernel:0"`. It is recommended that you use the layer properties to 969 access specific variables, e.g. `model.get_layer("dense_1").kernel`. 970 971 Arguments: 972 filepath: String, path to SavedModel or H5 file to save the model. 973 overwrite: Whether to silently overwrite any existing file at the 974 target location, or provide the user with a manual prompt. 975 include_optimizer: If True, save optimizer's state together. 976 save_format: Either 'tf' or 'h5', indicating whether to save the model 977 to Tensorflow SavedModel or HDF5. Defaults to 'tf' in TF 2.X, and 978 'h5' in TF 1.X. 979 signatures: Signatures to save with the SavedModel. Applicable to the 980 'tf' format only. Please see the `signatures` argument in 981 `tf.saved_model.save` for details. 982 options: Optional `tf.saved_model.SaveOptions` object that specifies 983 options for saving to SavedModel. 984 985 Example: 986 987 ```python 988 from keras.models import load_model 989 990 model.save('my_model.h5') # creates a HDF5 file 'my_model.h5' 991 del model # deletes the existing model 992 993 # returns a compiled model 994 # identical to the previous one 995 model = load_model('my_model.h5') 996 ``` 997 """ 998 save.save_model(self, filepath, overwrite, include_optimizer, save_format, 999 signatures, options) 1000 1001 def save_weights(self, filepath, overwrite=True, save_format=None): 1002 """Saves all layer weights. 1003 1004 Either saves in HDF5 or in TensorFlow format based on the `save_format` 1005 argument. 1006 1007 When saving in HDF5 format, the weight file has: 1008 - `layer_names` (attribute), a list of strings 1009 (ordered names of model layers). 1010 - For every layer, a `group` named `layer.name` 1011 - For every such layer group, a group attribute `weight_names`, 1012 a list of strings 1013 (ordered names of weights tensor of the layer). 1014 - For every weight in the layer, a dataset 1015 storing the weight value, named after the weight tensor. 1016 1017 When saving in TensorFlow format, all objects referenced by the network are 1018 saved in the same format as `tf.train.Checkpoint`, including any `Layer` 1019 instances or `Optimizer` instances assigned to object attributes. For 1020 networks constructed from inputs and outputs using `tf.keras.Model(inputs, 1021 outputs)`, `Layer` instances used by the network are tracked/saved 1022 automatically. For user-defined classes which inherit from `tf.keras.Model`, 1023 `Layer` instances must be assigned to object attributes, typically in the 1024 constructor. See the documentation of `tf.train.Checkpoint` and 1025 `tf.keras.Model` for details. 1026 1027 While the formats are the same, do not mix `save_weights` and 1028 `tf.train.Checkpoint`. Checkpoints saved by `Model.save_weights` should be 1029 loaded using `Model.load_weights`. Checkpoints saved using 1030 `tf.train.Checkpoint.save` should be restored using the corresponding 1031 `tf.train.Checkpoint.restore`. Prefer `tf.train.Checkpoint` over 1032 `save_weights` for training checkpoints. 1033 1034 The TensorFlow format matches objects and variables by starting at a root 1035 object, `self` for `save_weights`, and greedily matching attribute 1036 names. For `Model.save` this is the `Model`, and for `Checkpoint.save` this 1037 is the `Checkpoint` even if the `Checkpoint` has a model attached. This 1038 means saving a `tf.keras.Model` using `save_weights` and loading into a 1039 `tf.train.Checkpoint` with a `Model` attached (or vice versa) will not match 1040 the `Model`'s variables. See the [guide to training 1041 checkpoints](https://www.tensorflow.org/guide/checkpoint) for details 1042 on the TensorFlow format. 1043 1044 Arguments: 1045 filepath: String, path to the file to save the weights to. When saving 1046 in TensorFlow format, this is the prefix used for checkpoint files 1047 (multiple files are generated). Note that the '.h5' suffix causes 1048 weights to be saved in HDF5 format. 1049 overwrite: Whether to silently overwrite any existing file at the 1050 target location, or provide the user with a manual prompt. 1051 save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or 1052 '.keras' will default to HDF5 if `save_format` is `None`. Otherwise 1053 `None` defaults to 'tf'. 1054 1055 Raises: 1056 ImportError: If h5py is not available when attempting to save in HDF5 1057 format. 1058 ValueError: For invalid/unknown format arguments. 1059 """ 1060 self._assert_weights_created() 1061 filepath_is_h5 = _is_hdf5_filepath(filepath) 1062 if save_format is None: 1063 if filepath_is_h5: 1064 save_format = 'h5' 1065 else: 1066 save_format = 'tf' 1067 else: 1068 user_format = save_format.lower().strip() 1069 if user_format in ('tensorflow', 'tf'): 1070 save_format = 'tf' 1071 elif user_format in ('hdf5', 'h5', 'keras'): 1072 save_format = 'h5' 1073 else: 1074 raise ValueError( 1075 'Unknown format "%s". Was expecting one of {"tf", "h5"}.' % ( 1076 save_format,)) 1077 if save_format == 'tf' and filepath_is_h5: 1078 raise ValueError( 1079 ('save_weights got save_format="tf"/"tensorflow", but the ' 1080 'filepath ("%s") looks like an HDF5 file. Omit the ".h5"/".keras" ' 1081 'when saving in TensorFlow format.') 1082 % filepath) 1083 1084 if save_format == 'h5' and h5py is None: 1085 raise ImportError( 1086 '`save_weights` requires h5py when saving in hdf5.') 1087 if save_format == 'tf': 1088 check_filepath = filepath + '.index' 1089 else: 1090 check_filepath = filepath 1091 # If file exists and should not be overwritten: 1092 if not overwrite and os.path.isfile(check_filepath): 1093 proceed = ask_to_proceed_with_overwrite(check_filepath) 1094 if not proceed: 1095 return 1096 if save_format == 'h5': 1097 with h5py.File(filepath, 'w') as f: 1098 hdf5_format.save_weights_to_hdf5_group(f, self.layers) 1099 else: 1100 if context.executing_eagerly(): 1101 session = None 1102 else: 1103 session = backend.get_session() 1104 optimizer = getattr(self, 'optimizer', None) 1105 if (optimizer 1106 and not isinstance(optimizer, trackable.Trackable)): 1107 logging.warning( 1108 ('This model was compiled with a Keras optimizer (%s) but is being ' 1109 'saved in TensorFlow format with `save_weights`. The model\'s ' 1110 'weights will be saved, but unlike with TensorFlow optimizers in ' 1111 'the TensorFlow format the optimizer\'s state will not be ' 1112 'saved.\n\nConsider using a TensorFlow optimizer from `tf.train`.') 1113 % (optimizer,)) 1114 self._trackable_saver.save(filepath, session=session) 1115 # Record this checkpoint so it's visible from tf.train.latest_checkpoint. 1116 checkpoint_management.update_checkpoint_state_internal( 1117 save_dir=os.path.dirname(filepath), 1118 model_checkpoint_path=filepath, 1119 save_relative_paths=True, 1120 all_model_checkpoint_paths=[filepath]) 1121 1122 def load_weights(self, filepath, by_name=False, skip_mismatch=False): 1123 """Loads all layer weights, either from a TensorFlow or an HDF5 weight file. 1124 1125 If `by_name` is False weights are loaded based on the network's 1126 topology. This means the architecture should be the same as when the weights 1127 were saved. Note that layers that don't have weights are not taken into 1128 account in the topological ordering, so adding or removing layers is fine as 1129 long as they don't have weights. 1130 1131 If `by_name` is True, weights are loaded into layers only if they share the 1132 same name. This is useful for fine-tuning or transfer-learning models where 1133 some of the layers have changed. 1134 1135 Only topological loading (`by_name=False`) is supported when loading weights 1136 from the TensorFlow format. Note that topological loading differs slightly 1137 between TensorFlow and HDF5 formats for user-defined classes inheriting from 1138 `tf.keras.Model`: HDF5 loads based on a flattened list of weights, while the 1139 TensorFlow format loads based on the object-local names of attributes to 1140 which layers are assigned in the `Model`'s constructor. 1141 1142 Arguments: 1143 filepath: String, path to the weights file to load. For weight files in 1144 TensorFlow format, this is the file prefix (the same as was passed 1145 to `save_weights`). 1146 by_name: Boolean, whether to load weights by name or by topological 1147 order. Only topological loading is supported for weight files in 1148 TensorFlow format. 1149 skip_mismatch: Boolean, whether to skip loading of layers where there is 1150 a mismatch in the number of weights, or a mismatch in the shape of 1151 the weight (only valid when `by_name=True`). 1152 1153 Returns: 1154 When loading a weight file in TensorFlow format, returns the same status 1155 object as `tf.train.Checkpoint.restore`. When graph building, restore 1156 ops are run automatically as soon as the network is built (on first call 1157 for user-defined classes inheriting from `Model`, immediately if it is 1158 already built). 1159 1160 When loading weights in HDF5 format, returns `None`. 1161 1162 Raises: 1163 ImportError: If h5py is not available and the weight file is in HDF5 1164 format. 1165 ValueError: If `skip_mismatch` is set to `True` when `by_name` is 1166 `False`. 1167 """ 1168 1169 if skip_mismatch and not by_name: 1170 raise ValueError( 1171 'When calling model.load_weights, skip_mismatch can only be set to ' 1172 'True when by_name is True.') 1173 1174 if _is_hdf5_filepath(filepath): 1175 save_format = 'h5' 1176 else: 1177 try: 1178 py_checkpoint_reader.NewCheckpointReader(filepath) 1179 save_format = 'tf' 1180 except errors_impl.DataLossError: 1181 # The checkpoint is not readable in TensorFlow format. Try HDF5. 1182 save_format = 'h5' 1183 if save_format == 'tf': 1184 status = self._trackable_saver.restore(filepath) 1185 if by_name: 1186 raise NotImplementedError( 1187 'Weights may only be loaded based on topology into Models when ' 1188 'loading TensorFlow-formatted weights (got by_name=True to ' 1189 'load_weights).') 1190 if not context.executing_eagerly(): 1191 session = backend.get_session() 1192 # Restore existing variables (if any) immediately, and set up a 1193 # streaming restore for any variables created in the future. 1194 trackable_utils.streaming_restore(status=status, session=session) 1195 status.assert_nontrivial_match() 1196 return status 1197 if h5py is None: 1198 raise ImportError( 1199 '`load_weights` requires h5py when loading weights from HDF5.') 1200 if self._is_graph_network and not self.built: 1201 raise NotImplementedError( 1202 'Unable to load weights saved in HDF5 format into a subclassed ' 1203 'Model which has not created its variables yet. Call the Model ' 1204 'first, then load the weights.') 1205 self._assert_weights_created() 1206 with h5py.File(filepath, 'r') as f: 1207 if 'layer_names' not in f.attrs and 'model_weights' in f: 1208 f = f['model_weights'] 1209 if by_name: 1210 hdf5_format.load_weights_from_hdf5_group_by_name( 1211 f, self.layers, skip_mismatch=skip_mismatch) 1212 else: 1213 hdf5_format.load_weights_from_hdf5_group(f, self.layers) 1214 1215 def _updated_config(self): 1216 """Util shared between different serialization methods. 1217 1218 Returns: 1219 Model config with Keras version information added. 1220 """ 1221 from tensorflow.python.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top 1222 1223 config = self.get_config() 1224 model_config = { 1225 'class_name': self.__class__.__name__, 1226 'config': config, 1227 'keras_version': keras_version, 1228 'backend': backend.backend() 1229 } 1230 return model_config 1231 1232 def to_json(self, **kwargs): 1233 """Returns a JSON string containing the network configuration. 1234 1235 To load a network from a JSON save file, use 1236 `keras.models.model_from_json(json_string, custom_objects={})`. 1237 1238 Arguments: 1239 **kwargs: Additional keyword arguments 1240 to be passed to `json.dumps()`. 1241 1242 Returns: 1243 A JSON string. 1244 """ 1245 model_config = self._updated_config() 1246 return json.dumps( 1247 model_config, default=serialization.get_json_type, **kwargs) 1248 1249 def to_yaml(self, **kwargs): 1250 """Returns a yaml string containing the network configuration. 1251 1252 To load a network from a yaml save file, use 1253 `keras.models.model_from_yaml(yaml_string, custom_objects={})`. 1254 1255 `custom_objects` should be a dictionary mapping 1256 the names of custom losses / layers / etc to the corresponding 1257 functions / classes. 1258 1259 Arguments: 1260 **kwargs: Additional keyword arguments 1261 to be passed to `yaml.dump()`. 1262 1263 Returns: 1264 A YAML string. 1265 1266 Raises: 1267 ImportError: if yaml module is not found. 1268 """ 1269 if yaml is None: 1270 raise ImportError( 1271 'Requires yaml module installed (`pip install pyyaml`).') 1272 return yaml.dump(self._updated_config(), **kwargs) 1273 1274 def summary(self, line_length=None, positions=None, print_fn=None): 1275 """Prints a string summary of the network. 1276 1277 Arguments: 1278 line_length: Total length of printed lines 1279 (e.g. set this to adapt the display to different 1280 terminal window sizes). 1281 positions: Relative or absolute positions of log elements 1282 in each line. If not provided, 1283 defaults to `[.33, .55, .67, 1.]`. 1284 print_fn: Print function to use. Defaults to `print`. 1285 It will be called on each line of the summary. 1286 You can set it to a custom function 1287 in order to capture the string summary. 1288 1289 Raises: 1290 ValueError: if `summary()` is called before the model is built. 1291 """ 1292 if not self.built: 1293 raise ValueError('This model has not yet been built. ' 1294 'Build the model first by calling `build()` or calling ' 1295 '`fit()` with some data, or specify ' 1296 'an `input_shape` argument in the first layer(s) for ' 1297 'automatic build.') 1298 layer_utils.print_summary(self, 1299 line_length=line_length, 1300 positions=positions, 1301 print_fn=print_fn) 1302 1303 def _validate_graph_inputs_and_outputs(self): 1304 """Validates the inputs and outputs of a Graph Network.""" 1305 # Check for redundancy in inputs. 1306 if len({id(i) for i in self.inputs}) != len(self.inputs): 1307 raise ValueError('The list of inputs passed to the model ' 1308 'is redundant. ' 1309 'All inputs should only appear once.' 1310 ' Found: ' + str(self.inputs)) 1311 1312 for x in self.inputs: 1313 # Check that x has appropriate `_keras_history` metadata. 1314 if not hasattr(x, '_keras_history'): 1315 cls_name = self.__class__.__name__ 1316 raise ValueError('Input tensors to a ' + cls_name + ' ' + 1317 'must come from `tf.keras.Input`. ' 1318 'Received: ' + str(x) + 1319 ' (missing previous layer metadata).') 1320 # Check that x is an input tensor. 1321 # pylint: disable=protected-access 1322 layer = x._keras_history.layer 1323 if len(layer._inbound_nodes) > 1 or ( 1324 layer._inbound_nodes and layer._inbound_nodes[0].inbound_layers): 1325 cls_name = self.__class__.__name__ 1326 logging.warning(cls_name + ' inputs must come from ' 1327 '`tf.keras.Input` (thus holding past layer metadata), ' 1328 'they cannot be the output of ' 1329 'a previous non-Input layer. ' 1330 'Here, a tensor specified as ' 1331 'input to "' + self.name + '" was not an Input tensor, ' 1332 'it was generated by layer ' + layer.name + '.\n' 1333 'Note that input tensors are ' 1334 'instantiated via `tensor = tf.keras.Input(shape)`.\n' 1335 'The tensor that caused the issue was: ' + str(x.name)) 1336 if isinstance(x, ragged_tensor.RaggedTensor): 1337 self._supports_ragged_inputs = True 1338 1339 # Check compatibility of batch sizes of Input Layers. 1340 input_batch_sizes = [ 1341 training_utils.get_static_batch_size(x._keras_history.layer) 1342 for x in self.inputs 1343 ] 1344 consistent_batch_size = None 1345 for batch_size in input_batch_sizes: 1346 if batch_size is not None: 1347 if (consistent_batch_size is not None and 1348 batch_size != consistent_batch_size): 1349 raise ValueError('The specified batch sizes of the Input Layers' 1350 ' are incompatible. Found batch sizes: {}'.format( 1351 input_batch_sizes)) 1352 consistent_batch_size = batch_size 1353 1354 for x in self.outputs: 1355 if not hasattr(x, '_keras_history'): 1356 cls_name = self.__class__.__name__ 1357 raise ValueError('Output tensors to a ' + cls_name + ' must be ' 1358 'the output of a TensorFlow `Layer` ' 1359 '(thus holding past layer metadata). Found: ' + str(x)) 1360 1361 def _insert_layers(self, layers, relevant_nodes=None): 1362 """Inserts Layers into the Network after Network creation. 1363 1364 This is only valid for Keras Graph Networks. Layers added via this function 1365 will be included in the `call` computation and `get_config` of this Network. 1366 They will not be added to the Network's outputs. 1367 1368 1369 Arguments: 1370 layers: Arbitrary nested structure of Layers. Layers must be reachable 1371 from one or more of the `keras.Input` Tensors that correspond to this 1372 Network's inputs. 1373 relevant_nodes: Nodes from the Layers that should be considered part of 1374 this Network. If `None`, all Nodes will be considered part of this 1375 Network. 1376 1377 Raises: 1378 ValueError: If the layers depend on `Input`s not found in this Model. 1379 """ 1380 layers = nest.flatten(layers) 1381 tf_utils.assert_no_legacy_layers(layers) 1382 node_to_depth = {} 1383 for depth, nodes in self._nodes_by_depth.items(): 1384 node_to_depth.update({node: depth for node in nodes}) 1385 # The nodes of these Layers that are relevant to this Network. If not 1386 # provided, assume all Nodes are relevant 1387 if not relevant_nodes: 1388 relevant_nodes = nest.flatten([layer._inbound_nodes for layer in layers]) 1389 network_nodes = set(relevant_nodes + list(node_to_depth.keys())) 1390 1391 def _get_min_depth(node): 1392 """Gets the minimum depth at which node can be computed.""" 1393 min_depth = 0 1394 for layer, node_id, _, _ in node.iterate_inbound(include_arguments=True): 1395 inbound_node = layer._inbound_nodes[node_id] 1396 if inbound_node in node_to_depth: 1397 min_depth = min(min_depth, node_to_depth[inbound_node]) 1398 elif inbound_node not in network_nodes: 1399 continue 1400 else: 1401 # Previous relevant nodes haven't been processed yet. 1402 return None 1403 # New node is one shallower than its shallowest input. 1404 return min_depth - 1 1405 1406 # Insert nodes into `_nodes_by_depth` and other node attrs. 1407 unprocessed_nodes = copy.copy(relevant_nodes) 1408 i = 0 1409 while unprocessed_nodes: 1410 i += 1 1411 # Do a sanity check. This can occur if `Input`s from outside this Model 1412 # are being relied on. 1413 if i > 10000: 1414 raise ValueError('Layers could not be added due to missing ' 1415 'dependencies.') 1416 1417 node = unprocessed_nodes.pop(0) 1418 depth = _get_min_depth(node) 1419 if depth is None: # Defer until inbound nodes are processed. 1420 unprocessed_nodes.append(node) 1421 continue 1422 node_key = _make_node_key(node.outbound_layer.name, 1423 node.outbound_layer._inbound_nodes.index(node)) 1424 if node_key not in self._network_nodes: 1425 node_to_depth[node] = depth 1426 self._network_nodes.add(node_key) 1427 self._nodes_by_depth[depth].append(node) 1428 1429 # Insert layers and update other layer attrs. 1430 layer_set = set(self._layers) 1431 deferred_layers = [] 1432 for layer in layers: 1433 if layer not in layer_set: 1434 self._layers.append(layer) 1435 deferred_layers.append(layer) 1436 self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call) 1437 1438 # This allows the added layer to broadcast mutations to the current 1439 # layer, which is necessary to ensure cache correctness. 1440 layer._attribute_sentinel.add_parent(self._attribute_sentinel) 1441 layer_set.add(layer) 1442 self._handle_deferred_layer_dependencies(deferred_layers) 1443 1444 self._compute_tensor_usage_count() 1445 1446 def _compute_tensor_usage_count(self): 1447 """Compute the #. of tensor usages for all the output tensors of layers. 1448 1449 The computed tensor usage count is saved as `self._tensor_usage_count`. This 1450 is later used for saving memory in eager computation by releasing 1451 no-longer-needed tensors as early as possible. 1452 """ 1453 tensor_usage_count = collections.Counter() 1454 available_tensors = set(str(id(tensor)) for tensor in self.inputs) 1455 1456 depth_keys = list(self._nodes_by_depth.keys()) 1457 depth_keys.sort(reverse=True) 1458 depth_keys = depth_keys[1:] 1459 1460 for depth in depth_keys: 1461 for node in self._nodes_by_depth[depth]: 1462 input_tensors = { 1463 str(id(tensor)) for tensor in nest.flatten(node.input_tensors) 1464 } 1465 if input_tensors.issubset(available_tensors): 1466 kwargs = copy.copy(node.arguments) if node.arguments else {} 1467 1468 for tensor in nest.flatten(kwargs): 1469 if isinstance(tensor, ops.Tensor) and hasattr(tensor, 1470 '_keras_history'): 1471 tensor_usage_count[str(id(tensor))] += 1 1472 1473 for tensor in nest.flatten(node.input_tensors): 1474 tensor_usage_count[str(id(tensor))] += 1 1475 1476 for output_tensor in nest.flatten(node.output_tensors): 1477 available_tensors.add(str(id(output_tensor))) 1478 1479 for tensor in self.outputs: 1480 tensor_usage_count[str(id(tensor))] += 1 1481 1482 self._tensor_usage_count = tensor_usage_count 1483 1484 def _assert_weights_created(self): 1485 """Asserts that all the weights for the network have been created. 1486 1487 For a non-dynamic network, the weights must already be created after the 1488 layer has been called. For a dynamic network, the exact list of weights can 1489 never be known for certain since it may change at any time during execution. 1490 1491 We run this check right before accessing weights or getting the Numpy value 1492 for the current weights. Otherwise, if the layer has never been called, 1493 the user would just get an empty list, which is misleading. 1494 1495 Raises: 1496 ValueError: if the weights of the network has not yet been created. 1497 """ 1498 if self.dynamic: 1499 return 1500 if (not self._is_graph_network and 1501 'build' in self.__class__.__dict__ and 1502 not self.built): 1503 # For any model that has customized build() method but hasn't 1504 # been invoked yet, this will cover both sequential and subclass model. 1505 raise ValueError('Weights for model %s have not yet been created. ' 1506 'Weights are created when the Model is first called on ' 1507 'inputs or `build()` is called with an `input_shape`.' % 1508 self.name) 1509 1510 def _graph_network_add_loss(self, symbolic_loss): 1511 new_nodes, new_layers = _map_subgraph_network(self.inputs, [symbolic_loss]) 1512 # Losses must be keyed on inputs no matter what in order to be supported in 1513 # DistributionStrategy. 1514 add_loss_layer = base_layer.AddLoss( 1515 unconditional=False, dtype=symbolic_loss.dtype) 1516 add_loss_layer(symbolic_loss) 1517 new_nodes.extend(add_loss_layer.inbound_nodes) 1518 new_layers.append(add_loss_layer) 1519 self._insert_layers(new_layers, new_nodes) 1520 1521 def _graph_network_add_metric(self, value, aggregation, name): 1522 new_nodes, new_layers = _map_subgraph_network(self.inputs, [value]) 1523 add_metric_layer = base_layer.AddMetric( 1524 aggregation, name, dtype=value.dtype) 1525 add_metric_layer(value) 1526 new_nodes.extend(add_metric_layer.inbound_nodes) 1527 new_layers.append(add_metric_layer) 1528 self._insert_layers(new_layers, new_nodes) 1529 1530 @property 1531 def _trackable_saved_model_saver(self): 1532 return network_serialization.NetworkSavedModelSaver(self) 1533 1534 1535def _is_hdf5_filepath(filepath): 1536 return (filepath.endswith('.h5') or filepath.endswith('.keras') or 1537 filepath.endswith('.hdf5')) 1538 1539 1540def _make_node_key(layer_name, node_index): 1541 return layer_name + '_ib-' + str(node_index) 1542 1543 1544def _map_graph_network(inputs, outputs): 1545 """Validates a network's topology and gather its layers and nodes. 1546 1547 Arguments: 1548 inputs: List of input tensors. 1549 outputs: List of outputs tensors. 1550 1551 Returns: 1552 A tuple `(nodes, nodes_by_depth, layers, layers_by_depth)`. 1553 - nodes: list of Node instances. 1554 - nodes_by_depth: dict mapping ints (depth) to lists of node instances. 1555 - layers: list of Layer instances. 1556 - layers_by_depth: dict mapping ints (depth) to lists of layer instances. 1557 1558 Raises: 1559 ValueError: In case the network is not valid (e.g. disconnected graph). 1560 """ 1561 # Network_nodes: set of nodes included in the graph of layers 1562 # (not all nodes included in the layers are relevant to the current graph). 1563 network_nodes = set() # ids of all nodes relevant to the Network 1564 nodes_depths = {} # dict {node: depth value} 1565 layers_depths = {} # dict {layer: depth value} 1566 layer_indices = {} # dict {layer: index in traversal} 1567 nodes_in_decreasing_depth = [] 1568 1569 def build_map(tensor, 1570 finished_nodes, 1571 nodes_in_progress, 1572 layer, 1573 node_index, 1574 tensor_index): 1575 """Builds a map of the graph of layers. 1576 1577 This recursively updates the map `layer_indices`, 1578 the list `nodes_in_decreasing_depth` and the set `network_nodes`. 1579 1580 Arguments: 1581 tensor: Some tensor in a graph. 1582 finished_nodes: Set of nodes whose subgraphs have been traversed 1583 completely. Useful to prevent duplicated work. 1584 nodes_in_progress: Set of nodes that are currently active on the 1585 recursion stack. Useful to detect cycles. 1586 layer: Layer from which `tensor` comes from. If not provided, 1587 will be obtained from `tensor._keras_history`. 1588 node_index: Node index from which `tensor` comes from. 1589 tensor_index: Tensor_index from which `tensor` comes from. 1590 1591 Raises: 1592 ValueError: if a cycle is detected. 1593 """ 1594 node = layer._inbound_nodes[node_index] # pylint: disable=protected-access 1595 1596 # Prevent cycles. 1597 if node in nodes_in_progress: 1598 raise ValueError('The tensor ' + str(tensor) + ' at layer "' + 1599 layer.name + '" is part of a cycle.') 1600 1601 # Don't repeat work for shared subgraphs 1602 if node in finished_nodes: 1603 return 1604 1605 node_key = _make_node_key(layer.name, node_index) 1606 # Update network_nodes. 1607 network_nodes.add(node_key) 1608 1609 # Store the traversal order for layer sorting. 1610 if layer not in layer_indices: 1611 layer_indices[layer] = len(layer_indices) 1612 1613 nodes_in_progress.add(node) 1614 1615 # Propagate to all previous tensors connected to this node. 1616 for layer, node_index, tensor_index, tensor in node.iterate_inbound( 1617 include_arguments=True): 1618 build_map(tensor, finished_nodes, nodes_in_progress, layer, node_index, 1619 tensor_index) 1620 1621 finished_nodes.add(node) 1622 nodes_in_progress.remove(node) 1623 nodes_in_decreasing_depth.append(node) 1624 1625 finished_nodes = set() 1626 nodes_in_progress = set() 1627 for x in outputs: 1628 layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access 1629 build_map(x, finished_nodes, nodes_in_progress, 1630 layer=layer, 1631 node_index=node_index, 1632 tensor_index=tensor_index) 1633 1634 for node in reversed(nodes_in_decreasing_depth): 1635 # If the depth is not set, the node has no outbound nodes (depth 0). 1636 depth = nodes_depths.setdefault(node, 0) 1637 1638 # Update the depth of the corresponding layer 1639 previous_depth = layers_depths.get(node.outbound_layer, 0) 1640 # If we've seen this layer before at a higher depth, 1641 # we should use that depth instead of the node depth. 1642 # This is necessary for shared layers that have inputs at different 1643 # depth levels in the graph. 1644 depth = max(depth, previous_depth) 1645 layers_depths[node.outbound_layer] = depth 1646 nodes_depths[node] = depth 1647 1648 # Update the depth of inbound nodes. 1649 # The "depth" of a node is the max of the depths 1650 # of all nodes it is connected to + 1. 1651 for node_dep in node._get_all_node_dependencies(): 1652 previous_depth = nodes_depths.get(node_dep, 0) 1653 nodes_depths[node_dep] = max(depth + 1, previous_depth) 1654 1655 # Handle inputs that are not connected to outputs. 1656 # We do not error out here because the inputs may be used to compute losses 1657 # and metrics. 1658 for input_t in inputs: 1659 input_layer = input_t._keras_history[0] 1660 if input_layer not in layers_depths: 1661 layers_depths[input_layer] = 0 1662 layer_indices[input_layer] = -1 1663 nodes_depths[input_layer._inbound_nodes[0]] = 0 1664 network_nodes.add(_make_node_key(input_layer.name, 0)) 1665 1666 # Build a dict {depth: list of nodes with this depth} 1667 nodes_by_depth = collections.defaultdict(list) 1668 for node, depth in nodes_depths.items(): 1669 nodes_by_depth[depth].append(node) 1670 1671 # Build a dict {depth: list of layers with this depth} 1672 layers_by_depth = collections.defaultdict(list) 1673 for layer, depth in layers_depths.items(): 1674 layers_by_depth[depth].append(layer) 1675 1676 # Get sorted list of layer depths. 1677 depth_keys = list(layers_by_depth.keys()) 1678 depth_keys.sort(reverse=True) 1679 1680 # Set self.layers ordered by depth. 1681 layers = [] 1682 for depth in depth_keys: 1683 layers_for_depth = layers_by_depth[depth] 1684 # Network.layers needs to have a deterministic order: 1685 # here we order them by traversal order. 1686 layers_for_depth.sort(key=lambda x: layer_indices[x]) 1687 layers.extend(layers_for_depth) 1688 1689 # Get sorted list of node depths. 1690 depth_keys = list(nodes_by_depth.keys()) 1691 depth_keys.sort(reverse=True) 1692 1693 # Check that all tensors required are computable. 1694 # computable_tensors: all tensors in the graph 1695 # that can be computed from the inputs provided. 1696 computable_tensors = set() 1697 for x in inputs: 1698 computable_tensors.add(id(x)) 1699 1700 layers_with_complete_input = [] # To provide a better error msg. 1701 for depth in depth_keys: 1702 for node in nodes_by_depth[depth]: 1703 layer = node.outbound_layer 1704 if layer: 1705 for x in nest.flatten(node.input_tensors): 1706 if id(x) not in computable_tensors: 1707 raise ValueError('Graph disconnected: ' 1708 'cannot obtain value for tensor ' + str(x) + 1709 ' at layer "' + layer.name + '". ' 1710 'The following previous layers ' 1711 'were accessed without issue: ' + 1712 str(layers_with_complete_input)) 1713 for x in nest.flatten(node.output_tensors): 1714 computable_tensors.add(id(x)) 1715 layers_with_complete_input.append(layer.name) 1716 1717 # Ensure name unicity, which will be crucial for serialization 1718 # (since serialized nodes refer to layers by their name). 1719 all_names = [layer.name for layer in layers] 1720 for name in all_names: 1721 if all_names.count(name) != 1: 1722 raise ValueError('The name "' + name + '" is used ' + 1723 str(all_names.count(name)) + ' times in the model. ' 1724 'All layer names should be unique.') 1725 return network_nodes, nodes_by_depth, layers, layers_by_depth 1726 1727 1728def _map_subgraph_network(inputs, outputs): 1729 """Returns the nodes and layers in the topology from `inputs` to `outputs`. 1730 1731 Args: 1732 inputs: List of input tensors. 1733 outputs: List of output tensors. 1734 1735 Returns: 1736 A tuple of List{Node] and List[Layer]. 1737 """ 1738 base_layer_utils.create_keras_history(outputs) 1739 # Keep only nodes and layers in the topology between inputs and outputs. 1740 _, nodes_by_depth, layers, _ = _map_graph_network(inputs, outputs) 1741 return nest.flatten([nodes for nodes in nodes_by_depth.values()]), layers 1742 1743 1744def _should_skip_first_node(layer): 1745 """Returns True if the first layer node should not be saved or loaded.""" 1746 # Networks start with a pre-existing node linking their input to output. 1747 return issubclass(layer.__class__, Network) and layer._is_graph_network 1748 1749 1750def _serialize_tensors(kwargs): 1751 """Serializes Tensors passed to `call`.""" 1752 1753 def _serialize_keras_tensor(t): 1754 """Serializes a single Tensor passed to `call`.""" 1755 if hasattr(t, '_keras_history'): 1756 kh = t._keras_history 1757 return [kh.layer.name, kh.node_index, kh.tensor_index] 1758 1759 if isinstance(t, np.ndarray): 1760 return t.tolist() 1761 1762 if isinstance(t, ops.Tensor): 1763 return backend.get_value(t).tolist() 1764 1765 return t 1766 1767 return nest.map_structure(_serialize_keras_tensor, kwargs) 1768 1769 1770def _map_tensors_to_constants(kwargs): 1771 1772 def _map_to_constants(t): 1773 if not hasattr(t, '_keras_history') and isinstance(t, ops.Tensor): 1774 return constant_op.constant(backend.get_value(t)) 1775 return t 1776 1777 return nest.map_structure(_map_to_constants, kwargs) 1778 1779 1780def _deserialize_keras_tensors(kwargs, layer_map): 1781 """Deserializes Keras Tensors passed to `call`..""" 1782 1783 def _deserialize_keras_tensor(t): 1784 """Deserializes a single Keras Tensor passed to `call`.""" 1785 if isinstance(t, tf_utils.ListWrapper): 1786 t = t.as_list() 1787 layer_name = t[0] 1788 node_index = t[1] 1789 tensor_index = t[2] 1790 1791 layer = layer_map[layer_name] 1792 node = layer._inbound_nodes[node_index] 1793 return nest.flatten(node.output_tensors)[tensor_index] 1794 return t 1795 1796 kwargs = tf_utils.convert_inner_node_data(kwargs, wrap=True) 1797 return nest.map_structure(_deserialize_keras_tensor, kwargs) 1798 1799 1800def connect_ancillary_layers(model, created_layers): 1801 """Adds layers that are not connected to the outputs to the model.""" 1802 # Layers not connected to outputs, such as those added in `add_loss`. 1803 ancillary_layers = [ 1804 layer for layer in created_layers.values() if layer not in model.layers 1805 ] 1806 if ancillary_layers: 1807 relevant_nodes = nest.flatten([ 1808 layer.inbound_nodes[1:] 1809 if _should_skip_first_node(layer) else layer.inbound_nodes 1810 for layer in created_layers.values() 1811 ]) 1812 model._insert_layers(ancillary_layers, relevant_nodes) 1813 return model 1814 1815 1816def reconstruct_from_config(config, custom_objects=None, created_layers=None): 1817 """Reconstructs graph from config object. 1818 1819 Args: 1820 config: Dictionary returned from Network.get_config() 1821 custom_objects: Optional dictionary mapping names (strings) to custom 1822 classes or functions to be considered during deserialization. 1823 created_layers: Optional dictionary mapping names to Layer objects. Any 1824 layer not in this dictionary will be be created and added to the dict. 1825 This function will add new nodes to all layers (excluding InputLayers), 1826 instead of re-using pre-existing nodes in the layers. 1827 1828 Returns: 1829 Tuple of (input tensors, output tensors, dictionary of created layers) 1830 """ 1831 # Layer instances created during the graph reconstruction process. 1832 created_layers = created_layers or collections.OrderedDict() 1833 1834 # Maps input data (tuple of inbound layer name, node index) from the config 1835 # to node indices in the newly generated model. The node indices may be 1836 # different if the layers have already been called previously. 1837 node_index_map = {} 1838 node_count_by_layer = {} 1839 1840 # Dictionary mapping layer instances to 1841 # node data that specifies a layer call. 1842 # It acts as a queue that maintains any unprocessed 1843 # layer call until it becomes possible to process it 1844 # (i.e. until the input tensors to the call all exist). 1845 unprocessed_nodes = {} 1846 1847 def add_unprocessed_node(layer, node_data): 1848 if layer not in unprocessed_nodes: 1849 unprocessed_nodes[layer] = [node_data] 1850 else: 1851 unprocessed_nodes[layer].append(node_data) 1852 1853 def get_node_index(layer, config_node_index): 1854 """Returns node index in layer (might differ from config_node_index).""" 1855 if isinstance(layer, input_layer_module.InputLayer): 1856 return 0 1857 return node_index_map.get((layer.name, config_node_index), None) 1858 1859 def process_node(layer, node_data): 1860 """Deserialize a node. 1861 1862 Arguments: 1863 layer: layer instance. 1864 node_data: Nested structure of `ListWrapper`. 1865 1866 Raises: 1867 ValueError: In case of improperly formatted `node_data`. 1868 """ 1869 input_tensors = [] 1870 for input_data in nest.flatten(node_data): 1871 input_data = input_data.as_list() 1872 inbound_layer_name = input_data[0] 1873 inbound_node_index = input_data[1] 1874 inbound_tensor_index = input_data[2] 1875 if len(input_data) == 3: 1876 kwargs = {} 1877 elif len(input_data) == 4: 1878 kwargs = input_data[3] 1879 kwargs = _deserialize_keras_tensors(kwargs, created_layers) 1880 else: 1881 raise ValueError('Improperly formatted model config.') 1882 1883 inbound_layer = created_layers[inbound_layer_name] 1884 inbound_node_index = get_node_index(inbound_layer, inbound_node_index) 1885 1886 if inbound_node_index is None: 1887 add_unprocessed_node(layer, node_data) 1888 return 1889 inbound_node = inbound_layer._inbound_nodes[inbound_node_index] 1890 input_tensors.append( 1891 nest.flatten(inbound_node.output_tensors)[inbound_tensor_index]) 1892 input_tensors = nest.pack_sequence_as(node_data, input_tensors) 1893 # Call layer on its inputs, thus creating the node 1894 # and building the layer if needed. 1895 if input_tensors is not None: 1896 input_tensors = base_layer_utils.unnest_if_single_tensor(input_tensors) 1897 output_tensors = layer(input_tensors, **kwargs) 1898 1899 # Update node index map. 1900 output_index = nest.flatten(output_tensors)[0]._keras_history.node_index 1901 node_index_map[(layer.name, node_count_by_layer[layer])] = output_index 1902 node_count_by_layer[layer] += 1 1903 1904 def process_layer(layer_data): 1905 """Deserializes a layer, then call it on appropriate inputs. 1906 1907 Arguments: 1908 layer_data: layer config dict. 1909 1910 Raises: 1911 ValueError: In case of improperly formatted `layer_data` dict. 1912 """ 1913 layer_name = layer_data['name'] 1914 1915 if layer_name in created_layers: 1916 layer = created_layers[layer_name] 1917 else: 1918 # Instantiate layer. 1919 from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top 1920 1921 layer = deserialize_layer(layer_data, custom_objects=custom_objects) 1922 created_layers[layer_name] = layer 1923 1924 node_count_by_layer[layer] = int(_should_skip_first_node(layer)) 1925 1926 # Gather layer inputs and convert to `ListWrapper` objects. 1927 inbound_nodes_data = layer_data['inbound_nodes'] 1928 inbound_nodes_data = tf_utils.convert_inner_node_data( 1929 inbound_nodes_data, wrap=True) 1930 for node_data in inbound_nodes_data: 1931 # We don't process nodes (i.e. make layer calls) 1932 # on the fly because the inbound node may not yet exist, 1933 # in case of layer shared at different topological depths 1934 # (e.g. a model such as A(B(A(B(x))))) 1935 add_unprocessed_node(layer, node_data) 1936 1937 # First, we create all layers and enqueue nodes to be processed 1938 for layer_data in config['layers']: 1939 process_layer(layer_data) 1940 # Then we process nodes in order of layer depth. 1941 # Nodes that cannot yet be processed (if the inbound node 1942 # does not yet exist) are re-enqueued, and the process 1943 # is repeated until all nodes are processed. 1944 while unprocessed_nodes: 1945 for layer_data in config['layers']: 1946 layer = created_layers[layer_data['name']] 1947 if layer in unprocessed_nodes: 1948 for node_data in unprocessed_nodes.pop(layer): 1949 process_node(layer, node_data) 1950 1951 input_tensors = [] 1952 output_tensors = [] 1953 1954 input_layers = tf_utils.convert_inner_node_data( 1955 config['input_layers'], wrap=True) 1956 for layer_data in nest.flatten(input_layers): 1957 layer_name, node_index, tensor_index = layer_data.as_list() 1958 assert layer_name in created_layers 1959 layer = created_layers[layer_name] 1960 node_index = get_node_index(layer, node_index) 1961 layer_output_tensors = layer._inbound_nodes[node_index].output_tensors 1962 input_tensors.append(nest.flatten(layer_output_tensors)[tensor_index]) 1963 1964 output_layers = tf_utils.convert_inner_node_data( 1965 config['output_layers'], wrap=True) 1966 for layer_data in nest.flatten(output_layers): 1967 layer_name, node_index, tensor_index = layer_data.as_list() 1968 assert layer_name in created_layers 1969 layer = created_layers[layer_name] 1970 node_index = get_node_index(layer, node_index) 1971 layer_output_tensors = layer._inbound_nodes[node_index].output_tensors 1972 output_tensors.append(nest.flatten(layer_output_tensors)[tensor_index]) 1973 1974 input_tensors = nest.pack_sequence_as(input_layers, input_tensors) 1975 output_tensors = nest.pack_sequence_as(output_layers, output_tensors) 1976 return input_tensors, output_tensors, created_layers 1977 1978 1979def get_network_config(network, serialize_layer_fn=None): 1980 """Builds the config, which consists of the node graph and serialized layers. 1981 1982 Args: 1983 network: A Network object. 1984 serialize_layer_fn: Function used to serialize layers. 1985 1986 Returns: 1987 Config dictionary. 1988 """ 1989 serialize_layer_fn = ( 1990 serialize_layer_fn or generic_utils.serialize_keras_object) 1991 config = { 1992 'name': network.name, 1993 } 1994 node_conversion_map = {} 1995 for layer in network.layers: 1996 kept_nodes = 1 if _should_skip_first_node(layer) else 0 1997 for original_node_index, node in enumerate(layer._inbound_nodes): 1998 node_key = _make_node_key(layer.name, original_node_index) 1999 if node_key in network._network_nodes: 2000 node_conversion_map[node_key] = kept_nodes 2001 kept_nodes += 1 2002 layer_configs = [] 2003 for layer in network.layers: # From the earliest layers on. 2004 filtered_inbound_nodes = [] 2005 for original_node_index, node in enumerate(layer._inbound_nodes): 2006 node_key = _make_node_key(layer.name, original_node_index) 2007 if node_key in network._network_nodes: 2008 # The node is relevant to the model: 2009 # add to filtered_inbound_nodes. 2010 if node.arguments: 2011 kwargs = _serialize_tensors(node.arguments) 2012 try: 2013 json.dumps(kwargs) 2014 except TypeError: 2015 logging.warning( 2016 'Layer ' + layer.name + 2017 ' was passed non-serializable keyword arguments: ' + 2018 str(node.arguments) + '. They will not be included ' 2019 'in the serialized model (and thus will be missing ' 2020 'at deserialization time).') 2021 kwargs = {} 2022 else: 2023 kwargs = {} 2024 if node.inbound_layers: 2025 node_data = [] 2026 for inbound_layer, node_id, tensor_id, _ in node.iterate_inbound(): 2027 node_key = _make_node_key(inbound_layer.name, node_id) 2028 new_node_index = node_conversion_map.get(node_key, 0) 2029 node_data.append( 2030 tf_utils.ListWrapper( 2031 [inbound_layer.name, new_node_index, tensor_id, kwargs])) 2032 node_data = nest.pack_sequence_as(node.input_tensors, node_data) 2033 if not nest.is_sequence(node_data): 2034 node_data = [node_data] 2035 # Convert ListWrapper to list for backwards compatible configs. 2036 node_data = tf_utils.convert_inner_node_data(node_data) 2037 filtered_inbound_nodes.append(node_data) 2038 2039 layer_config = serialize_layer_fn(layer) 2040 layer_config['name'] = layer.name 2041 layer_config['inbound_nodes'] = filtered_inbound_nodes 2042 layer_configs.append(layer_config) 2043 config['layers'] = layer_configs 2044 2045 # Gather info about inputs and outputs. 2046 model_inputs = [] 2047 for i in range(len(network._input_layers)): 2048 layer, node_index, tensor_index = network._input_coordinates[i] 2049 node_key = _make_node_key(layer.name, node_index) 2050 if node_key not in network._network_nodes: 2051 continue 2052 new_node_index = node_conversion_map[node_key] 2053 model_inputs.append( 2054 tf_utils.ListWrapper([layer.name, new_node_index, tensor_index])) 2055 model_inputs = nest.pack_sequence_as(network._nested_inputs, model_inputs) 2056 # Preserve external Keras compat for Models with single input. 2057 if not nest.is_sequence(model_inputs): 2058 model_inputs = [model_inputs] 2059 model_inputs = tf_utils.convert_inner_node_data(model_inputs) 2060 config['input_layers'] = model_inputs 2061 2062 model_outputs = [] 2063 for i in range(len(network._output_layers)): 2064 layer, node_index, tensor_index = network._output_coordinates[i] 2065 node_key = _make_node_key(layer.name, node_index) 2066 if node_key not in network._network_nodes: 2067 continue 2068 new_node_index = node_conversion_map[node_key] 2069 model_outputs.append( 2070 tf_utils.ListWrapper([layer.name, new_node_index, tensor_index])) 2071 model_outputs = nest.pack_sequence_as(network._nested_outputs, model_outputs) 2072 # Preserve external Keras compat for Models with single output. 2073 if not nest.is_sequence(model_outputs): 2074 model_outputs = [model_outputs] 2075 model_outputs = tf_utils.convert_inner_node_data(model_outputs) 2076 config['output_layers'] = model_outputs 2077 return config 2078