1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Keras initializer serialization / deserialization.""" 16 17import threading 18 19from tensorflow.python import tf2 20from tensorflow.python.keras.initializers import initializers_v1 21from tensorflow.python.keras.initializers import initializers_v2 22from tensorflow.python.keras.utils import generic_utils 23from tensorflow.python.keras.utils import tf_inspect as inspect 24from tensorflow.python.ops import init_ops 25from tensorflow.python.util.tf_export import keras_export 26 27 28# LOCAL.ALL_OBJECTS is meant to be a global mutable. Hence we need to make it 29# thread-local to avoid concurrent mutations. 30LOCAL = threading.local() 31 32 33def populate_deserializable_objects(): 34 """Populates dict ALL_OBJECTS with every built-in initializer. 35 """ 36 global LOCAL 37 if not hasattr(LOCAL, 'ALL_OBJECTS'): 38 LOCAL.ALL_OBJECTS = {} 39 LOCAL.GENERATED_WITH_V2 = None 40 41 if LOCAL.ALL_OBJECTS and LOCAL.GENERATED_WITH_V2 == tf2.enabled(): 42 # Objects dict is already generated for the proper TF version: 43 # do nothing. 44 return 45 46 LOCAL.ALL_OBJECTS = {} 47 LOCAL.GENERATED_WITH_V2 = tf2.enabled() 48 49 # Compatibility aliases (need to exist in both V1 and V2). 50 LOCAL.ALL_OBJECTS['ConstantV2'] = initializers_v2.Constant 51 LOCAL.ALL_OBJECTS['GlorotNormalV2'] = initializers_v2.GlorotNormal 52 LOCAL.ALL_OBJECTS['GlorotUniformV2'] = initializers_v2.GlorotUniform 53 LOCAL.ALL_OBJECTS['HeNormalV2'] = initializers_v2.HeNormal 54 LOCAL.ALL_OBJECTS['HeUniformV2'] = initializers_v2.HeUniform 55 LOCAL.ALL_OBJECTS['IdentityV2'] = initializers_v2.Identity 56 LOCAL.ALL_OBJECTS['LecunNormalV2'] = initializers_v2.LecunNormal 57 LOCAL.ALL_OBJECTS['LecunUniformV2'] = initializers_v2.LecunUniform 58 LOCAL.ALL_OBJECTS['OnesV2'] = initializers_v2.Ones 59 LOCAL.ALL_OBJECTS['OrthogonalV2'] = initializers_v2.Orthogonal 60 LOCAL.ALL_OBJECTS['RandomNormalV2'] = initializers_v2.RandomNormal 61 LOCAL.ALL_OBJECTS['RandomUniformV2'] = initializers_v2.RandomUniform 62 LOCAL.ALL_OBJECTS['TruncatedNormalV2'] = initializers_v2.TruncatedNormal 63 LOCAL.ALL_OBJECTS['VarianceScalingV2'] = initializers_v2.VarianceScaling 64 LOCAL.ALL_OBJECTS['ZerosV2'] = initializers_v2.Zeros 65 66 # Out of an abundance of caution we also include these aliases that have 67 # a non-zero probability of having been included in saved configs in the past. 68 LOCAL.ALL_OBJECTS['glorot_normalV2'] = initializers_v2.GlorotNormal 69 LOCAL.ALL_OBJECTS['glorot_uniformV2'] = initializers_v2.GlorotUniform 70 LOCAL.ALL_OBJECTS['he_normalV2'] = initializers_v2.HeNormal 71 LOCAL.ALL_OBJECTS['he_uniformV2'] = initializers_v2.HeUniform 72 LOCAL.ALL_OBJECTS['lecun_normalV2'] = initializers_v2.LecunNormal 73 LOCAL.ALL_OBJECTS['lecun_uniformV2'] = initializers_v2.LecunUniform 74 75 if tf2.enabled(): 76 # For V2, entries are generated automatically based on the content of 77 # initializers_v2.py. 78 v2_objs = {} 79 base_cls = initializers_v2.Initializer 80 generic_utils.populate_dict_with_module_objects( 81 v2_objs, 82 [initializers_v2], 83 obj_filter=lambda x: inspect.isclass(x) and issubclass(x, base_cls)) 84 for key, value in v2_objs.items(): 85 LOCAL.ALL_OBJECTS[key] = value 86 # Functional aliases. 87 LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value 88 else: 89 # V1 initializers. 90 v1_objs = { 91 'Constant': init_ops.Constant, 92 'GlorotNormal': init_ops.GlorotNormal, 93 'GlorotUniform': init_ops.GlorotUniform, 94 'Identity': init_ops.Identity, 95 'Ones': init_ops.Ones, 96 'Orthogonal': init_ops.Orthogonal, 97 'VarianceScaling': init_ops.VarianceScaling, 98 'Zeros': init_ops.Zeros, 99 'HeNormal': initializers_v1.HeNormal, 100 'HeUniform': initializers_v1.HeUniform, 101 'LecunNormal': initializers_v1.LecunNormal, 102 'LecunUniform': initializers_v1.LecunUniform, 103 'RandomNormal': initializers_v1.RandomNormal, 104 'RandomUniform': initializers_v1.RandomUniform, 105 'TruncatedNormal': initializers_v1.TruncatedNormal, 106 } 107 for key, value in v1_objs.items(): 108 LOCAL.ALL_OBJECTS[key] = value 109 # Functional aliases. 110 LOCAL.ALL_OBJECTS[generic_utils.to_snake_case(key)] = value 111 112 # More compatibility aliases. 113 LOCAL.ALL_OBJECTS['normal'] = LOCAL.ALL_OBJECTS['random_normal'] 114 LOCAL.ALL_OBJECTS['uniform'] = LOCAL.ALL_OBJECTS['random_uniform'] 115 LOCAL.ALL_OBJECTS['one'] = LOCAL.ALL_OBJECTS['ones'] 116 LOCAL.ALL_OBJECTS['zero'] = LOCAL.ALL_OBJECTS['zeros'] 117 118 119# For backwards compatibility, we populate this file with the objects 120# from ALL_OBJECTS. We make no guarantees as to whether these objects will 121# using their correct version. 122populate_deserializable_objects() 123globals().update(LOCAL.ALL_OBJECTS) 124 125# Utility functions 126 127 128@keras_export('keras.initializers.serialize') 129def serialize(initializer): 130 return generic_utils.serialize_keras_object(initializer) 131 132 133@keras_export('keras.initializers.deserialize') 134def deserialize(config, custom_objects=None): 135 """Return an `Initializer` object from its config.""" 136 populate_deserializable_objects() 137 return generic_utils.deserialize_keras_object( 138 config, 139 module_objects=LOCAL.ALL_OBJECTS, 140 custom_objects=custom_objects, 141 printable_module_name='initializer') 142 143 144@keras_export('keras.initializers.get') 145def get(identifier): 146 """Retrieve a Keras initializer by the identifier. 147 148 The `identifier` may be the string name of a initializers function or class ( 149 case-sensitively). 150 151 >>> identifier = 'Ones' 152 >>> tf.keras.initializers.deserialize(identifier) 153 <...keras.initializers.initializers_v2.Ones...> 154 155 You can also specify `config` of the initializer to this function by passing 156 dict containing `class_name` and `config` as an identifier. Also note that the 157 `class_name` must map to a `Initializer` class. 158 159 >>> cfg = {'class_name': 'Ones', 'config': {}} 160 >>> tf.keras.initializers.deserialize(cfg) 161 <...keras.initializers.initializers_v2.Ones...> 162 163 In the case that the `identifier` is a class, this method will return a new 164 instance of the class by its constructor. 165 166 Args: 167 identifier: String or dict that contains the initializer name or 168 configurations. 169 170 Returns: 171 Initializer instance base on the input identifier. 172 173 Raises: 174 ValueError: If the input identifier is not a supported type or in a bad 175 format. 176 """ 177 178 if identifier is None: 179 return None 180 if isinstance(identifier, dict): 181 return deserialize(identifier) 182 elif isinstance(identifier, str): 183 identifier = str(identifier) 184 return deserialize(identifier) 185 elif callable(identifier): 186 if inspect.isclass(identifier): 187 identifier = identifier() 188 return identifier 189 else: 190 raise ValueError('Could not interpret initializer identifier: ' + 191 str(identifier)) 192