• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities related to layer/model functionality."""
16
17# TODO(b/110718070): Move these functions back to tensorflow/python/keras/utils
18# once __init__ files no longer require all of tf.keras to be imported together.
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24from tensorflow.python.training.tracking import object_identity
25
26
27def is_layer(obj):
28  """Implicit check for Layer-like objects."""
29  # TODO(b/110718070): Replace with isinstance(obj, base_layer.Layer).
30  return hasattr(obj, "_is_layer")
31
32
33def has_weights(obj):
34  """Implicit check for Layer-like objects."""
35  # TODO(b/110718070): Replace with isinstance(obj, base_layer.Layer).
36  return (hasattr(obj, "trainable_weights")
37          and hasattr(obj, "non_trainable_weights"))
38
39
40def filter_empty_layer_containers(layer_list):
41  """Filter out empty Layer-like containers and uniquify."""
42  existing = object_identity.ObjectIdentitySet()
43  to_visit = layer_list[::-1]
44  filtered = []
45  while to_visit:
46    obj = to_visit.pop()
47    if obj in existing:
48      continue
49    existing.add(obj)
50    if is_layer(obj):
51      filtered.append(obj)
52    elif hasattr(obj, "layers"):
53      # Trackable data structures will not show up in ".layers" lists, but
54      # the layers they contain will.
55      to_visit.extend(obj.layers[::-1])
56  return filtered
57
58
59def gather_trainable_weights(trainable, sub_layers, extra_variables):
60  """Lists the trainable weights for an object with sub-layers.
61
62  Args:
63    trainable: Whether the object collecting the variables is trainable.
64    sub_layers: A flat list of Layer objects owned by this object, to collect
65      variables from.
66    extra_variables: Any extra variables to include. Their `.trainable` property
67      is used to categorize them.
68
69  Returns:
70    A list of collected trainable weights/variables.
71  """
72  if not trainable:
73    return []
74  weights = []
75  for layer in sub_layers:
76    weights += layer.trainable_weights
77  trainable_extra_variables = [
78      v for v in extra_variables if v.trainable]
79  return weights + trainable_extra_variables
80
81
82def gather_non_trainable_weights(trainable, sub_layers, extra_variables):
83  """Lists the non-trainable weights for an object with sub-layers.
84
85  Args:
86    trainable: Whether the object collecting the variables is trainable.
87    sub_layers: A flat list of Layer objects owned by this object, to collect
88      variables from.
89    extra_variables: Any extra variables to include. Their `.trainable` property
90      is used to categorize them.
91
92  Returns:
93    A list of collected non-trainable weights/variables.
94  """
95  trainable_extra_variables = []
96  non_trainable_extra_variables = []
97  for v in extra_variables:
98    if v.trainable:
99      trainable_extra_variables.append(v)
100    else:
101      non_trainable_extra_variables.append(v)
102  weights = []
103  for layer in sub_layers:
104    weights += layer.non_trainable_weights
105  if not trainable:
106    trainable_weights = []
107    for layer in sub_layers:
108      trainable_weights += layer.trainable_weights
109    return (trainable_weights + trainable_extra_variables
110            + weights + non_trainable_extra_variables)
111  return weights + non_trainable_extra_variables
112