1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Keras initializer serialization / deserialization. 16""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import six 22 23from tensorflow.python import tf2 24from tensorflow.python.framework import dtypes 25from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object 26from tensorflow.python.keras.utils.generic_utils import serialize_keras_object 27from tensorflow.python.ops import init_ops_v2 28 29# These imports are brought in so that keras.initializers.deserialize 30# has them available in module_objects. 31from tensorflow.python.ops.init_ops import Constant 32from tensorflow.python.ops.init_ops import GlorotNormal 33from tensorflow.python.ops.init_ops import GlorotUniform 34from tensorflow.python.ops.init_ops import he_normal # pylint: disable=unused-import 35from tensorflow.python.ops.init_ops import he_uniform # pylint: disable=unused-import 36from tensorflow.python.ops.init_ops import Identity 37from tensorflow.python.ops.init_ops import Initializer # pylint: disable=unused-import 38from tensorflow.python.ops.init_ops import lecun_normal # pylint: disable=unused-import 39from tensorflow.python.ops.init_ops import lecun_uniform # pylint: disable=unused-import 40from tensorflow.python.ops.init_ops import Ones 41from tensorflow.python.ops.init_ops import Orthogonal 42from tensorflow.python.ops.init_ops import RandomNormal as TFRandomNormal 43from tensorflow.python.ops.init_ops import RandomUniform as TFRandomUniform 44from tensorflow.python.ops.init_ops import TruncatedNormal as TFTruncatedNormal 45from tensorflow.python.ops.init_ops import VarianceScaling # pylint: disable=unused-import 46from tensorflow.python.ops.init_ops import Zeros 47# pylint: disable=unused-import, disable=line-too-long 48from tensorflow.python.ops.init_ops_v2 import Constant as ConstantV2 49from tensorflow.python.ops.init_ops_v2 import GlorotNormal as GlorotNormalV2 50from tensorflow.python.ops.init_ops_v2 import GlorotUniform as GlorotUniformV2 51from tensorflow.python.ops.init_ops_v2 import he_normal as he_normalV2 52from tensorflow.python.ops.init_ops_v2 import he_uniform as he_uniformV2 53from tensorflow.python.ops.init_ops_v2 import Identity as IdentityV2 54from tensorflow.python.ops.init_ops_v2 import Initializer as InitializerV2 55from tensorflow.python.ops.init_ops_v2 import lecun_normal as lecun_normalV2 56from tensorflow.python.ops.init_ops_v2 import lecun_uniform as lecun_uniformV2 57from tensorflow.python.ops.init_ops_v2 import Ones as OnesV2 58from tensorflow.python.ops.init_ops_v2 import Orthogonal as OrthogonalV2 59from tensorflow.python.ops.init_ops_v2 import RandomNormal as RandomNormalV2 60from tensorflow.python.ops.init_ops_v2 import RandomUniform as RandomUniformV2 61from tensorflow.python.ops.init_ops_v2 import TruncatedNormal as TruncatedNormalV2 62from tensorflow.python.ops.init_ops_v2 import VarianceScaling as VarianceScalingV2 63from tensorflow.python.ops.init_ops_v2 import Zeros as ZerosV2 64# pylint: enable=unused-import, enable=line-too-long 65 66from tensorflow.python.util.tf_export import keras_export 67 68 69@keras_export(v1=['keras.initializers.TruncatedNormal', 70 'keras.initializers.truncated_normal']) 71class TruncatedNormal(TFTruncatedNormal): 72 """Initializer that generates a truncated normal distribution. 73 74 These values are similar to values from a `random_normal_initializer` 75 except that values more than two standard deviations from the mean 76 are discarded and re-drawn. This is the recommended initializer for 77 neural network weights and filters. 78 79 Args: 80 mean: a python scalar or a scalar tensor. Mean of the random values to 81 generate. Defaults to 0. 82 stddev: a python scalar or a scalar tensor. Standard deviation of the random 83 values to generate. Defaults to 0.05. 84 seed: A Python integer. Used to create random seeds. See 85 `tf.compat.v1.set_random_seed` for behavior. 86 dtype: The data type. Only floating point types are supported. 87 88 Returns: 89 A TruncatedNormal instance. 90 """ 91 92 def __init__(self, mean=0.0, stddev=0.05, seed=None, dtype=dtypes.float32): 93 super(TruncatedNormal, self).__init__( 94 mean=mean, stddev=stddev, seed=seed, dtype=dtype) 95 96 97@keras_export(v1=['keras.initializers.RandomUniform', 98 'keras.initializers.uniform', 99 'keras.initializers.random_uniform']) 100class RandomUniform(TFRandomUniform): 101 """Initializer that generates tensors with a uniform distribution. 102 103 Args: 104 minval: A python scalar or a scalar tensor. Lower bound of the range of 105 random values to generate. Defaults to -0.05. 106 maxval: A python scalar or a scalar tensor. Upper bound of the range of 107 random values to generate. Defaults to 0.05. 108 seed: A Python integer. Used to create random seeds. See 109 `tf.compat.v1.set_random_seed` for behavior. 110 dtype: The data type. 111 112 Returns: 113 A RandomUniform instance. 114 """ 115 116 def __init__(self, minval=-0.05, maxval=0.05, seed=None, 117 dtype=dtypes.float32): 118 super(RandomUniform, self).__init__( 119 minval=minval, maxval=maxval, seed=seed, dtype=dtype) 120 121 122@keras_export(v1=['keras.initializers.RandomNormal', 123 'keras.initializers.normal', 124 'keras.initializers.random_normal']) 125class RandomNormal(TFRandomNormal): 126 """Initializer that generates tensors with a normal distribution. 127 128 Args: 129 mean: a python scalar or a scalar tensor. Mean of the random values to 130 generate. Defaults to 0. 131 stddev: a python scalar or a scalar tensor. Standard deviation of the random 132 values to generate. Defaults to 0.05. 133 seed: A Python integer. Used to create random seeds. See 134 `tf.compat.v1.set_random_seed` for behavior. 135 dtype: The data type. Only floating point types are supported. 136 137 Returns: 138 RandomNormal instance. 139 """ 140 141 def __init__(self, mean=0.0, stddev=0.05, seed=None, dtype=dtypes.float32): 142 super(RandomNormal, self).__init__( 143 mean=mean, stddev=stddev, seed=seed, dtype=dtype) 144 145 146# Compatibility aliases 147 148# pylint: disable=invalid-name 149zero = zeros = Zeros 150one = ones = Ones 151constant = Constant 152uniform = random_uniform = RandomUniform 153normal = random_normal = RandomNormal 154truncated_normal = TruncatedNormal 155identity = Identity 156orthogonal = Orthogonal 157glorot_normal = GlorotNormal 158glorot_uniform = GlorotUniform 159 160 161# Utility functions 162 163 164@keras_export('keras.initializers.serialize') 165def serialize(initializer): 166 return serialize_keras_object(initializer) 167 168 169@keras_export('keras.initializers.deserialize') 170def deserialize(config, custom_objects=None): 171 """Return an `Initializer` object from its config.""" 172 if tf2.enabled(): 173 # Class names are the same for V1 and V2 but the V2 classes 174 # are aliased in this file so we need to grab them directly 175 # from `init_ops_v2`. 176 module_objects = { 177 obj_name: getattr(init_ops_v2, obj_name) 178 for obj_name in dir(init_ops_v2) 179 } 180 else: 181 module_objects = globals() 182 return deserialize_keras_object( 183 config, 184 module_objects=module_objects, 185 custom_objects=custom_objects, 186 printable_module_name='initializer') 187 188 189@keras_export('keras.initializers.get') 190def get(identifier): 191 if identifier is None: 192 return None 193 if isinstance(identifier, dict): 194 return deserialize(identifier) 195 elif isinstance(identifier, six.string_types): 196 identifier = str(identifier) 197 # We have to special-case functions that return classes. 198 # TODO(omalleyt): Turn these into classes or class aliases. 199 special_cases = ['he_normal', 'he_uniform', 'lecun_normal', 'lecun_uniform'] 200 if identifier in special_cases: 201 # Treat like a class. 202 return deserialize({'class_name': identifier, 'config': {}}) 203 return deserialize(identifier) 204 elif callable(identifier): 205 return identifier 206 else: 207 raise ValueError('Could not interpret initializer identifier: ' + 208 str(identifier)) 209 210 211# pylint: enable=invalid-name 212