1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Keras initializers for TF 1.""" 16 17from tensorflow.python.framework import dtypes 18from tensorflow.python.ops import init_ops 19from tensorflow.python.util.tf_export import keras_export 20 21 22_v1_zeros_initializer = init_ops.Zeros 23_v1_ones_initializer = init_ops.Ones 24_v1_constant_initializer = init_ops.Constant 25_v1_variance_scaling_initializer = init_ops.VarianceScaling 26_v1_orthogonal_initializer = init_ops.Orthogonal 27_v1_identity = init_ops.Identity 28_v1_glorot_uniform_initializer = init_ops.GlorotUniform 29_v1_glorot_normal_initializer = init_ops.GlorotNormal 30 31keras_export(v1=['keras.initializers.Zeros', 'keras.initializers.zeros'])( 32 _v1_zeros_initializer) 33keras_export(v1=['keras.initializers.Ones', 'keras.initializers.ones'])( 34 _v1_ones_initializer) 35keras_export(v1=['keras.initializers.Constant', 'keras.initializers.constant'])( 36 _v1_constant_initializer) 37keras_export(v1=['keras.initializers.VarianceScaling'])( 38 _v1_variance_scaling_initializer) 39keras_export(v1=['keras.initializers.Orthogonal', 40 'keras.initializers.orthogonal'])(_v1_orthogonal_initializer) 41keras_export(v1=['keras.initializers.Identity', 42 'keras.initializers.identity'])(_v1_identity) 43keras_export(v1=['keras.initializers.glorot_uniform'])( 44 _v1_glorot_uniform_initializer) 45keras_export(v1=['keras.initializers.glorot_normal'])( 46 _v1_glorot_normal_initializer) 47 48 49@keras_export(v1=['keras.initializers.RandomNormal', 50 'keras.initializers.random_normal', 51 'keras.initializers.normal']) 52class RandomNormal(init_ops.RandomNormal): 53 54 def __init__(self, mean=0.0, stddev=0.05, seed=None, dtype=dtypes.float32): 55 super(RandomNormal, self).__init__( 56 mean=mean, stddev=stddev, seed=seed, dtype=dtype) 57 58 59@keras_export(v1=['keras.initializers.RandomUniform', 60 'keras.initializers.random_uniform', 61 'keras.initializers.uniform']) 62class RandomUniform(init_ops.RandomUniform): 63 64 def __init__(self, minval=-0.05, maxval=0.05, seed=None, 65 dtype=dtypes.float32): 66 super(RandomUniform, self).__init__( 67 minval=minval, maxval=maxval, seed=seed, dtype=dtype) 68 69 70@keras_export(v1=['keras.initializers.TruncatedNormal', 71 'keras.initializers.truncated_normal']) 72class TruncatedNormal(init_ops.TruncatedNormal): 73 74 def __init__(self, mean=0.0, stddev=0.05, seed=None, dtype=dtypes.float32): 75 super(TruncatedNormal, self).__init__( 76 mean=mean, stddev=stddev, seed=seed, dtype=dtype) 77 78 79@keras_export(v1=['keras.initializers.lecun_normal']) 80class LecunNormal(init_ops.VarianceScaling): 81 82 def __init__(self, seed=None): 83 super(LecunNormal, self).__init__( 84 scale=1., mode='fan_in', distribution='truncated_normal', seed=seed) 85 86 def get_config(self): 87 return {'seed': self.seed} 88 89 90@keras_export(v1=['keras.initializers.lecun_uniform']) 91class LecunUniform(init_ops.VarianceScaling): 92 93 def __init__(self, seed=None): 94 super(LecunUniform, self).__init__( 95 scale=1., mode='fan_in', distribution='uniform', seed=seed) 96 97 def get_config(self): 98 return {'seed': self.seed} 99 100 101@keras_export(v1=['keras.initializers.he_normal']) 102class HeNormal(init_ops.VarianceScaling): 103 104 def __init__(self, seed=None): 105 super(HeNormal, self).__init__( 106 scale=2., mode='fan_in', distribution='truncated_normal', seed=seed) 107 108 def get_config(self): 109 return {'seed': self.seed} 110 111 112@keras_export(v1=['keras.initializers.he_uniform']) 113class HeUniform(init_ops.VarianceScaling): 114 115 def __init__(self, seed=None): 116 super(HeUniform, self).__init__( 117 scale=2., mode='fan_in', distribution='uniform', seed=seed) 118 119 def get_config(self): 120 return {'seed': self.seed} 121