• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Utilities related to layer/model functionality.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import numpy as np
23import six
24
25from tensorflow.python.keras import backend as K
26from tensorflow.python.keras.utils.conv_utils import convert_kernel
27from tensorflow.python.util import deprecation
28from tensorflow.python.util import nest
29from tensorflow.python.util import object_identity
30from tensorflow.python.util.tf_export import keras_export
31
32
33@keras_export('keras.utils.get_source_inputs')
34def get_source_inputs(tensor, layer=None, node_index=None):
35  """Returns the list of input tensors necessary to compute `tensor`.
36
37  Output will always be a list of tensors
38  (potentially with 1 element).
39
40  Arguments:
41      tensor: The tensor to start from.
42      layer: Origin layer of the tensor. Will be
43          determined via tensor._keras_history if not provided.
44      node_index: Origin node index of the tensor.
45
46  Returns:
47      List of input tensors.
48  """
49  if not hasattr(tensor, '_keras_history'):
50    return tensor
51
52  if layer is None or node_index:
53    layer, node_index, _ = tensor._keras_history
54  if not layer._inbound_nodes:
55    return [tensor]
56  else:
57    node = layer._inbound_nodes[node_index]
58    if not node.inbound_layers:
59      # Reached an Input layer, stop recursion.
60      return nest.flatten(node.input_tensors)
61    else:
62      source_tensors = []
63      for layer, node_index, _, tensor in node.iterate_inbound():
64        previous_sources = get_source_inputs(tensor, layer, node_index)
65        # Avoid input redundancy.
66        for x in previous_sources:
67          if all(x is not t for t in source_tensors):
68            source_tensors.append(x)
69      return source_tensors
70
71
72def validate_string_arg(input_data,
73                        allowable_strings,
74                        layer_name,
75                        arg_name,
76                        allow_none=False,
77                        allow_callables=False):
78  """Validates the correctness of a string-based arg."""
79  if allow_none and input_data is None:
80    return
81  elif allow_callables and callable(input_data):
82    return
83  elif isinstance(input_data,
84                  six.string_types) and input_data in allowable_strings:
85    return
86  else:
87    allowed_args = '`None`, ' if allow_none else ''
88    allowed_args += 'a `Callable`, ' if allow_callables else ''
89    allowed_args += 'or one of the following values: %s' % allowable_strings
90    raise ValueError(("%s's %s arg received an invalid value %s. " +
91                      'Allowed values are %s.') %
92                     (layer_name, arg_name, input_data, allowed_args))
93
94
95def count_params(weights):
96  """Count the total number of scalars composing the weights.
97
98  Arguments:
99      weights: An iterable containing the weights on which to compute params
100
101  Returns:
102      The total number of scalars composing the weights
103  """
104  unique_weights = object_identity.ObjectIdentitySet(weights)
105  weight_shapes = [w.shape.as_list() for w in unique_weights]
106  standardized_weight_shapes = [
107      [0 if w_i is None else w_i for w_i in w] for w in weight_shapes
108  ]
109  return int(sum(np.prod(p) for p in standardized_weight_shapes))
110
111
112def print_summary(model, line_length=None, positions=None, print_fn=None):
113  """Prints a summary of a model.
114
115  Arguments:
116      model: Keras model instance.
117      line_length: Total length of printed lines
118          (e.g. set this to adapt the display to different
119          terminal window sizes).
120      positions: Relative or absolute positions of log elements in each line.
121          If not provided, defaults to `[.33, .55, .67, 1.]`.
122      print_fn: Print function to use.
123          It will be called on each line of the summary.
124          You can set it to a custom function
125          in order to capture the string summary.
126          It defaults to `print` (prints to stdout).
127  """
128  if print_fn is None:
129    print_fn = print
130
131  if model.__class__.__name__ == 'Sequential':
132    sequential_like = True
133  elif not model._is_graph_network:
134    # We treat subclassed models as a simple sequence of layers, for logging
135    # purposes.
136    sequential_like = True
137  else:
138    sequential_like = True
139    nodes_by_depth = model._nodes_by_depth.values()
140    nodes = []
141    for v in nodes_by_depth:
142      if (len(v) > 1) or (len(v) == 1 and
143                          len(nest.flatten(v[0].inbound_layers)) > 1):
144        # if the model has multiple nodes
145        # or if the nodes have multiple inbound_layers
146        # the model is no longer sequential
147        sequential_like = False
148        break
149      nodes += v
150    if sequential_like:
151      # search for shared layers
152      for layer in model.layers:
153        flag = False
154        for node in layer._inbound_nodes:
155          if node in nodes:
156            if flag:
157              sequential_like = False
158              break
159            else:
160              flag = True
161        if not sequential_like:
162          break
163
164  if sequential_like:
165    line_length = line_length or 65
166    positions = positions or [.45, .85, 1.]
167    if positions[-1] <= 1:
168      positions = [int(line_length * p) for p in positions]
169    # header names for the different log elements
170    to_display = ['Layer (type)', 'Output Shape', 'Param #']
171  else:
172    line_length = line_length or 98
173    positions = positions or [.33, .55, .67, 1.]
174    if positions[-1] <= 1:
175      positions = [int(line_length * p) for p in positions]
176    # header names for the different log elements
177    to_display = ['Layer (type)', 'Output Shape', 'Param #', 'Connected to']
178    relevant_nodes = []
179    for v in model._nodes_by_depth.values():
180      relevant_nodes += v
181
182  def print_row(fields, positions):
183    line = ''
184    for i in range(len(fields)):
185      if i > 0:
186        line = line[:-1] + ' '
187      line += str(fields[i])
188      line = line[:positions[i]]
189      line += ' ' * (positions[i] - len(line))
190    print_fn(line)
191
192  print_fn('Model: "{}"'.format(model.name))
193  print_fn('_' * line_length)
194  print_row(to_display, positions)
195  print_fn('=' * line_length)
196
197  def print_layer_summary(layer):
198    """Prints a summary for a single layer.
199
200    Arguments:
201        layer: target layer.
202    """
203    try:
204      output_shape = layer.output_shape
205    except AttributeError:
206      output_shape = 'multiple'
207    except RuntimeError:  # output_shape unknown in Eager mode.
208      output_shape = '?'
209    name = layer.name
210    cls_name = layer.__class__.__name__
211    fields = [name + ' (' + cls_name + ')', output_shape, layer.count_params()]
212    print_row(fields, positions)
213
214  def print_layer_summary_with_connections(layer):
215    """Prints a summary for a single layer (including topological connections).
216
217    Arguments:
218        layer: target layer.
219    """
220    try:
221      output_shape = layer.output_shape
222    except AttributeError:
223      output_shape = 'multiple'
224    connections = []
225    for node in layer._inbound_nodes:
226      if relevant_nodes and node not in relevant_nodes:
227        # node is not part of the current network
228        continue
229
230      for inbound_layer, node_index, tensor_index, _ in node.iterate_inbound():
231        connections.append('{}[{}][{}]'.format(inbound_layer.name, node_index,
232                                               tensor_index))
233
234    name = layer.name
235    cls_name = layer.__class__.__name__
236    if not connections:
237      first_connection = ''
238    else:
239      first_connection = connections[0]
240    fields = [
241        name + ' (' + cls_name + ')', output_shape,
242        layer.count_params(), first_connection
243    ]
244    print_row(fields, positions)
245    if len(connections) > 1:
246      for i in range(1, len(connections)):
247        fields = ['', '', '', connections[i]]
248        print_row(fields, positions)
249
250  layers = model.layers
251  for i in range(len(layers)):
252    if sequential_like:
253      print_layer_summary(layers[i])
254    else:
255      print_layer_summary_with_connections(layers[i])
256    if i == len(layers) - 1:
257      print_fn('=' * line_length)
258    else:
259      print_fn('_' * line_length)
260
261  model._check_trainable_weights_consistency()
262  if hasattr(model, '_collected_trainable_weights'):
263    trainable_count = count_params(model._collected_trainable_weights)
264  else:
265    trainable_count = count_params(model.trainable_weights)
266
267  non_trainable_count = count_params(model.non_trainable_weights)
268
269  print_fn('Total params: {:,}'.format(trainable_count + non_trainable_count))
270  print_fn('Trainable params: {:,}'.format(trainable_count))
271  print_fn('Non-trainable params: {:,}'.format(non_trainable_count))
272  print_fn('_' * line_length)
273
274
275def gather_trainable_weights(trainable, sub_layers, extra_variables):
276  """Lists the trainable weights for an object with sub-layers.
277
278  Args:
279    trainable: Whether the object collecting the variables is trainable.
280    sub_layers: A flat list of Layer objects owned by this object, to collect
281      variables from.
282    extra_variables: Any extra variables to include. Their `.trainable` property
283      is used to categorize them.
284
285  Returns:
286    A list of collected trainable weights/variables.
287  """
288  if not trainable:
289    return []
290  weights = []
291  for layer in sub_layers:
292    weights += layer.trainable_weights
293  trainable_extra_variables = [
294      v for v in extra_variables if v.trainable]
295  return weights + trainable_extra_variables
296
297
298def gather_non_trainable_weights(trainable, sub_layers, extra_variables):
299  """Lists the non-trainable weights for an object with sub-layers.
300
301  Args:
302    trainable: Whether the object collecting the variables is trainable.
303    sub_layers: A flat list of Layer objects owned by this object, to collect
304      variables from.
305    extra_variables: Any extra variables to include. Their `.trainable` property
306      is used to categorize them.
307
308  Returns:
309    A list of collected non-trainable weights/variables.
310  """
311  trainable_extra_variables = []
312  non_trainable_extra_variables = []
313  for v in extra_variables:
314    if v.trainable:
315      trainable_extra_variables.append(v)
316    else:
317      non_trainable_extra_variables.append(v)
318  weights = []
319  for layer in sub_layers:
320    weights += layer.non_trainable_weights
321  if not trainable:
322    trainable_weights = []
323    for layer in sub_layers:
324      trainable_weights += layer.trainable_weights
325    return (trainable_weights + trainable_extra_variables
326            + weights + non_trainable_extra_variables)
327  return weights + non_trainable_extra_variables
328
329
330@deprecation.deprecated('2020-06-23',
331                        'The Theano kernel format is legacy; '
332                        'this utility will be removed.')
333@keras_export('keras.utils.convert_all_kernels_in_model')
334def convert_all_kernels_in_model(model):
335  """Converts all convolution kernels in a model from Theano to TensorFlow.
336
337  Also works from TensorFlow to Theano.
338
339  This is used for converting legacy Theano-saved model files.
340
341  Arguments:
342      model: target model for the conversion.
343  """
344  # Note: SeparableConvolution not included
345  # since only supported by TF.
346  conv_classes = {
347      'Conv1D',
348      'Conv2D',
349      'Conv3D',
350      'Conv2DTranspose',
351  }
352  to_assign = []
353  for layer in model.layers:
354    if layer.__class__.__name__ in conv_classes:
355      original_kernel = K.get_value(layer.kernel)
356      converted_kernel = convert_kernel(original_kernel)
357      to_assign.append((layer.kernel, converted_kernel))
358  K.batch_set_value(to_assign)
359
360
361def convert_dense_weights_data_format(dense,
362                                      previous_feature_map_shape,
363                                      target_data_format='channels_first'):
364  """Utility useful when changing a convnet's `data_format`.
365
366  When porting the weights of a convnet from one data format to the other,
367  if the convnet includes a `Flatten` layer
368  (applied to the last convolutional feature map)
369  followed by a `Dense` layer, the weights of that `Dense` layer
370  should be updated to reflect the new dimension ordering.
371
372  Arguments:
373      dense: The target `Dense` layer.
374      previous_feature_map_shape: A shape tuple of 3 integers,
375          e.g. `(512, 7, 7)`. The shape of the convolutional
376          feature map right before the `Flatten` layer that
377          came before the target `Dense` layer.
378      target_data_format: One of "channels_last", "channels_first".
379          Set it "channels_last"
380          if converting a "channels_first" model to "channels_last",
381          or reciprocally.
382  """
383  assert target_data_format in {'channels_last', 'channels_first'}
384  kernel, bias = dense.get_weights()
385  for i in range(kernel.shape[1]):
386    if target_data_format == 'channels_first':
387      c, h, w = previous_feature_map_shape
388      original_fm_shape = (h, w, c)
389      ki = kernel[:, i].reshape(original_fm_shape)
390      ki = np.transpose(ki, (2, 0, 1))  # last -> first
391    else:
392      h, w, c = previous_feature_map_shape
393      original_fm_shape = (c, h, w)
394      ki = kernel[:, i].reshape(original_fm_shape)
395      ki = np.transpose(ki, (1, 2, 0))  # first -> last
396    kernel[:, i] = np.reshape(ki, (np.prod(previous_feature_map_shape),))
397  dense.set_weights([kernel, bias])
398
399
400def is_builtin_layer(layer):
401  if not getattr(layer, '_keras_api_names', None):
402    return False
403
404  # Subclasses of `Layer` that are not exported inherit the export name
405  # of the base layer class.
406  return (layer._keras_api_names != ('keras.layers.Layer',) and
407          layer._keras_api_names_v1 != ('keras.layers.Layer',))
408