1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""A `Network` is way to compose layers: the topological form of a `Model`. 17""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23import copy 24import itertools 25import warnings 26 27from six.moves import zip # pylint: disable=redefined-builtin 28 29from tensorflow.python.eager import context 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import ops 32from tensorflow.python.keras import backend 33from tensorflow.python.keras.engine import base_layer 34from tensorflow.python.keras.engine import base_layer_utils 35from tensorflow.python.keras.engine import input_layer as input_layer_module 36from tensorflow.python.keras.engine import input_spec 37from tensorflow.python.keras.engine import keras_tensor 38from tensorflow.python.keras.engine import node as node_module 39from tensorflow.python.keras.engine import training as training_lib 40from tensorflow.python.keras.engine import training_utils 41from tensorflow.python.keras.saving.saved_model import network_serialization 42from tensorflow.python.keras.utils import generic_utils 43from tensorflow.python.keras.utils import tf_inspect 44from tensorflow.python.keras.utils import tf_utils 45from tensorflow.python.ops import array_ops 46from tensorflow.python.ops import math_ops 47from tensorflow.python.platform import tf_logging as logging 48from tensorflow.python.training.tracking import base as trackable 49from tensorflow.python.util import nest 50from tensorflow.tools.docs import doc_controls 51 52 53# pylint: disable=g-classes-have-attributes 54class Functional(training_lib.Model): 55 """A `Functional` model is a `Model` defined as a directed graph of layers. 56 57 Three types of `Model` exist: subclassed `Model`, `Functional` model, 58 and `Sequential` (a special case of `Functional`). 59 In general, more Keras features are supported with `Functional` 60 than with subclassed `Model`s, specifically: 61 62 - Model cloning (`keras.models.clone`) 63 - Serialization (`model.get_config()/from_config`, `model.to_json()/to_yaml()` 64 - Whole-model saving (`model.save()`) 65 66 A `Functional` model can be instantiated by passing two arguments to 67 `__init__`. The first argument is the `keras.Input` Tensors that represent 68 the inputs to the model. The second argument specifies the output 69 tensors that represent the outputs of this model. Both arguments can be a 70 nested structure of tensors. 71 72 Example: 73 74 ``` 75 inputs = {'x1': keras.Input(shape=(10,)), 'x2': keras.Input(shape=(1,))} 76 t = keras.layers.Dense(1, activation='relu')(inputs['x1']) 77 outputs = keras.layers.Add()([t, inputs['x2']) 78 model = keras.Model(inputs, outputs) 79 ``` 80 81 A `Functional` model constructed using the Functional API can also include raw 82 TensorFlow functions, with the exception of functions that create Variables 83 or assign ops. 84 85 Example: 86 87 ``` 88 inputs = keras.Input(shape=(10,)) 89 x = keras.layers.Dense(1)(inputs) 90 outputs = tf.nn.relu(x) 91 model = keras.Model(inputs, outputs) 92 ``` 93 94 Args: 95 inputs: List of input tensors (must be created via `tf.keras.Input()`). 96 outputs: List of output tensors. 97 name: String, optional. Name of the model. 98 trainable: Boolean, optional. If the model's variables should be trainable. 99 """ 100 101 # See tf.Module for the usage of this property. 102 # The key of _layer_call_argspecs is a layer. tf.Module._flatten will fail to 103 # flatten the key since it is trying to convert Trackable/Layer to a string. 104 _TF_MODULE_IGNORED_PROPERTIES = frozenset(itertools.chain( 105 ('_layer_call_argspecs', '_compiled_trainable_state', 106 '_output_mask_cache', '_output_tensor_cache', '_output_shape_cache'), 107 training_lib.Model._TF_MODULE_IGNORED_PROPERTIES 108 )) 109 110 @trackable.no_automatic_dependency_tracking 111 def __init__(self, inputs, outputs, name=None, trainable=True, 112 **kwargs): 113 # This is used by the Model class, since we have some logic to swap the 114 # class in the __new__ method, which will lead to __init__ get invoked 115 # twice. Using the skip_init to skip one of the invocation of __init__ to 116 # avoid any side effects 117 skip_init = kwargs.pop('skip_init', False) 118 if skip_init: 119 return 120 generic_utils.validate_kwargs(kwargs, {}) 121 super(Functional, self).__init__(name=name, trainable=trainable) 122 self._init_graph_network(inputs, outputs) 123 124 @trackable.no_automatic_dependency_tracking 125 def _init_graph_network(self, inputs, outputs): 126 base_layer.keras_api_gauge.get_cell('Functional').set(True) 127 # This method is needed for Sequential to reinitialize graph network when 128 # layer is added or removed. 129 self._is_graph_network = True 130 131 # Normalize and set self.inputs, self.outputs. 132 if isinstance(inputs, list) and len(nest.flatten(inputs)) == 1: 133 inputs = inputs[0] 134 if isinstance(outputs, list) and len(nest.flatten(outputs)) == 1: 135 outputs = outputs[0] 136 self._nested_inputs = inputs 137 self._nested_outputs = outputs 138 self.inputs = nest.flatten(inputs) 139 self.outputs = nest.flatten(outputs) 140 141 # Models constructed with a single Tensor or list of Tensors can 142 # be called with a dict, where the keys of the dict are the names 143 # of the `Input` objects. Extra keys are ignored with warning. 144 if not nest.is_nested(self._nested_inputs): 145 self._enable_dict_to_input_mapping = True 146 elif (isinstance(self._nested_inputs, (list, tuple)) and 147 not any(nest.is_nested(t) for t in self._nested_inputs)): 148 self._enable_dict_to_input_mapping = True 149 elif (isinstance(self._nested_inputs, dict) and 150 not any(nest.is_nested(t) for t in self._nested_inputs.values())): 151 self._enable_dict_to_input_mapping = True 152 else: 153 self._enable_dict_to_input_mapping = False 154 155 if not keras_tensor.keras_tensors_enabled(): 156 if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs): 157 base_layer_utils.create_keras_history(self._nested_outputs) 158 159 self._validate_graph_inputs_and_outputs() 160 161 # A Network does not create weights of its own, thus it is already 162 # built. 163 self.built = True 164 self._build_input_shape = nest.map_structure(lambda x: x.shape, inputs) 165 self._compute_output_and_mask_jointly = True 166 # `_expects_training_arg` is True since the `training` argument is always 167 # present in the signature of the `call` method of a graph network. 168 self._expects_training_arg = True 169 self._expects_mask_arg = True 170 # A graph network does not autocast inputs, as its layers will cast them 171 # instead. 172 self._autocast = False 173 174 self._input_layers = [] 175 self._output_layers = [] 176 self._input_coordinates = [] 177 self._output_coordinates = [] 178 179 # This is for performance optimization when calling the Network on new 180 # inputs. Every time the Network is called on a set on input tensors, 181 # we compute the output tensors, output masks and output shapes in one pass, 182 # then cache them here. When any of these outputs is queried later, we 183 # retrieve it from there instead of recomputing it. 184 self._output_mask_cache = {} 185 self._output_tensor_cache = {} 186 self._output_shape_cache = {} 187 188 # Build self._output_layers: 189 for x in self.outputs: 190 layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access 191 self._output_layers.append(layer) 192 self._output_coordinates.append((layer, node_index, tensor_index)) 193 194 # Build self._input_layers: 195 for x in self.inputs: 196 layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access 197 # It's supposed to be an input layer, so only one node 198 # and one tensor output. 199 assert node_index == 0 200 assert tensor_index == 0 201 self._input_layers.append(layer) 202 self._input_coordinates.append((layer, node_index, tensor_index)) 203 204 # Keep track of the network's nodes and layers. 205 nodes, nodes_by_depth, layers, _ = _map_graph_network( 206 self.inputs, self.outputs) 207 self._network_nodes = nodes 208 self._nodes_by_depth = nodes_by_depth 209 self._self_tracked_trackables = layers 210 self._layer_call_argspecs = {} 211 for layer in self._self_tracked_trackables: 212 self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call) 213 214 # Build self.input_names and self.output_names. 215 self._set_output_names() 216 self.input_names = [] 217 self._feed_input_names = [] 218 self._feed_inputs = [] 219 self._feed_input_shapes = [] 220 for layer in self._input_layers: 221 self.input_names.append(layer.name) 222 if layer.is_placeholder: 223 self._feed_input_names.append(layer.name) 224 # Use batch_input_shape here because non-eager composite tensors may not 225 # have a shape attribute that's meaningful (sparse, for instance, has 226 # a tensor that's non-constant and needs to be fed). This means that 227 # input layers that create placeholders will need to have the 228 # batch_input_shape attr to allow for input shape validation. 229 self._feed_input_shapes.append(layer._batch_input_shape) 230 self._feed_inputs.append(layer.input) 231 232 self._compute_tensor_usage_count() 233 self._set_save_spec(self._nested_inputs) 234 tf_utils.assert_no_legacy_layers(self.layers) 235 236 @property 237 def input(self): 238 """Retrieves the input tensor(s) of a layer. 239 240 Only applicable if the layer has exactly one input, 241 i.e. if it is connected to one incoming layer. 242 243 Returns: 244 Input tensor or list of input tensors. 245 246 Raises: 247 RuntimeError: If called in Eager mode. 248 AttributeError: If no inbound nodes are found. 249 """ 250 return self._nested_inputs 251 252 @property 253 def input_shape(self): 254 """Retrieves the input shape(s) of a layer. 255 256 Only applicable if the layer has exactly one input, 257 i.e. if it is connected to one incoming layer, or if all inputs 258 have the same shape. 259 260 Returns: 261 Input shape, as an integer shape tuple 262 (or list of shape tuples, one tuple per input tensor). 263 264 Raises: 265 AttributeError: if the layer has no defined input_shape. 266 RuntimeError: if called in Eager mode. 267 """ 268 return nest.map_structure(backend.int_shape, self.input) 269 270 @property 271 def input_spec(self): 272 if hasattr(self, '_manual_input_spec'): 273 return self._manual_input_spec 274 if (isinstance(self._nested_inputs, (dict, list, tuple)) and 275 len(self._nested_inputs) != len(self.inputs)): 276 # Case where we have a nested structure. 277 # In such a case we can't safely run any checks. 278 return None 279 if isinstance(self._nested_inputs, dict): 280 # Case where `_nested_inputs` is a plain dict of Inputs. 281 names = sorted(self._nested_inputs.keys()) 282 return [input_spec.InputSpec( 283 shape=shape_with_no_batch_size(self._nested_inputs[name]), 284 allow_last_axis_squeeze=True, name=name) for name in names] 285 else: 286 # Single input, or list / tuple of inputs. 287 # The data may be passed as a dict keyed by input name. 288 return [input_spec.InputSpec( 289 shape=shape_with_no_batch_size(x), allow_last_axis_squeeze=True, 290 name=x._keras_history.layer.name) for x in self.inputs] 291 292 @input_spec.setter 293 def input_spec(self, value): 294 self._manual_input_spec = value 295 296 @property 297 def output(self): 298 """Retrieves the output tensor(s) of a layer. 299 300 Only applicable if the layer has exactly one output, 301 i.e. if it is connected to one incoming layer. 302 303 Returns: 304 Output tensor or list of output tensors. 305 306 Raises: 307 AttributeError: if the layer is connected to more than one incoming 308 layers. 309 RuntimeError: if called in Eager mode. 310 """ 311 return self._nested_outputs 312 313 @property 314 def output_shape(self): 315 """Retrieves the output shape(s) of a layer. 316 317 Only applicable if the layer has one output, 318 or if all outputs have the same shape. 319 320 Returns: 321 Output shape, as an integer shape tuple 322 (or list of shape tuples, one tuple per output tensor). 323 324 Raises: 325 AttributeError: if the layer has no defined output shape. 326 RuntimeError: if called in Eager mode. 327 """ 328 return nest.map_structure(backend.int_shape, self.output) 329 330 def _set_output_names(self): 331 """Assigns unique names to the Network's outputs. 332 333 Output layers with multiple output tensors would otherwise lead to duplicate 334 names in self.output_names. 335 """ 336 uniquified = [] 337 output_names = set() 338 prefix_count = {} 339 for layer in self._output_layers: 340 proposal = layer.name 341 while proposal in output_names: 342 existing_count = prefix_count.get(layer.name, 1) 343 proposal = '{}_{}'.format(layer.name, existing_count) 344 prefix_count[layer.name] = existing_count + 1 345 output_names.add(proposal) 346 uniquified.append(proposal) 347 self.output_names = uniquified 348 349 @property 350 def _layer_checkpoint_dependencies(self): 351 """Dictionary of layer dependencies to be included in the checkpoint.""" 352 weight_layer_index = 0 353 354 dependencies = collections.OrderedDict() 355 for layer_index, layer in enumerate(self.layers): 356 try: 357 if layer.weights: 358 # Keep a separate index for layers which have weights. This allows 359 # users to insert Layers without weights anywhere in the network 360 # without breaking checkpoints. 361 dependencies['layer_with_weights-%d' % weight_layer_index] = layer 362 weight_layer_index += 1 363 except ValueError: 364 # The layer might have weights, but may not be built yet. We just treat 365 # it as layer without weight. 366 pass 367 368 # Even if it doesn't have weights, we should still track everything in 369 # case it has/will have Trackable dependencies. 370 dependencies['layer-%d' % layer_index] = layer 371 return dependencies 372 373 @property 374 def _checkpoint_dependencies(self): 375 dependencies = [ 376 trackable.TrackableReference(name=name, ref=layer) 377 for name, layer in self._layer_checkpoint_dependencies.items()] 378 dependencies.extend(super(Functional, self)._checkpoint_dependencies) 379 return dependencies 380 381 def _lookup_dependency(self, name): 382 layer_dependencies = self._layer_checkpoint_dependencies 383 if name in layer_dependencies: 384 return layer_dependencies[name] 385 return super(Functional, self)._lookup_dependency(name) 386 387 def _handle_deferred_layer_dependencies(self, layers): 388 """Handles layer checkpoint dependencies that are added after init.""" 389 layer_checkpoint_dependencies = self._layer_checkpoint_dependencies 390 layer_to_name = {v: k for k, v in layer_checkpoint_dependencies.items()} 391 for layer in layers: 392 if layer in layer_to_name: 393 self._handle_deferred_dependencies(name=layer_to_name[layer], 394 trackable=layer) 395 396 @property 397 def _should_compute_mask(self): 398 return True 399 400 def compute_mask(self, inputs, mask): 401 # TODO(omalleyt): b/123540974 This function is not really safe to call 402 # by itself because it will duplicate any updates and losses in graph 403 # mode by `call`ing the Layers again. 404 output_tensors = self._run_internal_graph(inputs, mask=mask) 405 return nest.map_structure(lambda t: getattr(t, '_keras_mask', None), 406 output_tensors) 407 408 @doc_controls.do_not_doc_inheritable 409 def call(self, inputs, training=None, mask=None): 410 """Calls the model on new inputs. 411 412 In this case `call` just reapplies 413 all ops in the graph to the new inputs 414 (e.g. build a new computational graph from the provided inputs). 415 416 Args: 417 inputs: A tensor or list of tensors. 418 training: Boolean or boolean scalar tensor, indicating whether to run 419 the `Network` in training mode or inference mode. 420 mask: A mask or list of masks. A mask can be 421 either a tensor or None (no mask). 422 423 Returns: 424 A tensor if there is a single output, or 425 a list of tensors if there are more than one outputs. 426 """ 427 return self._run_internal_graph( 428 inputs, training=training, mask=mask) 429 430 def compute_output_shape(self, input_shape): 431 # Convert any shapes in tuple format to TensorShapes. 432 input_shape = tf_utils.convert_shapes(input_shape, to_tuples=False) 433 434 if len(nest.flatten(input_shape)) != len(nest.flatten(self._input_layers)): 435 raise ValueError('Invalid input_shape argument ' + str(input_shape) + 436 ': model has ' + str(len(self._input_layers)) + 437 ' tensor inputs.') 438 439 # Use the tuple of TensorShape as the cache key, since tuple is hashable 440 # and can be used as hash key. 441 try: 442 cache_key = tuple(tf_utils.convert_shapes(input_shape, to_tuples=True)) 443 if cache_key in self._output_shape_cache: 444 # Cache hit. Return shapes as TensorShapes. 445 return self._output_shape_cache[cache_key] 446 except ValueError: 447 # In case there are unknown TensorShape, eg for sparse tensor input, 448 # We skip the caching since the shape is unknown. 449 pass 450 451 layers_to_output_shapes = {} 452 for layer, shape in zip(self._input_layers, nest.flatten(input_shape)): 453 # It's an input layer: then `compute_output_shape` is identity, 454 # and there is only one node and one tensor.. 455 shape_key = layer.name + '_0_0' 456 layers_to_output_shapes[shape_key] = shape 457 458 depth_keys = list(self._nodes_by_depth.keys()) 459 depth_keys.sort(reverse=True) 460 # Iterate over nodes, by depth level. 461 if len(depth_keys) > 1: 462 for depth in depth_keys: 463 nodes = self._nodes_by_depth[depth] 464 for node in nodes: 465 layer = node.layer 466 if layer in self._input_layers: 467 # We've already covered the input layers 468 # a few lines above. 469 continue 470 # Get the input shapes for the first argument of the node 471 layer_input_shapes = [] 472 layer_inputs = node.call_args[0] 473 for layer_input in nest.flatten(layer_inputs): 474 kh = layer_input._keras_history 475 input_layer_key = kh.layer.name + '_%s_%s' % (kh.node_index, 476 kh.tensor_index) 477 layer_input_shapes.append(layers_to_output_shapes[input_layer_key]) 478 layer_input_shapes = nest.pack_sequence_as(layer_inputs, 479 layer_input_shapes) 480 # Layers expect shapes to be tuples for `compute_output_shape`. 481 layer_input_shapes = tf_utils.convert_shapes( 482 layer_input_shapes, to_tuples=True) 483 layer_output_shapes = layer.compute_output_shape(layer_input_shapes) 484 # Convert back to TensorShapes. 485 layer_output_shapes = tf_utils.convert_shapes( 486 layer_output_shapes, to_tuples=False) 487 488 node_index = layer._inbound_nodes.index(node) # pylint: disable=protected-access 489 for j, shape in enumerate(nest.flatten(layer_output_shapes)): 490 shape_key = layer.name + '_%s_%s' % (node_index, j) 491 layers_to_output_shapes[shape_key] = shape 492 493 # Read final output shapes from layers_to_output_shapes. 494 output_shapes = [] 495 for i in range(len(self._output_layers)): 496 layer, node_index, tensor_index = self._output_coordinates[i] 497 shape_key = layer.name + '_%s_%s' % (node_index, tensor_index) 498 output_shapes.append(layers_to_output_shapes[shape_key]) 499 output_shapes = nest.pack_sequence_as(self._nested_outputs, output_shapes) 500 # Store in cache. 501 self._output_shape_cache[cache_key] = output_shapes 502 503 # Return shapes as TensorShapes. 504 return output_shapes 505 506 def _init_set_name(self, name, zero_based=True): 507 if not name: 508 cls_name = self.__class__.__name__ 509 if self.__class__ == Functional: 510 # Hide the functional class name from user, since its not a public 511 # visible class. Use "Model" instead, 512 cls_name = 'Model' 513 self._name = backend.unique_object_name( 514 generic_utils.to_snake_case(cls_name), 515 zero_based=zero_based) 516 else: 517 self._name = name 518 519 def _run_internal_graph(self, inputs, training=None, mask=None): 520 """Computes output tensors for new inputs. 521 522 # Note: 523 - Can be run on non-Keras tensors. 524 525 Args: 526 inputs: Tensor or nested structure of Tensors. 527 training: Boolean learning phase. 528 mask: (Optional) Tensor or nested structure of Tensors. 529 530 Returns: 531 output_tensors 532 """ 533 inputs = self._flatten_to_reference_inputs(inputs) 534 if mask is None: 535 masks = [None] * len(inputs) 536 else: 537 masks = self._flatten_to_reference_inputs(mask) 538 for input_t, mask in zip(inputs, masks): 539 input_t._keras_mask = mask 540 541 # Dictionary mapping reference tensors to computed tensors. 542 tensor_dict = {} 543 tensor_usage_count = self._tensor_usage_count 544 for x, y in zip(self.inputs, inputs): 545 y = self._conform_to_reference_input(y, ref_input=x) 546 x_id = str(id(x)) 547 tensor_dict[x_id] = [y] * tensor_usage_count[x_id] 548 549 nodes_by_depth = self._nodes_by_depth 550 depth_keys = list(nodes_by_depth.keys()) 551 depth_keys.sort(reverse=True) 552 553 for depth in depth_keys: 554 nodes = nodes_by_depth[depth] 555 for node in nodes: 556 if node.is_input: 557 continue # Input tensors already exist. 558 559 if any(t_id not in tensor_dict for t_id in node.flat_input_ids): 560 continue # Node is not computable, try skipping. 561 562 args, kwargs = node.map_arguments(tensor_dict) 563 outputs = node.layer(*args, **kwargs) 564 565 # Update tensor_dict. 566 for x_id, y in zip(node.flat_output_ids, nest.flatten(outputs)): 567 tensor_dict[x_id] = [y] * tensor_usage_count[x_id] 568 569 output_tensors = [] 570 for x in self.outputs: 571 x_id = str(id(x)) 572 assert x_id in tensor_dict, 'Could not compute output ' + str(x) 573 output_tensors.append(tensor_dict[x_id].pop()) 574 575 return nest.pack_sequence_as(self._nested_outputs, output_tensors) 576 577 def _flatten_to_reference_inputs(self, tensors): 578 """Maps `tensors` to their respective `keras.Input`.""" 579 if self._enable_dict_to_input_mapping and isinstance(tensors, dict): 580 ref_inputs = self._nested_inputs 581 if not nest.is_nested(ref_inputs): 582 ref_inputs = [self._nested_inputs] 583 if isinstance(ref_inputs, dict): 584 # In the case that the graph is constructed with dict input tensors, 585 # We will use the original dict key to map with the keys in the input 586 # data. Note that the model.inputs is using nest.flatten to process the 587 # input tensors, which means the dict input tensors are ordered by their 588 # keys. 589 ref_input_names = sorted(ref_inputs.keys()) 590 else: 591 ref_input_names = [inp._keras_history.layer.name for inp in ref_inputs] 592 593 # Raise an warning if there are more input data comparing to input tensor 594 if len(tensors) > len(ref_input_names): 595 warnings.warn( 596 'Input dict contained keys {} which did not match any model input. ' 597 'They will be ignored by the model.'.format( 598 [n for n in tensors.keys() if n not in ref_input_names]) 599 ) 600 601 try: 602 # Flatten in the order `Input`s were passed during Model construction. 603 return [tensors[n] for n in ref_input_names] 604 except KeyError: 605 # TODO(b/151582614) 606 return nest.flatten(tensors) 607 608 # Otherwise both self.inputs and tensors will already be in same order. 609 return nest.flatten(tensors) 610 611 def _conform_to_reference_input(self, tensor, ref_input): 612 """Set shape and dtype based on `keras.Input`s.""" 613 if isinstance(tensor, ops.Tensor): 614 # Allow (None,) and (None, 1) Tensors to be passed interchangeably. Use 615 # the shape specified by the `keras.Input`. 616 t_shape = tensor.shape 617 t_rank = t_shape.rank 618 ref_shape = ref_input.shape 619 ref_rank = ref_shape.rank 620 keras_history = getattr(tensor, '_keras_history', None) 621 if t_rank is not None and ref_rank is not None: 622 # Should squeeze last dimension. 623 # True if tensor is (BATCH, ..., 1) and reference is (BATCH, ...). 624 if (t_rank == ref_rank + 1 and t_shape[-1] == 1): 625 tensor = array_ops.squeeze_v2(tensor, axis=-1) 626 # Should expand last_dimension. 627 # True if tensor is (BATCH, ...) and reference is (BATCH, ..., 1). 628 elif (t_rank == ref_rank - 1 and ref_shape[-1] == 1): 629 tensor = array_ops.expand_dims_v2(tensor, axis=-1) 630 if keras_history is not None: # Restore keras history. 631 tensor._keras_history = keras_history 632 633 # Add shape hints to Tensors that may have None shape dims but have shapes 634 # defined by the `keras.Input` (not applicable in eager mode). 635 if not context.executing_eagerly(): 636 try: 637 tensor.set_shape(tensor.shape.merge_with(ref_input.shape)) 638 except ValueError: 639 logging.warning( 640 'Model was constructed with shape {} for input {}, but it was ' 641 'called on an input with incompatible shape {}.'.format( 642 ref_input.shape, ref_input, tensor.shape)) 643 644 # Dtype casting. 645 tensor = math_ops.cast(tensor, dtype=ref_input.dtype) 646 elif tf_utils.is_extension_type(tensor): 647 # Dtype casting (If the extension type has a non-variant dtype and 648 # supports being cast) 649 ref_input_dtype = getattr(ref_input, 'dtype', None) 650 if ref_input_dtype is not None and ref_input_dtype != dtypes.variant: 651 tensor = math_ops.cast(tensor, dtype=ref_input_dtype) 652 653 return tensor 654 655 def get_config(self): 656 return copy.deepcopy(get_network_config(self)) 657 658 @classmethod 659 def from_config(cls, config, custom_objects=None): 660 """Instantiates a Model from its config (output of `get_config()`). 661 662 Args: 663 config: Model config dictionary. 664 custom_objects: Optional dictionary mapping names 665 (strings) to custom classes or functions to be 666 considered during deserialization. 667 668 Returns: 669 A model instance. 670 671 Raises: 672 ValueError: In case of improperly formatted config dict. 673 """ 674 with generic_utils.SharedObjectLoadingScope(): 675 input_tensors, output_tensors, created_layers = reconstruct_from_config( 676 config, custom_objects) 677 model = cls(inputs=input_tensors, outputs=output_tensors, 678 name=config.get('name')) 679 connect_ancillary_layers(model, created_layers) 680 return model 681 682 def _validate_graph_inputs_and_outputs(self): 683 """Validates the inputs and outputs of a Graph Network.""" 684 # Check for redundancy in inputs. 685 if len({id(i) for i in self.inputs}) != len(self.inputs): 686 raise ValueError('The list of inputs passed to the model ' 687 'is redundant. ' 688 'All inputs should only appear once.' 689 ' Found: ' + str(self.inputs)) 690 691 for x in self.inputs: 692 # Check that x has appropriate `_keras_history` metadata. 693 if not hasattr(x, '_keras_history'): 694 cls_name = self.__class__.__name__ 695 raise ValueError('Input tensors to a ' + cls_name + ' ' + 696 'must come from `tf.keras.Input`. ' 697 'Received: ' + str(x) + 698 ' (missing previous layer metadata).') 699 # Check that x is an input tensor. 700 # pylint: disable=protected-access 701 layer = x._keras_history.layer 702 if len(layer._inbound_nodes) > 1 or ( 703 layer._inbound_nodes and not layer._inbound_nodes[0].is_input): 704 cls_name = self.__class__.__name__ 705 logging.warning(cls_name + ' model inputs must come from ' 706 '`tf.keras.Input` (thus holding past layer metadata), ' 707 'they cannot be the output of ' 708 'a previous non-Input layer. ' 709 'Here, a tensor specified as ' 710 'input to "' + self.name + '" was not an Input tensor, ' 711 'it was generated by layer ' + layer.name + '.\n' 712 'Note that input tensors are ' 713 'instantiated via `tensor = tf.keras.Input(shape)`.\n' 714 'The tensor that caused the issue was: ' + str(x.name)) 715 716 # Check compatibility of batch sizes of Input Layers. 717 input_batch_sizes = [ 718 training_utils.get_static_batch_size(x._keras_history.layer) 719 for x in self.inputs 720 ] 721 consistent_batch_size = None 722 for batch_size in input_batch_sizes: 723 if batch_size is not None: 724 if (consistent_batch_size is not None and 725 batch_size != consistent_batch_size): 726 raise ValueError('The specified batch sizes of the Input Layers' 727 ' are incompatible. Found batch sizes: {}'.format( 728 input_batch_sizes)) 729 consistent_batch_size = batch_size 730 731 for x in self.outputs: 732 if not hasattr(x, '_keras_history'): 733 cls_name = self.__class__.__name__ 734 raise ValueError('Output tensors of a ' + cls_name + ' model must be ' 735 'the output of a TensorFlow `Layer` ' 736 '(thus holding past layer metadata). Found: ' + str(x)) 737 738 def _insert_layers(self, layers, relevant_nodes=None): 739 """Inserts Layers into the Network after Network creation. 740 741 This is only valid for Keras Graph Networks. Layers added via this function 742 will be included in the `call` computation and `get_config` of this Network. 743 They will not be added to the Network's outputs. 744 745 746 Args: 747 layers: Arbitrary nested structure of Layers. Layers must be reachable 748 from one or more of the `keras.Input` Tensors that correspond to this 749 Network's inputs. 750 relevant_nodes: Nodes from the Layers that should be considered part of 751 this Network. If `None`, all Nodes will be considered part of this 752 Network. 753 754 Raises: 755 ValueError: If the layers depend on `Input`s not found in this Model. 756 """ 757 layers = nest.flatten(layers) 758 tf_utils.assert_no_legacy_layers(layers) 759 node_to_depth = {} 760 for depth, nodes in self._nodes_by_depth.items(): 761 node_to_depth.update({node: depth for node in nodes}) 762 # The nodes of these Layers that are relevant to this Network. If not 763 # provided, assume all Nodes are relevant 764 if not relevant_nodes: 765 relevant_nodes = nest.flatten([layer._inbound_nodes for layer in layers]) 766 network_nodes = set(relevant_nodes + list(node_to_depth.keys())) 767 768 def _get_min_depth(node): 769 """Gets the minimum depth at which node can be computed.""" 770 min_depth = 0 771 for layer, node_id, _, _ in node.iterate_inbound(): 772 inbound_node = layer._inbound_nodes[node_id] 773 if inbound_node in node_to_depth: 774 min_depth = min(min_depth, node_to_depth[inbound_node]) 775 elif inbound_node not in network_nodes: 776 continue 777 else: 778 # Previous relevant nodes haven't been processed yet. 779 return None 780 # New node is one shallower than its shallowest input. 781 return min_depth - 1 782 783 # Insert nodes into `_nodes_by_depth` and other node attrs. 784 unprocessed_nodes = copy.copy(relevant_nodes) 785 i = 0 786 while unprocessed_nodes: 787 i += 1 788 # Do a sanity check. This can occur if `Input`s from outside this Model 789 # are being relied on. 790 if i > 10000: 791 raise ValueError('Layers could not be added due to missing ' 792 'dependencies.') 793 794 node = unprocessed_nodes.pop(0) 795 depth = _get_min_depth(node) 796 if depth is None: # Defer until inbound nodes are processed. 797 unprocessed_nodes.append(node) 798 continue 799 node_key = _make_node_key(node.layer.name, 800 node.layer._inbound_nodes.index(node)) 801 if node_key not in self._network_nodes: 802 node_to_depth[node] = depth 803 self._network_nodes.add(node_key) 804 self._nodes_by_depth[depth].append(node) 805 806 # Insert layers and update other layer attrs. 807 layer_set = set(self._self_tracked_trackables) 808 deferred_layers = [] 809 for layer in layers: 810 if layer not in layer_set: 811 self._self_tracked_trackables.append(layer) 812 deferred_layers.append(layer) 813 self._layer_call_argspecs[layer] = tf_inspect.getfullargspec(layer.call) 814 layer_set.add(layer) 815 self._handle_deferred_layer_dependencies(deferred_layers) 816 817 self._compute_tensor_usage_count() 818 819 def _compute_tensor_usage_count(self): 820 """Compute the #. of tensor usages for all the output tensors of layers. 821 822 The computed tensor usage count is saved as `self._tensor_usage_count`. This 823 is later used for saving memory in eager computation by releasing 824 no-longer-needed tensors as early as possible. 825 """ 826 tensor_usage_count = collections.Counter() 827 available_tensors = set(str(id(tensor)) for tensor in self.inputs) 828 829 depth_keys = list(self._nodes_by_depth.keys()) 830 depth_keys.sort(reverse=True) 831 depth_keys = depth_keys[1:] 832 833 for depth in depth_keys: 834 for node in self._nodes_by_depth[depth]: 835 input_tensors = { 836 str(id(tensor)) for tensor in nest.flatten(node.keras_inputs) 837 } 838 if input_tensors.issubset(available_tensors): 839 for tensor in nest.flatten(node.keras_inputs): 840 tensor_usage_count[str(id(tensor))] += 1 841 842 for output_tensor in nest.flatten(node.outputs): 843 available_tensors.add(str(id(output_tensor))) 844 845 for tensor in self.outputs: 846 tensor_usage_count[str(id(tensor))] += 1 847 848 self._tensor_usage_count = tensor_usage_count 849 850 def _assert_weights_created(self): 851 # Override the implementation in Model. 852 # The Functional model should always have weight created already. 853 return 854 855 def _graph_network_add_loss(self, symbolic_loss): 856 new_nodes, new_layers = _map_subgraph_network(self.inputs, [symbolic_loss]) 857 # Losses must be keyed on inputs no matter what in order to be supported in 858 # DistributionStrategy. 859 add_loss_layer = base_layer.AddLoss( 860 unconditional=False, dtype=symbolic_loss.dtype) 861 add_loss_layer(symbolic_loss) 862 new_nodes.extend(add_loss_layer.inbound_nodes) 863 new_layers.append(add_loss_layer) 864 self._insert_layers(new_layers, new_nodes) 865 866 def _graph_network_add_metric(self, value, aggregation, name): 867 new_nodes, new_layers = _map_subgraph_network(self.inputs, [value]) 868 add_metric_layer = base_layer.AddMetric( 869 aggregation, name, dtype=value.dtype) 870 add_metric_layer(value) 871 new_nodes.extend(add_metric_layer.inbound_nodes) 872 new_layers.append(add_metric_layer) 873 self._insert_layers(new_layers, new_nodes) 874 875 @property 876 def _trackable_saved_model_saver(self): 877 return network_serialization.NetworkSavedModelSaver(self) 878 879 def _get_save_spec(self, dynamic_batch=True): 880 if getattr(self, '_has_explicit_input_shape', True): 881 # Functional models and Sequential models that have an explicit input 882 # shape should use the batch size set by the input layer. 883 dynamic_batch = False 884 return super(Functional, self)._get_save_spec(dynamic_batch) 885 886 887def _make_node_key(layer_name, node_index): 888 return layer_name + '_ib-' + str(node_index) 889 890 891def _map_graph_network(inputs, outputs): 892 """Validates a network's topology and gather its layers and nodes. 893 894 Args: 895 inputs: List of input tensors. 896 outputs: List of outputs tensors. 897 898 Returns: 899 A tuple `(nodes, nodes_by_depth, layers, layers_by_depth)`. 900 - nodes: list of Node instances. 901 - nodes_by_depth: dict mapping ints (depth) to lists of node instances. 902 - layers: list of Layer instances. 903 - layers_by_depth: dict mapping ints (depth) to lists of layer instances. 904 905 Raises: 906 ValueError: In case the network is not valid (e.g. disconnected graph). 907 """ 908 # "depth" is number of layers between output Node and the Node. 909 # Nodes are ordered from inputs -> outputs. 910 nodes_in_decreasing_depth, layer_indices = _build_map(outputs) 911 network_nodes = { 912 _make_node_key(node.layer.name, node.layer._inbound_nodes.index(node)) 913 for node in nodes_in_decreasing_depth 914 } 915 916 nodes_depths = {} # dict {node: depth value} 917 layers_depths = {} # dict {layer: depth value} 918 919 for node in reversed(nodes_in_decreasing_depth): 920 # If the depth is not set, the node has no outbound nodes (depth 0). 921 depth = nodes_depths.setdefault(node, 0) 922 923 # Update the depth of the corresponding layer 924 previous_depth = layers_depths.get(node.layer, 0) 925 # If we've seen this layer before at a higher depth, 926 # we should use that depth instead of the node depth. 927 # This is necessary for shared layers that have inputs at different 928 # depth levels in the graph. 929 depth = max(depth, previous_depth) 930 layers_depths[node.layer] = depth 931 nodes_depths[node] = depth 932 933 # Update the depth of inbound nodes. 934 # The "depth" of a node is the max of the depths 935 # of all nodes it is connected to + 1. 936 for node_dep in node.parent_nodes: 937 previous_depth = nodes_depths.get(node_dep, 0) 938 nodes_depths[node_dep] = max(depth + 1, previous_depth) 939 940 # Handle inputs that are not connected to outputs. 941 # We do not error out here because the inputs may be used to compute losses 942 # and metrics. 943 for input_t in inputs: 944 input_layer = input_t._keras_history[0] 945 if input_layer not in layers_depths: 946 layers_depths[input_layer] = 0 947 layer_indices[input_layer] = -1 948 nodes_depths[input_layer._inbound_nodes[0]] = 0 949 network_nodes.add(_make_node_key(input_layer.name, 0)) 950 951 # Build a dict {depth: list of nodes with this depth} 952 nodes_by_depth = collections.defaultdict(list) 953 for node, depth in nodes_depths.items(): 954 nodes_by_depth[depth].append(node) 955 956 # Build a dict {depth: list of layers with this depth} 957 layers_by_depth = collections.defaultdict(list) 958 for layer, depth in layers_depths.items(): 959 layers_by_depth[depth].append(layer) 960 961 # Get sorted list of layer depths. 962 depth_keys = list(layers_by_depth.keys()) 963 depth_keys.sort(reverse=True) 964 965 # Set self.layers ordered by depth. 966 layers = [] 967 for depth in depth_keys: 968 layers_for_depth = layers_by_depth[depth] 969 # Network.layers needs to have a deterministic order: 970 # here we order them by traversal order. 971 layers_for_depth.sort(key=lambda x: layer_indices[x]) 972 layers.extend(layers_for_depth) 973 974 # Get sorted list of node depths. 975 depth_keys = list(nodes_by_depth.keys()) 976 depth_keys.sort(reverse=True) 977 978 # Check that all tensors required are computable. 979 # computable_tensors: all tensors in the graph 980 # that can be computed from the inputs provided. 981 computable_tensors = set() 982 for x in inputs: 983 computable_tensors.add(id(x)) 984 985 layers_with_complete_input = [] # To provide a better error msg. 986 for depth in depth_keys: 987 for node in nodes_by_depth[depth]: 988 layer = node.layer 989 if layer and not node.is_input: 990 for x in nest.flatten(node.keras_inputs): 991 if id(x) not in computable_tensors: 992 raise ValueError('Graph disconnected: ' 993 'cannot obtain value for tensor ' + str(x) + 994 ' at layer "' + layer.name + '". ' 995 'The following previous layers ' 996 'were accessed without issue: ' + 997 str(layers_with_complete_input)) 998 for x in nest.flatten(node.outputs): 999 computable_tensors.add(id(x)) 1000 layers_with_complete_input.append(layer.name) 1001 1002 # Ensure name unicity, which will be crucial for serialization 1003 # (since serialized nodes refer to layers by their name). 1004 all_names = [layer.name for layer in layers] 1005 for name in all_names: 1006 if all_names.count(name) != 1: 1007 raise ValueError('The name "' + name + '" is used ' + 1008 str(all_names.count(name)) + ' times in the model. ' 1009 'All layer names should be unique.') 1010 return network_nodes, nodes_by_depth, layers, layers_by_depth 1011 1012 1013def _build_map(outputs): 1014 """This method topologically sorts nodes in order from inputs to outputs. 1015 1016 It uses a depth-first search to topologically sort nodes that appear in the 1017 _keras_history connectivity metadata of `outputs`. 1018 1019 Args: 1020 outputs: the output tensors whose _keras_history metadata should be walked. 1021 This may be an arbitrary nested structure. 1022 1023 Returns: 1024 A tuple like (ordered_nodes, layer_to_first_traversal_index) 1025 ordered_nodes: list of nodes appearing in the keras history, topologically 1026 sorted from original inputs to the `outputs`. 1027 (If outputs have different sets of ancestors, the inputs to one output 1028 may appear after a different output). 1029 layer_to_first_traversal_index: 1030 A dict mapping layer to the traversal index in the DFS where it is 1031 seen. Note: if a layer is shared by several nodes, the dict will only 1032 store the index corresponding to the *first* time the layer seen. 1033 """ 1034 finished_nodes = set() 1035 nodes_in_progress = set() 1036 nodes_in_decreasing_depth = [] # nodes from inputs -> outputs. 1037 layer_indices = {} # layer -> in traversal order. 1038 for output in nest.flatten(outputs): 1039 _build_map_helper(output, finished_nodes, nodes_in_progress, 1040 nodes_in_decreasing_depth, layer_indices) 1041 return nodes_in_decreasing_depth, layer_indices 1042 1043 1044def _build_map_helper(tensor, finished_nodes, nodes_in_progress, 1045 nodes_in_decreasing_depth, layer_indices): 1046 """Recursive helper for `_build_map`.""" 1047 layer, node_index, _ = tensor._keras_history # pylint: disable=protected-access 1048 node = layer._inbound_nodes[node_index] # pylint: disable=protected-access 1049 1050 # Don't repeat work for shared subgraphs 1051 if node in finished_nodes: 1052 return 1053 1054 # Prevent cycles. 1055 if node in nodes_in_progress: 1056 raise ValueError('The tensor ' + str(tensor) + ' at layer "' + layer.name + 1057 '" is part of a cycle.') 1058 1059 # Store the traversal order for layer sorting. 1060 if layer not in layer_indices: 1061 layer_indices[layer] = len(layer_indices) 1062 1063 # Propagate to all previous tensors connected to this node. 1064 nodes_in_progress.add(node) 1065 if not node.is_input: 1066 for tensor in node.keras_inputs: 1067 _build_map_helper(tensor, finished_nodes, nodes_in_progress, 1068 nodes_in_decreasing_depth, layer_indices) 1069 1070 finished_nodes.add(node) 1071 nodes_in_progress.remove(node) 1072 nodes_in_decreasing_depth.append(node) 1073 1074 1075def _map_subgraph_network(inputs, outputs): 1076 """Returns the nodes and layers in the topology from `inputs` to `outputs`. 1077 1078 Args: 1079 inputs: List of input tensors. 1080 outputs: List of output tensors. 1081 1082 Returns: 1083 A tuple of List{Node] and List[Layer]. 1084 """ 1085 if not keras_tensor.keras_tensors_enabled(): 1086 base_layer_utils.create_keras_history(outputs) 1087 # Keep only nodes and layers in the topology between inputs and outputs. 1088 _, nodes_by_depth, layers, _ = _map_graph_network(inputs, outputs) 1089 return nest.flatten([nodes for nodes in nodes_by_depth.values()]), layers 1090 1091 1092def _should_skip_first_node(layer): 1093 """Returns True if the first layer node should not be saved or loaded.""" 1094 # Networks that are constructed with an Input layer/shape start with a 1095 # pre-existing node linking their input to output. This node is excluded from 1096 # the network config. 1097 if layer._self_tracked_trackables: 1098 return (isinstance(layer, Functional) and 1099 # Filter out Sequential models without an input shape. 1100 isinstance(layer._self_tracked_trackables[0], 1101 input_layer_module.InputLayer)) 1102 else: 1103 return isinstance(layer, Functional) 1104 1105 1106def connect_ancillary_layers(model, created_layers): 1107 """Adds layers that are not connected to the outputs to the model.""" 1108 # Layers not connected to outputs, such as those added in `add_loss`. 1109 ancillary_layers = [ 1110 layer for layer in created_layers.values() if layer not in model.layers 1111 ] 1112 if ancillary_layers: 1113 relevant_nodes = nest.flatten([ 1114 layer.inbound_nodes[1:] 1115 if _should_skip_first_node(layer) else layer.inbound_nodes 1116 for layer in created_layers.values() 1117 ]) 1118 model._insert_layers(ancillary_layers, relevant_nodes) 1119 return model 1120 1121 1122def reconstruct_from_config(config, custom_objects=None, created_layers=None): 1123 """Reconstructs graph from config object. 1124 1125 Args: 1126 config: Dictionary returned from Network.get_config() 1127 custom_objects: Optional dictionary mapping names (strings) to custom 1128 classes or functions to be considered during deserialization. 1129 created_layers: Optional dictionary mapping names to Layer objects. Any 1130 layer not in this dictionary will be created and added to the dict. 1131 This function will add new nodes to all layers (excluding InputLayers), 1132 instead of re-using pre-existing nodes in the layers. 1133 1134 Returns: 1135 Tuple of (input tensors, output tensors, dictionary of created layers) 1136 """ 1137 # Layer instances created during the graph reconstruction process. 1138 created_layers = created_layers or collections.OrderedDict() 1139 1140 # Maps input data (tuple of inbound layer name, node index) from the config 1141 # to node indices in the newly generated model. The node indices may be 1142 # different if the layers have already been called previously. 1143 node_index_map = {} 1144 node_count_by_layer = {} 1145 1146 # Dictionary mapping layer instances to 1147 # node data that specifies a layer call. 1148 # It acts as a queue that maintains any unprocessed 1149 # layer call until it becomes possible to process it 1150 # (i.e. until the input tensors to the call all exist). 1151 unprocessed_nodes = {} 1152 1153 def add_unprocessed_node(layer, node_data): 1154 if layer not in unprocessed_nodes: 1155 unprocessed_nodes[layer] = [node_data] 1156 else: 1157 unprocessed_nodes[layer].append(node_data) 1158 1159 def get_node_index(layer, config_node_index): 1160 """Returns node index in layer (might differ from config_node_index).""" 1161 if isinstance(layer, input_layer_module.InputLayer): 1162 return 0 1163 return node_index_map.get((layer.name, config_node_index), None) 1164 1165 def _deserialize_keras_tensors(kwargs, layer_map): 1166 """Deserializes Keras Tensors passed to `call`..""" 1167 1168 def _deserialize_keras_tensor(t): 1169 """Deserializes a single Keras Tensor passed to `call`.""" 1170 if isinstance(t, tf_utils.ListWrapper): 1171 t = t.as_list() 1172 layer_name = t[0] 1173 node_index = t[1] 1174 tensor_index = t[2] 1175 1176 layer = layer_map[layer_name] 1177 new_node_index = get_node_index(layer, node_index) 1178 if new_node_index is None: 1179 # The inbound node may not have been processed yet, 1180 # (This can happen e.g. if it depends on a different set 1181 # of inputs than those that have been processed already). 1182 # raise an IndexError so that the current node puts itself 1183 # back on the unprocessed queue. 1184 # Caution: This may lead to infinite loops for malformed 1185 # network configurations! (or when there is a bug in 1186 # the network config loading code). 1187 raise IndexError 1188 node = layer._inbound_nodes[new_node_index] 1189 return nest.flatten(node.outputs)[tensor_index] 1190 return t 1191 1192 kwargs = tf_utils.convert_inner_node_data(kwargs, wrap=True) 1193 return nest.map_structure(_deserialize_keras_tensor, kwargs) 1194 1195 def process_node(layer, node_data): 1196 """Deserialize a node. 1197 1198 Args: 1199 layer: layer instance. 1200 node_data: Nested structure of `ListWrapper`. 1201 1202 Raises: 1203 ValueError: In case of improperly formatted `node_data`. 1204 """ 1205 input_tensors = [] 1206 for input_data in nest.flatten(node_data): 1207 input_data = input_data.as_list() 1208 inbound_layer_name = input_data[0] 1209 inbound_node_index = input_data[1] 1210 inbound_tensor_index = input_data[2] 1211 if len(input_data) == 3: 1212 kwargs = {} 1213 elif len(input_data) == 4: 1214 kwargs = input_data[3] 1215 try: 1216 kwargs = _deserialize_keras_tensors(kwargs, created_layers) 1217 except IndexError: 1218 # Happens if keras tensors in kwargs are still unprocessed 1219 add_unprocessed_node(layer, node_data) 1220 return 1221 else: 1222 raise ValueError('Improperly formatted model config.') 1223 1224 if inbound_layer_name != node_module._CONSTANT_VALUE: 1225 inbound_layer = created_layers[inbound_layer_name] 1226 inbound_node_index = get_node_index(inbound_layer, inbound_node_index) 1227 1228 if inbound_node_index is None: 1229 add_unprocessed_node(layer, node_data) 1230 return 1231 inbound_node = inbound_layer._inbound_nodes[inbound_node_index] 1232 input_tensors.append( 1233 nest.flatten(inbound_node.outputs)[inbound_tensor_index]) 1234 else: 1235 # We received a constant w/ no Keras history attached 1236 input_tensors.append(inbound_tensor_index) 1237 input_tensors = nest.pack_sequence_as(node_data, input_tensors) 1238 # Call layer on its inputs, thus creating the node 1239 # and building the layer if needed. 1240 if input_tensors is not None: 1241 if not layer._preserve_input_structure_in_config: 1242 input_tensors = ( 1243 base_layer_utils.unnest_if_single_tensor(input_tensors)) 1244 output_tensors = layer(input_tensors, **kwargs) 1245 1246 # Update node index map. 1247 output_index = nest.flatten(output_tensors)[0]._keras_history.node_index 1248 node_index_map[(layer.name, node_count_by_layer[layer])] = output_index 1249 node_count_by_layer[layer] += 1 1250 1251 def process_layer(layer_data): 1252 """Deserializes a layer, then call it on appropriate inputs. 1253 1254 Args: 1255 layer_data: layer config dict. 1256 1257 Raises: 1258 ValueError: In case of improperly formatted `layer_data` dict. 1259 """ 1260 layer_name = layer_data['name'] 1261 1262 if layer_name in created_layers: 1263 layer = created_layers[layer_name] 1264 else: 1265 # Instantiate layer. 1266 from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top 1267 1268 layer = deserialize_layer(layer_data, custom_objects=custom_objects) 1269 created_layers[layer_name] = layer 1270 1271 node_count_by_layer[layer] = int(_should_skip_first_node(layer)) 1272 1273 # Gather layer inputs and convert to `ListWrapper` objects. 1274 inbound_nodes_data = layer_data['inbound_nodes'] 1275 inbound_nodes_data = tf_utils.convert_inner_node_data( 1276 inbound_nodes_data, wrap=True) 1277 for node_data in inbound_nodes_data: 1278 # We don't process nodes (i.e. make layer calls) 1279 # on the fly because the inbound node may not yet exist, 1280 # in case of layer shared at different topological depths 1281 # (e.g. a model such as A(B(A(B(x))))) 1282 add_unprocessed_node(layer, node_data) 1283 1284 # First, we create all layers and enqueue nodes to be processed 1285 for layer_data in config['layers']: 1286 process_layer(layer_data) 1287 # Then we process nodes in order of layer depth. 1288 # Nodes that cannot yet be processed (if the inbound node 1289 # does not yet exist) are re-enqueued, and the process 1290 # is repeated until all nodes are processed. 1291 while unprocessed_nodes: 1292 for layer_data in config['layers']: 1293 layer = created_layers[layer_data['name']] 1294 if layer in unprocessed_nodes: 1295 for node_data in unprocessed_nodes.pop(layer): 1296 process_node(layer, node_data) 1297 1298 input_tensors = [] 1299 output_tensors = [] 1300 1301 input_layers = tf_utils.convert_inner_node_data( 1302 config['input_layers'], wrap=True) 1303 for layer_data in nest.flatten(input_layers): 1304 layer_name, node_index, tensor_index = layer_data.as_list() 1305 assert layer_name in created_layers 1306 layer = created_layers[layer_name] 1307 node_index = get_node_index(layer, node_index) 1308 layer_output_tensors = layer._inbound_nodes[node_index].output_tensors 1309 input_tensors.append(nest.flatten(layer_output_tensors)[tensor_index]) 1310 1311 output_layers = tf_utils.convert_inner_node_data( 1312 config['output_layers'], wrap=True) 1313 for layer_data in nest.flatten(output_layers): 1314 layer_name, node_index, tensor_index = layer_data.as_list() 1315 assert layer_name in created_layers 1316 layer = created_layers[layer_name] 1317 node_index = get_node_index(layer, node_index) 1318 layer_output_tensors = layer._inbound_nodes[node_index].output_tensors 1319 output_tensors.append(nest.flatten(layer_output_tensors)[tensor_index]) 1320 1321 input_tensors = nest.pack_sequence_as(input_layers, input_tensors) 1322 output_tensors = nest.pack_sequence_as(output_layers, output_tensors) 1323 return input_tensors, output_tensors, created_layers 1324 1325 1326def get_network_config(network, serialize_layer_fn=None): 1327 """Builds the config, which consists of the node graph and serialized layers. 1328 1329 Args: 1330 network: A Network object. 1331 serialize_layer_fn: Function used to serialize layers. 1332 1333 Returns: 1334 Config dictionary. 1335 """ 1336 serialize_layer_fn = ( 1337 serialize_layer_fn or generic_utils.serialize_keras_object) 1338 config = { 1339 'name': network.name, 1340 } 1341 node_conversion_map = {} 1342 for layer in network.layers: 1343 kept_nodes = 1 if _should_skip_first_node(layer) else 0 1344 for original_node_index, node in enumerate(layer._inbound_nodes): 1345 node_key = _make_node_key(layer.name, original_node_index) 1346 if node_key in network._network_nodes: 1347 node_conversion_map[node_key] = kept_nodes 1348 kept_nodes += 1 1349 layer_configs = [] 1350 1351 with generic_utils.SharedObjectSavingScope(): 1352 for layer in network.layers: # From the earliest layers on. 1353 filtered_inbound_nodes = [] 1354 for original_node_index, node in enumerate(layer._inbound_nodes): 1355 node_key = _make_node_key(layer.name, original_node_index) 1356 if node_key in network._network_nodes and not node.is_input: 1357 # The node is relevant to the model: 1358 # add to filtered_inbound_nodes. 1359 node_data = node.serialize(_make_node_key, node_conversion_map) 1360 filtered_inbound_nodes.append(node_data) 1361 1362 layer_config = serialize_layer_fn(layer) 1363 layer_config['name'] = layer.name 1364 layer_config['inbound_nodes'] = filtered_inbound_nodes 1365 layer_configs.append(layer_config) 1366 config['layers'] = layer_configs 1367 1368 # Gather info about inputs and outputs. 1369 model_inputs = [] 1370 for i in range(len(network._input_layers)): 1371 layer, node_index, tensor_index = network._input_coordinates[i] 1372 node_key = _make_node_key(layer.name, node_index) 1373 if node_key not in network._network_nodes: 1374 continue 1375 new_node_index = node_conversion_map[node_key] 1376 model_inputs.append( 1377 tf_utils.ListWrapper([layer.name, new_node_index, tensor_index])) 1378 model_inputs = nest.pack_sequence_as(network._nested_inputs, model_inputs) 1379 # Preserve external Keras compat for Models with single input. 1380 if not nest.is_nested(model_inputs): 1381 model_inputs = [model_inputs] 1382 model_inputs = tf_utils.convert_inner_node_data(model_inputs) 1383 config['input_layers'] = model_inputs 1384 1385 model_outputs = [] 1386 for i in range(len(network._output_layers)): 1387 layer, node_index, tensor_index = network._output_coordinates[i] 1388 node_key = _make_node_key(layer.name, node_index) 1389 if node_key not in network._network_nodes: 1390 continue 1391 new_node_index = node_conversion_map[node_key] 1392 model_outputs.append( 1393 tf_utils.ListWrapper([layer.name, new_node_index, tensor_index])) 1394 model_outputs = nest.pack_sequence_as(network._nested_outputs, model_outputs) 1395 # Preserve external Keras compat for Models with single output. 1396 if not nest.is_nested(model_outputs): 1397 model_outputs = [model_outputs] 1398 model_outputs = tf_utils.convert_inner_node_data(model_outputs) 1399 config['output_layers'] = model_outputs 1400 return config 1401 1402 1403def shape_with_no_batch_size(x): 1404 if x.shape.rank is None: 1405 return None 1406 shape = x.shape.as_list() 1407 if shape: 1408 shape[0] = None 1409 return shape 1410 1411 1412class ModuleWrapper(base_layer.Layer): 1413 """Wrapper for `tf.Module`s to support the Functional and Sequential API.""" 1414 1415 def __init__(self, module, method_name=None, **kwargs): 1416 """Initializes the wrapper Layer for this module. 1417 1418 Args: 1419 module: The `tf.Module` instance to be wrapped. 1420 method_name: (Optional) str. The name of the method to use as the forward 1421 pass of the module. If not set, defaults to '__call__' if defined, or 1422 'call'. 1423 **kwargs: Additional keywrod arguments. See `tf.keras.layers.Layer`. 1424 1425 Raises: 1426 ValueError: If `method` is not defined on `module`. 1427 """ 1428 super(ModuleWrapper, self).__init__(**kwargs) 1429 if method_name is None: 1430 if hasattr(module, '__call__'): 1431 method_name = '__call__' 1432 elif hasattr(module, 'call'): 1433 method_name = 'call' 1434 if method_name is None or not hasattr(module, method_name): 1435 raise ValueError('{} is not defined on object {}'.format( 1436 method_name, module)) 1437 1438 self._module = module 1439 self._method_name = method_name 1440 1441 # Check if module.__call__ has a `training` arg or accepts `**kwargs`. 1442 method = getattr(module, method_name) 1443 method_arg_spec = tf_inspect.getfullargspec(method) 1444 self._expects_training_arg = ('training' in method_arg_spec.args or 1445 method_arg_spec.varkw is not None) 1446 self._expects_mask_arg = ('mask' in method_arg_spec.args or 1447 method_arg_spec.varkw is not None) 1448 1449 def call(self, *args, **kwargs): 1450 if 'training' in kwargs and not self._expects_training_arg: 1451 kwargs.pop('training') 1452 if 'mask' in kwargs and not self._expects_mask_arg: 1453 kwargs.pop('mask') 1454 return getattr(self._module, self._method_name)(*args, **kwargs) 1455