1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Utilities related to layer/model functionality. 17""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import numpy as np 23import six 24 25from tensorflow.python.keras import backend as K 26from tensorflow.python.keras.utils.conv_utils import convert_kernel 27from tensorflow.python.util import deprecation 28from tensorflow.python.util import nest 29from tensorflow.python.util import object_identity 30from tensorflow.python.util.tf_export import keras_export 31 32 33@keras_export('keras.utils.get_source_inputs') 34def get_source_inputs(tensor, layer=None, node_index=None): 35 """Returns the list of input tensors necessary to compute `tensor`. 36 37 Output will always be a list of tensors 38 (potentially with 1 element). 39 40 Arguments: 41 tensor: The tensor to start from. 42 layer: Origin layer of the tensor. Will be 43 determined via tensor._keras_history if not provided. 44 node_index: Origin node index of the tensor. 45 46 Returns: 47 List of input tensors. 48 """ 49 if not hasattr(tensor, '_keras_history'): 50 return tensor 51 52 if layer is None or node_index: 53 layer, node_index, _ = tensor._keras_history 54 if not layer._inbound_nodes: 55 return [tensor] 56 else: 57 node = layer._inbound_nodes[node_index] 58 if not node.inbound_layers: 59 # Reached an Input layer, stop recursion. 60 return nest.flatten(node.input_tensors) 61 else: 62 source_tensors = [] 63 for layer, node_index, _, tensor in node.iterate_inbound(): 64 previous_sources = get_source_inputs(tensor, layer, node_index) 65 # Avoid input redundancy. 66 for x in previous_sources: 67 if all(x is not t for t in source_tensors): 68 source_tensors.append(x) 69 return source_tensors 70 71 72def validate_string_arg(input_data, 73 allowable_strings, 74 layer_name, 75 arg_name, 76 allow_none=False, 77 allow_callables=False): 78 """Validates the correctness of a string-based arg.""" 79 if allow_none and input_data is None: 80 return 81 elif allow_callables and callable(input_data): 82 return 83 elif isinstance(input_data, 84 six.string_types) and input_data in allowable_strings: 85 return 86 else: 87 allowed_args = '`None`, ' if allow_none else '' 88 allowed_args += 'a `Callable`, ' if allow_callables else '' 89 allowed_args += 'or one of the following values: %s' % allowable_strings 90 raise ValueError(("%s's %s arg received an invalid value %s. " + 91 'Allowed values are %s.') % 92 (layer_name, arg_name, input_data, allowed_args)) 93 94 95def count_params(weights): 96 """Count the total number of scalars composing the weights. 97 98 Arguments: 99 weights: An iterable containing the weights on which to compute params 100 101 Returns: 102 The total number of scalars composing the weights 103 """ 104 unique_weights = object_identity.ObjectIdentitySet(weights) 105 weight_shapes = [w.shape.as_list() for w in unique_weights] 106 standardized_weight_shapes = [ 107 [0 if w_i is None else w_i for w_i in w] for w in weight_shapes 108 ] 109 return int(sum(np.prod(p) for p in standardized_weight_shapes)) 110 111 112def print_summary(model, line_length=None, positions=None, print_fn=None): 113 """Prints a summary of a model. 114 115 Arguments: 116 model: Keras model instance. 117 line_length: Total length of printed lines 118 (e.g. set this to adapt the display to different 119 terminal window sizes). 120 positions: Relative or absolute positions of log elements in each line. 121 If not provided, defaults to `[.33, .55, .67, 1.]`. 122 print_fn: Print function to use. 123 It will be called on each line of the summary. 124 You can set it to a custom function 125 in order to capture the string summary. 126 It defaults to `print` (prints to stdout). 127 """ 128 if print_fn is None: 129 print_fn = print 130 131 if model.__class__.__name__ == 'Sequential': 132 sequential_like = True 133 elif not model._is_graph_network: 134 # We treat subclassed models as a simple sequence of layers, for logging 135 # purposes. 136 sequential_like = True 137 else: 138 sequential_like = True 139 nodes_by_depth = model._nodes_by_depth.values() 140 nodes = [] 141 for v in nodes_by_depth: 142 if (len(v) > 1) or (len(v) == 1 and 143 len(nest.flatten(v[0].inbound_layers)) > 1): 144 # if the model has multiple nodes 145 # or if the nodes have multiple inbound_layers 146 # the model is no longer sequential 147 sequential_like = False 148 break 149 nodes += v 150 if sequential_like: 151 # search for shared layers 152 for layer in model.layers: 153 flag = False 154 for node in layer._inbound_nodes: 155 if node in nodes: 156 if flag: 157 sequential_like = False 158 break 159 else: 160 flag = True 161 if not sequential_like: 162 break 163 164 if sequential_like: 165 line_length = line_length or 65 166 positions = positions or [.45, .85, 1.] 167 if positions[-1] <= 1: 168 positions = [int(line_length * p) for p in positions] 169 # header names for the different log elements 170 to_display = ['Layer (type)', 'Output Shape', 'Param #'] 171 else: 172 line_length = line_length or 98 173 positions = positions or [.33, .55, .67, 1.] 174 if positions[-1] <= 1: 175 positions = [int(line_length * p) for p in positions] 176 # header names for the different log elements 177 to_display = ['Layer (type)', 'Output Shape', 'Param #', 'Connected to'] 178 relevant_nodes = [] 179 for v in model._nodes_by_depth.values(): 180 relevant_nodes += v 181 182 def print_row(fields, positions): 183 line = '' 184 for i in range(len(fields)): 185 if i > 0: 186 line = line[:-1] + ' ' 187 line += str(fields[i]) 188 line = line[:positions[i]] 189 line += ' ' * (positions[i] - len(line)) 190 print_fn(line) 191 192 print_fn('Model: "{}"'.format(model.name)) 193 print_fn('_' * line_length) 194 print_row(to_display, positions) 195 print_fn('=' * line_length) 196 197 def print_layer_summary(layer): 198 """Prints a summary for a single layer. 199 200 Arguments: 201 layer: target layer. 202 """ 203 try: 204 output_shape = layer.output_shape 205 except AttributeError: 206 output_shape = 'multiple' 207 except RuntimeError: # output_shape unknown in Eager mode. 208 output_shape = '?' 209 name = layer.name 210 cls_name = layer.__class__.__name__ 211 fields = [name + ' (' + cls_name + ')', output_shape, layer.count_params()] 212 print_row(fields, positions) 213 214 def print_layer_summary_with_connections(layer): 215 """Prints a summary for a single layer (including topological connections). 216 217 Arguments: 218 layer: target layer. 219 """ 220 try: 221 output_shape = layer.output_shape 222 except AttributeError: 223 output_shape = 'multiple' 224 connections = [] 225 for node in layer._inbound_nodes: 226 if relevant_nodes and node not in relevant_nodes: 227 # node is not part of the current network 228 continue 229 230 for inbound_layer, node_index, tensor_index, _ in node.iterate_inbound(): 231 connections.append('{}[{}][{}]'.format(inbound_layer.name, node_index, 232 tensor_index)) 233 234 name = layer.name 235 cls_name = layer.__class__.__name__ 236 if not connections: 237 first_connection = '' 238 else: 239 first_connection = connections[0] 240 fields = [ 241 name + ' (' + cls_name + ')', output_shape, 242 layer.count_params(), first_connection 243 ] 244 print_row(fields, positions) 245 if len(connections) > 1: 246 for i in range(1, len(connections)): 247 fields = ['', '', '', connections[i]] 248 print_row(fields, positions) 249 250 layers = model.layers 251 for i in range(len(layers)): 252 if sequential_like: 253 print_layer_summary(layers[i]) 254 else: 255 print_layer_summary_with_connections(layers[i]) 256 if i == len(layers) - 1: 257 print_fn('=' * line_length) 258 else: 259 print_fn('_' * line_length) 260 261 model._check_trainable_weights_consistency() 262 if hasattr(model, '_collected_trainable_weights'): 263 trainable_count = count_params(model._collected_trainable_weights) 264 else: 265 trainable_count = count_params(model.trainable_weights) 266 267 non_trainable_count = count_params(model.non_trainable_weights) 268 269 print_fn('Total params: {:,}'.format(trainable_count + non_trainable_count)) 270 print_fn('Trainable params: {:,}'.format(trainable_count)) 271 print_fn('Non-trainable params: {:,}'.format(non_trainable_count)) 272 print_fn('_' * line_length) 273 274 275def gather_trainable_weights(trainable, sub_layers, extra_variables): 276 """Lists the trainable weights for an object with sub-layers. 277 278 Args: 279 trainable: Whether the object collecting the variables is trainable. 280 sub_layers: A flat list of Layer objects owned by this object, to collect 281 variables from. 282 extra_variables: Any extra variables to include. Their `.trainable` property 283 is used to categorize them. 284 285 Returns: 286 A list of collected trainable weights/variables. 287 """ 288 if not trainable: 289 return [] 290 weights = [] 291 for layer in sub_layers: 292 weights += layer.trainable_weights 293 trainable_extra_variables = [ 294 v for v in extra_variables if v.trainable] 295 return weights + trainable_extra_variables 296 297 298def gather_non_trainable_weights(trainable, sub_layers, extra_variables): 299 """Lists the non-trainable weights for an object with sub-layers. 300 301 Args: 302 trainable: Whether the object collecting the variables is trainable. 303 sub_layers: A flat list of Layer objects owned by this object, to collect 304 variables from. 305 extra_variables: Any extra variables to include. Their `.trainable` property 306 is used to categorize them. 307 308 Returns: 309 A list of collected non-trainable weights/variables. 310 """ 311 trainable_extra_variables = [] 312 non_trainable_extra_variables = [] 313 for v in extra_variables: 314 if v.trainable: 315 trainable_extra_variables.append(v) 316 else: 317 non_trainable_extra_variables.append(v) 318 weights = [] 319 for layer in sub_layers: 320 weights += layer.non_trainable_weights 321 if not trainable: 322 trainable_weights = [] 323 for layer in sub_layers: 324 trainable_weights += layer.trainable_weights 325 return (trainable_weights + trainable_extra_variables 326 + weights + non_trainable_extra_variables) 327 return weights + non_trainable_extra_variables 328 329 330@deprecation.deprecated('2020-06-23', 331 'The Theano kernel format is legacy; ' 332 'this utility will be removed.') 333@keras_export('keras.utils.convert_all_kernels_in_model') 334def convert_all_kernels_in_model(model): 335 """Converts all convolution kernels in a model from Theano to TensorFlow. 336 337 Also works from TensorFlow to Theano. 338 339 This is used for converting legacy Theano-saved model files. 340 341 Arguments: 342 model: target model for the conversion. 343 """ 344 # Note: SeparableConvolution not included 345 # since only supported by TF. 346 conv_classes = { 347 'Conv1D', 348 'Conv2D', 349 'Conv3D', 350 'Conv2DTranspose', 351 } 352 to_assign = [] 353 for layer in model.layers: 354 if layer.__class__.__name__ in conv_classes: 355 original_kernel = K.get_value(layer.kernel) 356 converted_kernel = convert_kernel(original_kernel) 357 to_assign.append((layer.kernel, converted_kernel)) 358 K.batch_set_value(to_assign) 359 360 361def convert_dense_weights_data_format(dense, 362 previous_feature_map_shape, 363 target_data_format='channels_first'): 364 """Utility useful when changing a convnet's `data_format`. 365 366 When porting the weights of a convnet from one data format to the other, 367 if the convnet includes a `Flatten` layer 368 (applied to the last convolutional feature map) 369 followed by a `Dense` layer, the weights of that `Dense` layer 370 should be updated to reflect the new dimension ordering. 371 372 Arguments: 373 dense: The target `Dense` layer. 374 previous_feature_map_shape: A shape tuple of 3 integers, 375 e.g. `(512, 7, 7)`. The shape of the convolutional 376 feature map right before the `Flatten` layer that 377 came before the target `Dense` layer. 378 target_data_format: One of "channels_last", "channels_first". 379 Set it "channels_last" 380 if converting a "channels_first" model to "channels_last", 381 or reciprocally. 382 """ 383 assert target_data_format in {'channels_last', 'channels_first'} 384 kernel, bias = dense.get_weights() 385 for i in range(kernel.shape[1]): 386 if target_data_format == 'channels_first': 387 c, h, w = previous_feature_map_shape 388 original_fm_shape = (h, w, c) 389 ki = kernel[:, i].reshape(original_fm_shape) 390 ki = np.transpose(ki, (2, 0, 1)) # last -> first 391 else: 392 h, w, c = previous_feature_map_shape 393 original_fm_shape = (c, h, w) 394 ki = kernel[:, i].reshape(original_fm_shape) 395 ki = np.transpose(ki, (1, 2, 0)) # first -> last 396 kernel[:, i] = np.reshape(ki, (np.prod(previous_feature_map_shape),)) 397 dense.set_weights([kernel, bias]) 398 399 400def is_builtin_layer(layer): 401 if not getattr(layer, '_keras_api_names', None): 402 return False 403 404 # Subclasses of `Layer` that are not exported inherit the export name 405 # of the base layer class. 406 return (layer._keras_api_names != ('keras.layers.Layer',) and 407 layer._keras_api_names_v1 != ('keras.layers.Layer',)) 408