1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Contains the `Node` class.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import copy 23import json 24import numpy as np 25 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import tensor_util 28from tensorflow.python.keras import backend 29from tensorflow.python.keras.engine import base_layer_utils 30from tensorflow.python.keras.engine import keras_tensor 31from tensorflow.python.keras.saving.saved_model import json_utils 32from tensorflow.python.keras.utils import tf_utils 33from tensorflow.python.util import nest 34 35_CONSTANT_VALUE = '_CONSTANT_VALUE' 36 37 38class Node(object): 39 """A `Node` describes the connectivity between two layers. 40 41 Each time a layer is connected to some new input, 42 a node is added to `layer._inbound_nodes`. 43 Each time the output of a layer is used by another layer, 44 a node is added to `layer._outbound_nodes`. 45 46 Args: 47 layer: The Layer for the Layer.__call__ this node represents. 48 call_args: The positional arguments the Layer was called with. 49 call_kwargs: The keyword arguments the Layer was called with. 50 outputs: The outputs of the Layer.__call__ 51 """ 52 53 def __init__(self, 54 layer, 55 call_args=None, 56 call_kwargs=None, 57 outputs=None): 58 call_args = [] if call_args is None else call_args 59 call_kwargs = {} if call_kwargs is None else call_kwargs 60 outputs = [] if outputs is None else outputs 61 62 self.layer = layer 63 self.is_input = not call_args and not call_kwargs 64 65 # These arguments are user-provided. Copy the structures here so that 66 # future user modifications do not affect the node's metadata. 67 # We copy using map_structure rather than python's shallow or deep copy, 68 # because the args can be data structures (so shallow copy is 69 # insufficient), but individual values might not support copy.copy 70 # or be too expensive to deep copy. 71 call_args = nest.map_structure(lambda t: t, call_args) 72 call_kwargs = nest.map_structure(lambda t: t, call_kwargs) 73 self.outputs = nest.map_structure(lambda t: t, outputs) 74 self.call_args = call_args 75 self.call_kwargs = call_kwargs 76 77 # Cached for performance. 78 self._flat_arguments = nest.flatten((self.call_args, self.call_kwargs)) 79 # Used to avoid expensive `nest` operations in the most common case. 80 self._single_positional_tensor_passed = (not self.call_kwargs and len( 81 self.call_args) == 1 and tensor_util.is_tf_type(self.call_args[0])) 82 83 if not keras_tensor.keras_tensors_enabled(): 84 # Create TensorFlowOpLayers if needed. 85 for obj in self._flat_arguments: 86 if (isinstance(obj, ops.Tensor) and 87 base_layer_utils.needs_keras_history( 88 obj, ignore_call_context=True)): 89 base_layer_utils.create_keras_history(obj) 90 91 self._keras_inputs = [] 92 self._keras_inputs_ids_and_indices = [] 93 for i, ele in enumerate(self._flat_arguments): 94 if is_keras_tensor(ele): 95 self._keras_inputs.append(ele) 96 kt_id = str(id(ele)) 97 kt_index = i 98 self._keras_inputs_ids_and_indices.append((kt_id, kt_index)) 99 100 # Wire up Node to Layers. 101 self.layer._inbound_nodes.append(self) 102 for kt in self.keras_inputs: 103 inbound_layer = kt._keras_history.layer 104 if inbound_layer is not None: # `None` for `Input` tensors. 105 inbound_layer._outbound_nodes.append(self) 106 107 # Set metadata on outputs. 108 node_index = len(self.layer._inbound_nodes) - 1 109 for i, tensor in enumerate(nest.flatten(outputs)): 110 tensor._keras_history = KerasHistory( 111 layer=layer, node_index=node_index, tensor_index=i) 112 113 # Cached for performance. 114 self.flat_input_ids = [str(id(t)) for t in self._keras_inputs] 115 self.flat_output_ids = [str(id(t)) for t in nest.flatten(self.outputs)] 116 117 @property 118 def keras_inputs(self): 119 """Tensors input to this node that can be traced back to a `keras.Input`.""" 120 return self._keras_inputs 121 122 @property 123 def parent_nodes(self): 124 """Returns all the `Node`s whose output this node immediately depends on.""" 125 node_deps = [] 126 for kt in self.keras_inputs: 127 layer = kt._keras_history.layer 128 node_index = kt._keras_history.node_index 129 if layer is not None: # `None` for `Input` tensors. 130 node_deps.append(layer._inbound_nodes[node_index]) 131 return node_deps 132 133 def iterate_inbound(self): 134 """Yields tuples representing the data inbound from other nodes. 135 136 Yields: 137 tuples like: (inbound_layer, node_index, tensor_index, tensor). 138 """ 139 for kt in self.keras_inputs: 140 keras_history = kt._keras_history 141 layer = keras_history.layer 142 node_index = keras_history.node_index 143 tensor_index = keras_history.tensor_index 144 yield layer, node_index, tensor_index, kt 145 146 def map_arguments(self, tensor_dict): 147 """Maps Keras Tensors to computed Tensors using `tensor_dict`.""" 148 if self._single_positional_tensor_passed: 149 # Performance optimization for most common case. 150 kt_id, _ = self._keras_inputs_ids_and_indices[0] 151 return (tensor_dict[kt_id].pop(),), {} 152 else: 153 flat_arguments = copy.copy(self._flat_arguments) 154 for kt_id, kt_index in self._keras_inputs_ids_and_indices: 155 flat_arguments[kt_index] = tensor_dict[kt_id].pop() 156 157 args, kwargs = nest.pack_sequence_as((self.call_args, self.call_kwargs), 158 flat_arguments) 159 return args, kwargs 160 161 def serialize(self, make_node_key, node_conversion_map): 162 """Serializes `Node` for Functional API's `get_config`.""" 163 # Serialization still special-cases first argument. 164 args, kwargs = self.call_args, self.call_kwargs 165 inputs, args, kwargs = self.layer._split_out_first_arg(args, kwargs) 166 167 # Treat everything other than first argument as a kwarg. 168 arguments = dict(zip(self.layer._call_fn_args[1:], args)) 169 arguments.update(kwargs) 170 kwargs = arguments 171 172 def _serialize_keras_tensor(t): 173 """Serializes a single Tensor passed to `call`.""" 174 if hasattr(t, '_keras_history'): 175 kh = t._keras_history 176 node_index = kh.node_index 177 node_key = make_node_key(kh.layer.name, node_index) 178 new_node_index = node_conversion_map.get(node_key, 0) 179 return [kh.layer.name, new_node_index, kh.tensor_index] 180 181 if isinstance(t, np.ndarray): 182 return t.tolist() 183 184 if isinstance(t, ops.Tensor): 185 return backend.get_value(t).tolist() 186 187 return t 188 189 kwargs = nest.map_structure(_serialize_keras_tensor, kwargs) 190 try: 191 json.dumps(kwargs, default=json_utils.get_json_type) 192 except TypeError: 193 kwarg_types = nest.map_structure(type, kwargs) 194 raise TypeError('Layer ' + self.layer.name + 195 ' was passed non-JSON-serializable arguments. ' + 196 'Arguments had types: ' + 197 str(kwarg_types) + '. They cannot be serialized out ' 198 'when saving the model.') 199 200 # `kwargs` is added to each Tensor in the first arg. This should be 201 # changed in a future version of the serialization format. 202 def serialize_first_arg_tensor(t): 203 if is_keras_tensor(t): 204 kh = t._keras_history 205 node_index = kh.node_index 206 node_key = make_node_key(kh.layer.name, node_index) 207 new_node_index = node_conversion_map.get(node_key, 0) 208 data = [kh.layer.name, new_node_index, kh.tensor_index, kwargs] 209 else: 210 # If an element in the first call argument did not originate as a 211 # keras tensor and is a constant value, we save it using the format 212 # ['_CONSTANT_VALUE', -1, serializaed_tensor_or_python_constant] 213 # (potentially including serialized kwargs in an optional 4th argument 214 data = [_CONSTANT_VALUE, -1, _serialize_keras_tensor(t), kwargs] 215 return tf_utils.ListWrapper(data) 216 217 data = nest.map_structure(serialize_first_arg_tensor, inputs) 218 if (not nest.is_nested(data) and 219 not self.layer._preserve_input_structure_in_config): 220 data = [data] 221 data = tf_utils.convert_inner_node_data(data) 222 return data 223 224 ############################################################# 225 # Properties for Backwards compatibility. 226 # These only check the first input argument 227 # As nodes are internal, they may be removed in the future. 228 ############################################################# 229 230 @property 231 def input_tensors(self): 232 if self.is_input: 233 return [self.outputs] # Used in `Layer.input`. 234 return self.call_args[0] 235 236 @property 237 def output_tensors(self): 238 if self.is_input: 239 return [self.outputs] # Used in `Layer.input`. 240 return self.outputs 241 242 @property 243 def input_shapes(self): 244 input_shapes = nest.map_structure(backend.int_shape, self.input_tensors) 245 if len(input_shapes) == 1 and not self.is_input: 246 return input_shapes[0] 247 return input_shapes 248 249 @property 250 def output_shapes(self): 251 return nest.map_structure(backend.int_shape, self.output_tensors) 252 253 @property 254 def outbound_layer(self): 255 return self.layer 256 257 @property 258 def inbound_layers(self): 259 if self.is_input: 260 return [] 261 inbound_layers = nest.map_structure(lambda t: t._keras_history.layer, 262 self.call_args[0]) 263 return inbound_layers 264 265 266class KerasHistory( 267 collections.namedtuple('KerasHistory', 268 ['layer', 'node_index', 'tensor_index'])): 269 """Tracks the Layer call that created a Tensor, for Keras Graph Networks. 270 271 During construction of Keras Graph Networks, this metadata is added to 272 each Tensor produced as the output of a Layer, starting with an 273 `InputLayer`. This allows Keras to track how each Tensor was produced, and 274 this information is later retraced by the `keras.engine.Network` class to 275 reconstruct the Keras Graph Network. 276 277 Attributes: 278 layer: The Layer that produced the Tensor. 279 node_index: The specific call to the Layer that produced this Tensor. Layers 280 can be called multiple times in order to share weights. A new node is 281 created every time a Layer is called. 282 tensor_index: The output index for this Tensor. Always zero if the Layer 283 that produced this Tensor only has one output. Nested structures of 284 Tensors are deterministically assigned an index via `nest.flatten`. 285 """ 286 # Added to maintain memory and performance characteristics of `namedtuple` 287 # while subclassing. 288 __slots__ = () 289 290 291def is_keras_tensor(obj): 292 return hasattr(obj, '_keras_history') 293