1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Utilities related to layer/model functionality. 17""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import functools 23import weakref 24 25import numpy as np 26import six 27 28from tensorflow.python.util import nest 29from tensorflow.python.util.tf_export import keras_export 30 31 32@keras_export('keras.utils.get_source_inputs') 33def get_source_inputs(tensor, layer=None, node_index=None): 34 """Returns the list of input tensors necessary to compute `tensor`. 35 36 Output will always be a list of tensors 37 (potentially with 1 element). 38 39 Args: 40 tensor: The tensor to start from. 41 layer: Origin layer of the tensor. Will be 42 determined via tensor._keras_history if not provided. 43 node_index: Origin node index of the tensor. 44 45 Returns: 46 List of input tensors. 47 """ 48 if not hasattr(tensor, '_keras_history'): 49 return tensor 50 51 if layer is None or node_index: 52 layer, node_index, _ = tensor._keras_history 53 if not layer._inbound_nodes: 54 return [tensor] 55 else: 56 node = layer._inbound_nodes[node_index] 57 if node.is_input: 58 # Reached an Input layer, stop recursion. 59 return nest.flatten(node.input_tensors) 60 else: 61 source_tensors = [] 62 for layer, node_index, _, tensor in node.iterate_inbound(): 63 previous_sources = get_source_inputs(tensor, layer, node_index) 64 # Avoid input redundancy. 65 for x in previous_sources: 66 if all(x is not t for t in source_tensors): 67 source_tensors.append(x) 68 return source_tensors 69 70 71def validate_string_arg(input_data, 72 allowable_strings, 73 layer_name, 74 arg_name, 75 allow_none=False, 76 allow_callables=False): 77 """Validates the correctness of a string-based arg.""" 78 if allow_none and input_data is None: 79 return 80 elif allow_callables and callable(input_data): 81 return 82 elif isinstance(input_data, 83 six.string_types) and input_data in allowable_strings: 84 return 85 else: 86 allowed_args = '`None`, ' if allow_none else '' 87 allowed_args += 'a `Callable`, ' if allow_callables else '' 88 allowed_args += 'or one of the following values: %s' % (allowable_strings,) 89 raise ValueError(('The %s argument of layer %s received an invalid ' 90 'value %s. Allowed values are: %s.') % 91 (arg_name, layer_name, input_data, allowed_args)) 92 93 94def count_params(weights): 95 """Count the total number of scalars composing the weights. 96 97 Args: 98 weights: An iterable containing the weights on which to compute params 99 100 Returns: 101 The total number of scalars composing the weights 102 """ 103 unique_weights = {id(w): w for w in weights}.values() 104 weight_shapes = [w.shape.as_list() for w in unique_weights] 105 standardized_weight_shapes = [ 106 [0 if w_i is None else w_i for w_i in w] for w in weight_shapes 107 ] 108 return int(sum(np.prod(p) for p in standardized_weight_shapes)) 109 110 111def print_summary(model, line_length=None, positions=None, print_fn=None): 112 """Prints a summary of a model. 113 114 Args: 115 model: Keras model instance. 116 line_length: Total length of printed lines 117 (e.g. set this to adapt the display to different 118 terminal window sizes). 119 positions: Relative or absolute positions of log elements in each line. 120 If not provided, defaults to `[.33, .55, .67, 1.]`. 121 print_fn: Print function to use. 122 It will be called on each line of the summary. 123 You can set it to a custom function 124 in order to capture the string summary. 125 It defaults to `print` (prints to stdout). 126 """ 127 if print_fn is None: 128 print_fn = print 129 130 if model.__class__.__name__ == 'Sequential': 131 sequential_like = True 132 elif not model._is_graph_network: 133 # We treat subclassed models as a simple sequence of layers, for logging 134 # purposes. 135 sequential_like = True 136 else: 137 sequential_like = True 138 nodes_by_depth = model._nodes_by_depth.values() 139 nodes = [] 140 for v in nodes_by_depth: 141 if (len(v) > 1) or (len(v) == 1 and 142 len(nest.flatten(v[0].keras_inputs)) > 1): 143 # if the model has multiple nodes 144 # or if the nodes have multiple inbound_layers 145 # the model is no longer sequential 146 sequential_like = False 147 break 148 nodes += v 149 if sequential_like: 150 # search for shared layers 151 for layer in model.layers: 152 flag = False 153 for node in layer._inbound_nodes: 154 if node in nodes: 155 if flag: 156 sequential_like = False 157 break 158 else: 159 flag = True 160 if not sequential_like: 161 break 162 163 if sequential_like: 164 line_length = line_length or 65 165 positions = positions or [.45, .85, 1.] 166 if positions[-1] <= 1: 167 positions = [int(line_length * p) for p in positions] 168 # header names for the different log elements 169 to_display = ['Layer (type)', 'Output Shape', 'Param #'] 170 else: 171 line_length = line_length or 98 172 positions = positions or [.33, .55, .67, 1.] 173 if positions[-1] <= 1: 174 positions = [int(line_length * p) for p in positions] 175 # header names for the different log elements 176 to_display = ['Layer (type)', 'Output Shape', 'Param #', 'Connected to'] 177 relevant_nodes = [] 178 for v in model._nodes_by_depth.values(): 179 relevant_nodes += v 180 181 def print_row(fields, positions): 182 line = '' 183 for i in range(len(fields)): 184 if i > 0: 185 line = line[:-1] + ' ' 186 line += str(fields[i]) 187 line = line[:positions[i]] 188 line += ' ' * (positions[i] - len(line)) 189 print_fn(line) 190 191 print_fn('Model: "{}"'.format(model.name)) 192 print_fn('_' * line_length) 193 print_row(to_display, positions) 194 print_fn('=' * line_length) 195 196 def print_layer_summary(layer): 197 """Prints a summary for a single layer. 198 199 Args: 200 layer: target layer. 201 """ 202 try: 203 output_shape = layer.output_shape 204 except AttributeError: 205 output_shape = 'multiple' 206 except RuntimeError: # output_shape unknown in Eager mode. 207 output_shape = '?' 208 name = layer.name 209 cls_name = layer.__class__.__name__ 210 if not layer.built and not getattr(layer, '_is_graph_network', False): 211 # If a subclassed model has a layer that is not called in Model.call, the 212 # layer will not be built and we cannot call layer.count_params(). 213 params = '0 (unused)' 214 else: 215 params = layer.count_params() 216 fields = [name + ' (' + cls_name + ')', output_shape, params] 217 print_row(fields, positions) 218 219 def print_layer_summary_with_connections(layer): 220 """Prints a summary for a single layer (including topological connections). 221 222 Args: 223 layer: target layer. 224 """ 225 try: 226 output_shape = layer.output_shape 227 except AttributeError: 228 output_shape = 'multiple' 229 connections = [] 230 for node in layer._inbound_nodes: 231 if relevant_nodes and node not in relevant_nodes: 232 # node is not part of the current network 233 continue 234 235 for inbound_layer, node_index, tensor_index, _ in node.iterate_inbound(): 236 connections.append('{}[{}][{}]'.format(inbound_layer.name, node_index, 237 tensor_index)) 238 239 name = layer.name 240 cls_name = layer.__class__.__name__ 241 if not connections: 242 first_connection = '' 243 else: 244 first_connection = connections[0] 245 fields = [ 246 name + ' (' + cls_name + ')', output_shape, 247 layer.count_params(), first_connection 248 ] 249 print_row(fields, positions) 250 if len(connections) > 1: 251 for i in range(1, len(connections)): 252 fields = ['', '', '', connections[i]] 253 print_row(fields, positions) 254 255 layers = model.layers 256 for i in range(len(layers)): 257 if sequential_like: 258 print_layer_summary(layers[i]) 259 else: 260 print_layer_summary_with_connections(layers[i]) 261 if i == len(layers) - 1: 262 print_fn('=' * line_length) 263 else: 264 print_fn('_' * line_length) 265 266 if hasattr(model, '_collected_trainable_weights'): 267 trainable_count = count_params(model._collected_trainable_weights) 268 else: 269 trainable_count = count_params(model.trainable_weights) 270 271 non_trainable_count = count_params(model.non_trainable_weights) 272 273 print_fn('Total params: {:,}'.format(trainable_count + non_trainable_count)) 274 print_fn('Trainable params: {:,}'.format(trainable_count)) 275 print_fn('Non-trainable params: {:,}'.format(non_trainable_count)) 276 print_fn('_' * line_length) 277 278 279def convert_dense_weights_data_format(dense, 280 previous_feature_map_shape, 281 target_data_format='channels_first'): 282 """Utility useful when changing a convnet's `data_format`. 283 284 When porting the weights of a convnet from one data format to the other, 285 if the convnet includes a `Flatten` layer 286 (applied to the last convolutional feature map) 287 followed by a `Dense` layer, the weights of that `Dense` layer 288 should be updated to reflect the new dimension ordering. 289 290 Args: 291 dense: The target `Dense` layer. 292 previous_feature_map_shape: A shape tuple of 3 integers, 293 e.g. `(512, 7, 7)`. The shape of the convolutional 294 feature map right before the `Flatten` layer that 295 came before the target `Dense` layer. 296 target_data_format: One of "channels_last", "channels_first". 297 Set it "channels_last" 298 if converting a "channels_first" model to "channels_last", 299 or reciprocally. 300 """ 301 assert target_data_format in {'channels_last', 'channels_first'} 302 kernel, bias = dense.get_weights() 303 for i in range(kernel.shape[1]): 304 if target_data_format == 'channels_first': 305 c, h, w = previous_feature_map_shape 306 original_fm_shape = (h, w, c) 307 ki = kernel[:, i].reshape(original_fm_shape) 308 ki = np.transpose(ki, (2, 0, 1)) # last -> first 309 else: 310 h, w, c = previous_feature_map_shape 311 original_fm_shape = (c, h, w) 312 ki = kernel[:, i].reshape(original_fm_shape) 313 ki = np.transpose(ki, (1, 2, 0)) # first -> last 314 kernel[:, i] = np.reshape(ki, (np.prod(previous_feature_map_shape),)) 315 dense.set_weights([kernel, bias]) 316 317 318def is_builtin_layer(layer): 319 if not getattr(layer, '_keras_api_names', None): 320 return False 321 322 # Subclasses of `Layer` that are not exported inherit the export name 323 # of the base layer class. 324 return (layer._keras_api_names != ('keras.layers.Layer',) and 325 layer._keras_api_names_v1 != ('keras.layers.Layer',)) 326 327 328def cached_per_instance(f): 329 """Lightweight decorator for caching lazily constructed properties. 330 331 When to use: 332 This decorator provides simple caching with minimal overhead. It is designed 333 for properties which are expensive to compute and static over the life of a 334 class instance, and provides no mechanism for cache invalidation. Thus it is 335 best suited for lazily exposing derived properties of other static data. 336 337 For classes with custom getattr / setattr behavior (such as trackable 338 objects), storing cache results as object attributes is not performant. 339 Instead, a specialized cache can significantly reduce property lookup 340 overhead. (While still allowing the decorated property to be lazily computed.) 341 Consider the following class: 342 343 ``` 344 class MyClass(object): 345 def __setattr__(self, key, value): 346 # Some expensive class specific code 347 # ... 348 # ... 349 350 super(MyClass, self).__setattr__(key, value) 351 352 @property 353 def thing(self): 354 # `thing` is expensive to compute (and may not even be requested), so we 355 # want to lazily compute it and then cache it. 356 output = getattr(self, '_thing', None) 357 if output is None: 358 self._thing = output = compute_thing(self) 359 return output 360 ``` 361 362 It's also worth noting that ANY overriding of __setattr__, even something as 363 simple as: 364 ``` 365 def __setattr__(self, key, value): 366 super(MyClass, self).__setattr__(key, value) 367 ``` 368 369 Slows down attribute assignment by nearly 10x. 370 371 By contrast, replacing the definition of `thing` with the following sidesteps 372 the expensive __setattr__ altogether: 373 374 ''' 375 @property 376 @tracking.cached_per_instance 377 def thing(self): 378 # `thing` is expensive to compute (and may not even be requested), so we 379 # want to lazily compute it and then cache it. 380 return compute_thing(self) 381 ''' 382 383 Performance: 384 The overhead for this decorator is ~0.4 us / call. A much lower overhead 385 implementation (~0.085 us / call) can be achieved by using a custom dict type: 386 387 ``` 388 def dict_based_cache(f): 389 class Cache(dict): 390 __slots__ = () 391 def __missing__(self, key): 392 self[key] = output = f(key) 393 return output 394 395 return property(Cache().__getitem__) 396 ``` 397 398 However, that implementation holds class instances as keys, and as a result 399 blocks garbage collection. (And modifying it to use weakref's as keys raises 400 the lookup overhead to ~0.4 us) As a result, the WeakKeyDictionary 401 implementation below turns out to be more prudent. 402 403 Args: 404 f: The function to cache. 405 406 Returns: 407 f decorated with simple caching behavior. 408 """ 409 410 cache = weakref.WeakKeyDictionary() 411 412 @functools.wraps(f) 413 def wrapped(item): 414 output = cache.get(item) 415 if output is None: 416 cache[item] = output = f(item) 417 return output 418 419 wrapped.cache = cache 420 return wrapped 421 422 423def filter_empty_layer_containers(layer_list): 424 """Filter out empty Layer-like containers and uniquify.""" 425 # TODO(b/130381733): Make this an attribute in base_layer.Layer. 426 existing = set() 427 to_visit = layer_list[::-1] 428 while to_visit: 429 obj = to_visit.pop() 430 if id(obj) in existing: 431 continue 432 existing.add(id(obj)) 433 if hasattr(obj, '_is_layer') and not isinstance(obj, type): 434 yield obj 435 else: 436 sub_layers = getattr(obj, 'layers', None) or [] 437 438 # Trackable data structures will not show up in ".layers" lists, but 439 # the layers they contain will. 440 to_visit.extend(sub_layers[::-1]) 441