• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras initializer serialization / deserialization.
16"""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import six
22
23from tensorflow.python import tf2
24from tensorflow.python.framework import dtypes
25from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
26from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
27from tensorflow.python.ops import init_ops_v2
28
29# These imports are brought in so that keras.initializers.deserialize
30# has them available in module_objects.
31from tensorflow.python.ops.init_ops import Constant
32from tensorflow.python.ops.init_ops import GlorotNormal
33from tensorflow.python.ops.init_ops import GlorotUniform
34from tensorflow.python.ops.init_ops import he_normal  # pylint: disable=unused-import
35from tensorflow.python.ops.init_ops import he_uniform  # pylint: disable=unused-import
36from tensorflow.python.ops.init_ops import Identity
37from tensorflow.python.ops.init_ops import Initializer  # pylint: disable=unused-import
38from tensorflow.python.ops.init_ops import lecun_normal  # pylint: disable=unused-import
39from tensorflow.python.ops.init_ops import lecun_uniform  # pylint: disable=unused-import
40from tensorflow.python.ops.init_ops import Ones
41from tensorflow.python.ops.init_ops import Orthogonal
42from tensorflow.python.ops.init_ops import RandomNormal as TFRandomNormal
43from tensorflow.python.ops.init_ops import RandomUniform as TFRandomUniform
44from tensorflow.python.ops.init_ops import TruncatedNormal as TFTruncatedNormal
45from tensorflow.python.ops.init_ops import VarianceScaling  # pylint: disable=unused-import
46from tensorflow.python.ops.init_ops import Zeros
47# pylint: disable=unused-import, disable=line-too-long
48from tensorflow.python.ops.init_ops_v2 import Constant as ConstantV2
49from tensorflow.python.ops.init_ops_v2 import GlorotNormal as GlorotNormalV2
50from tensorflow.python.ops.init_ops_v2 import GlorotUniform as GlorotUniformV2
51from tensorflow.python.ops.init_ops_v2 import he_normal as he_normalV2
52from tensorflow.python.ops.init_ops_v2 import he_uniform as he_uniformV2
53from tensorflow.python.ops.init_ops_v2 import Identity as IdentityV2
54from tensorflow.python.ops.init_ops_v2 import Initializer as InitializerV2
55from tensorflow.python.ops.init_ops_v2 import lecun_normal as lecun_normalV2
56from tensorflow.python.ops.init_ops_v2 import lecun_uniform  as lecun_uniformV2
57from tensorflow.python.ops.init_ops_v2 import Ones as OnesV2
58from tensorflow.python.ops.init_ops_v2 import Orthogonal as OrthogonalV2
59from tensorflow.python.ops.init_ops_v2 import RandomNormal as RandomNormalV2
60from tensorflow.python.ops.init_ops_v2 import RandomUniform as RandomUniformV2
61from tensorflow.python.ops.init_ops_v2 import TruncatedNormal as TruncatedNormalV2
62from tensorflow.python.ops.init_ops_v2 import VarianceScaling as VarianceScalingV2
63from tensorflow.python.ops.init_ops_v2 import Zeros as ZerosV2
64# pylint: enable=unused-import, enable=line-too-long
65
66from tensorflow.python.util.tf_export import keras_export
67
68
69@keras_export(v1=['keras.initializers.TruncatedNormal',
70                  'keras.initializers.truncated_normal'])
71class TruncatedNormal(TFTruncatedNormal):
72  """Initializer that generates a truncated normal distribution.
73
74  These values are similar to values from a `random_normal_initializer`
75  except that values more than two standard deviations from the mean
76  are discarded and re-drawn. This is the recommended initializer for
77  neural network weights and filters.
78
79  Args:
80    mean: a python scalar or a scalar tensor. Mean of the random values to
81      generate. Defaults to 0.
82    stddev: a python scalar or a scalar tensor. Standard deviation of the random
83      values to generate. Defaults to 0.05.
84    seed: A Python integer. Used to create random seeds. See
85      `tf.compat.v1.set_random_seed` for behavior.
86    dtype: The data type. Only floating point types are supported.
87
88  Returns:
89    A TruncatedNormal instance.
90  """
91
92  def __init__(self, mean=0.0, stddev=0.05, seed=None, dtype=dtypes.float32):
93    super(TruncatedNormal, self).__init__(
94        mean=mean, stddev=stddev, seed=seed, dtype=dtype)
95
96
97@keras_export(v1=['keras.initializers.RandomUniform',
98                  'keras.initializers.uniform',
99                  'keras.initializers.random_uniform'])
100class RandomUniform(TFRandomUniform):
101  """Initializer that generates tensors with a uniform distribution.
102
103  Args:
104    minval: A python scalar or a scalar tensor. Lower bound of the range of
105      random values to generate. Defaults to -0.05.
106    maxval: A python scalar or a scalar tensor. Upper bound of the range of
107      random values to generate. Defaults to 0.05.
108    seed: A Python integer. Used to create random seeds. See
109      `tf.compat.v1.set_random_seed` for behavior.
110    dtype: The data type.
111
112  Returns:
113    A RandomUniform instance.
114  """
115
116  def __init__(self, minval=-0.05, maxval=0.05, seed=None,
117               dtype=dtypes.float32):
118    super(RandomUniform, self).__init__(
119        minval=minval, maxval=maxval, seed=seed, dtype=dtype)
120
121
122@keras_export(v1=['keras.initializers.RandomNormal',
123                  'keras.initializers.normal',
124                  'keras.initializers.random_normal'])
125class RandomNormal(TFRandomNormal):
126  """Initializer that generates tensors with a normal distribution.
127
128  Args:
129    mean: a python scalar or a scalar tensor. Mean of the random values to
130      generate. Defaults to 0.
131    stddev: a python scalar or a scalar tensor. Standard deviation of the random
132      values to generate. Defaults to 0.05.
133    seed: A Python integer. Used to create random seeds. See
134      `tf.compat.v1.set_random_seed` for behavior.
135    dtype: The data type. Only floating point types are supported.
136
137  Returns:
138      RandomNormal instance.
139  """
140
141  def __init__(self, mean=0.0, stddev=0.05, seed=None, dtype=dtypes.float32):
142    super(RandomNormal, self).__init__(
143        mean=mean, stddev=stddev, seed=seed, dtype=dtype)
144
145
146# Compatibility aliases
147
148# pylint: disable=invalid-name
149zero = zeros = Zeros
150one = ones = Ones
151constant = Constant
152uniform = random_uniform = RandomUniform
153normal = random_normal = RandomNormal
154truncated_normal = TruncatedNormal
155identity = Identity
156orthogonal = Orthogonal
157glorot_normal = GlorotNormal
158glorot_uniform = GlorotUniform
159
160
161# Utility functions
162
163
164@keras_export('keras.initializers.serialize')
165def serialize(initializer):
166  return serialize_keras_object(initializer)
167
168
169@keras_export('keras.initializers.deserialize')
170def deserialize(config, custom_objects=None):
171  """Return an `Initializer` object from its config."""
172  if tf2.enabled():
173    # Class names are the same for V1 and V2 but the V2 classes
174    # are aliased in this file so we need to grab them directly
175    # from `init_ops_v2`.
176    module_objects = {
177        obj_name: getattr(init_ops_v2, obj_name)
178        for obj_name in dir(init_ops_v2)
179    }
180  else:
181    module_objects = globals()
182  return deserialize_keras_object(
183      config,
184      module_objects=module_objects,
185      custom_objects=custom_objects,
186      printable_module_name='initializer')
187
188
189@keras_export('keras.initializers.get')
190def get(identifier):
191  if identifier is None:
192    return None
193  if isinstance(identifier, dict):
194    return deserialize(identifier)
195  elif isinstance(identifier, six.string_types):
196    identifier = str(identifier)
197    # We have to special-case functions that return classes.
198    # TODO(omalleyt): Turn these into classes or class aliases.
199    special_cases = ['he_normal', 'he_uniform', 'lecun_normal', 'lecun_uniform']
200    if identifier in special_cases:
201      # Treat like a class.
202      return deserialize({'class_name': identifier, 'config': {}})
203    return deserialize(identifier)
204  elif callable(identifier):
205    return identifier
206  else:
207    raise ValueError('Could not interpret initializer identifier: ' +
208                     str(identifier))
209
210
211# pylint: enable=invalid-name
212