1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities related to layer/model functionality.""" 16 17# TODO(b/110718070): Move these functions back to tensorflow/python/keras/utils 18# once __init__ files no longer require all of tf.keras to be imported together. 19 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24from tensorflow.python.training.tracking import object_identity 25 26 27def is_layer(obj): 28 """Implicit check for Layer-like objects.""" 29 # TODO(b/110718070): Replace with isinstance(obj, base_layer.Layer). 30 return hasattr(obj, "_is_layer") 31 32 33def has_weights(obj): 34 """Implicit check for Layer-like objects.""" 35 # TODO(b/110718070): Replace with isinstance(obj, base_layer.Layer). 36 return (hasattr(obj, "trainable_weights") 37 and hasattr(obj, "non_trainable_weights")) 38 39 40def filter_empty_layer_containers(layer_list): 41 """Filter out empty Layer-like containers and uniquify.""" 42 existing = object_identity.ObjectIdentitySet() 43 to_visit = layer_list[::-1] 44 filtered = [] 45 while to_visit: 46 obj = to_visit.pop() 47 if obj in existing: 48 continue 49 existing.add(obj) 50 if is_layer(obj): 51 filtered.append(obj) 52 elif hasattr(obj, "layers"): 53 # Trackable data structures will not show up in ".layers" lists, but 54 # the layers they contain will. 55 to_visit.extend(obj.layers[::-1]) 56 return filtered 57 58 59def gather_trainable_weights(trainable, sub_layers, extra_variables): 60 """Lists the trainable weights for an object with sub-layers. 61 62 Args: 63 trainable: Whether the object collecting the variables is trainable. 64 sub_layers: A flat list of Layer objects owned by this object, to collect 65 variables from. 66 extra_variables: Any extra variables to include. Their `.trainable` property 67 is used to categorize them. 68 69 Returns: 70 A list of collected trainable weights/variables. 71 """ 72 if not trainable: 73 return [] 74 weights = [] 75 for layer in sub_layers: 76 weights += layer.trainable_weights 77 trainable_extra_variables = [ 78 v for v in extra_variables if v.trainable] 79 return weights + trainable_extra_variables 80 81 82def gather_non_trainable_weights(trainable, sub_layers, extra_variables): 83 """Lists the non-trainable weights for an object with sub-layers. 84 85 Args: 86 trainable: Whether the object collecting the variables is trainable. 87 sub_layers: A flat list of Layer objects owned by this object, to collect 88 variables from. 89 extra_variables: Any extra variables to include. Their `.trainable` property 90 is used to categorize them. 91 92 Returns: 93 A list of collected non-trainable weights/variables. 94 """ 95 trainable_extra_variables = [] 96 non_trainable_extra_variables = [] 97 for v in extra_variables: 98 if v.trainable: 99 trainable_extra_variables.append(v) 100 else: 101 non_trainable_extra_variables.append(v) 102 weights = [] 103 for layer in sub_layers: 104 weights += layer.non_trainable_weights 105 if not trainable: 106 trainable_weights = [] 107 for layer in sub_layers: 108 trainable_weights += layer.trainable_weights 109 return (trainable_weights + trainable_extra_variables 110 + weights + non_trainable_extra_variables) 111 return weights + non_trainable_extra_variables 112