• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras initializers for TF 2."""
16# pylint: disable=g-classes-have-attributes
17
18import math
19
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import dtypes
22from tensorflow.python.keras import backend
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import gen_linalg_ops
25from tensorflow.python.ops import linalg_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops import random_ops
28from tensorflow.python.ops import stateless_random_ops
29from tensorflow.python.util.tf_export import keras_export
30
31_PARTITION_SHAPE = 'partition_shape'
32_PARTITION_OFFSET = 'partition_offset'
33
34
35@keras_export('keras.initializers.Initializer')
36class Initializer(object):
37  """Initializer base class: all Keras initializers inherit from this class.
38
39  Initializers should implement a `__call__` method with the following
40  signature:
41
42  ```python
43  def __call__(self, shape, dtype=None, **kwargs):
44    # returns a tensor of shape `shape` and dtype `dtype`
45    # containing values drawn from a distribution of your choice.
46  ```
47
48  Optionally, you an also implement the method `get_config` and the class
49  method `from_config` in order to support serialization -- just like with
50  any Keras object.
51
52  Here's a simple example: a random normal initializer.
53
54  ```python
55  import tensorflow as tf
56
57  class ExampleRandomNormal(tf.keras.initializers.Initializer):
58
59    def __init__(self, mean, stddev):
60      self.mean = mean
61      self.stddev = stddev
62
63    def __call__(self, shape, dtype=None, **kwargs):
64      return tf.random.normal(
65          shape, mean=self.mean, stddev=self.stddev, dtype=dtype)
66
67    def get_config(self):  # To support serialization
68      return {"mean": self.mean, "stddev": self.stddev}
69  ```
70
71  Note that we don't have to implement `from_config` in the example above since
72  the constructor arguments of the class the keys in the config returned by
73  `get_config` are the same. In this case, the default `from_config`
74  works fine.
75  """
76
77  def __call__(self, shape, dtype=None, **kwargs):
78    """Returns a tensor object initialized as specified by the initializer.
79
80    Args:
81      shape: Shape of the tensor.
82      dtype: Optional dtype of the tensor.
83      **kwargs: Additional keyword arguments.
84    """
85    raise NotImplementedError
86
87  def get_config(self):
88    """Returns the configuration of the initializer as a JSON-serializable dict.
89
90    Returns:
91      A JSON-serializable Python dict.
92    """
93    return {}
94
95  @classmethod
96  def from_config(cls, config):
97    """Instantiates an initializer from a configuration dictionary.
98
99    Example:
100
101    ```python
102    initializer = RandomUniform(-1, 1)
103    config = initializer.get_config()
104    initializer = RandomUniform.from_config(config)
105    ```
106
107    Args:
108      config: A Python dictionary, the output of `get_config`.
109
110    Returns:
111      A `tf.keras.initializers.Initializer` instance.
112    """
113    config.pop('dtype', None)
114    return cls(**config)
115
116
117@keras_export('keras.initializers.Zeros', 'keras.initializers.zeros', v1=[])
118class Zeros(Initializer):
119  """Initializer that generates tensors initialized to 0.
120
121  Also available via the shortcut function `tf.keras.initializers.zeros`.
122
123  Examples:
124
125  >>> # Standalone usage:
126  >>> initializer = tf.keras.initializers.Zeros()
127  >>> values = initializer(shape=(2, 2))
128
129  >>> # Usage in a Keras layer:
130  >>> initializer = tf.keras.initializers.Zeros()
131  >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer)
132  """
133
134  def __call__(self, shape, dtype=None, **kwargs):
135    """Returns a tensor object initialized as specified by the initializer.
136
137    Args:
138      shape: Shape of the tensor.
139      dtype: Optional dtype of the tensor. Only numeric or boolean dtypes are
140       supported. If not specified, `tf.keras.backend.floatx()` is used,
141       which default to `float32` unless you configured it otherwise
142       (via `tf.keras.backend.set_floatx(float_dtype)`).
143      **kwargs: Additional keyword arguments.
144    """
145    _validate_kwargs(self.__class__.__name__, kwargs)
146    dtype = _get_dtype(dtype)
147    if not dtype.is_numpy_compatible or dtype == dtypes.string:
148      raise ValueError('Expected numeric or boolean dtype, got %s.' % dtype)
149    if _PARTITION_SHAPE in kwargs:
150      shape = kwargs[_PARTITION_SHAPE]
151    return array_ops.zeros(shape, dtype)
152
153
154@keras_export('keras.initializers.Ones', 'keras.initializers.ones', v1=[])
155class Ones(Initializer):
156  """Initializer that generates tensors initialized to 1.
157
158  Also available via the shortcut function `tf.keras.initializers.ones`.
159
160  Examples:
161
162  >>> # Standalone usage:
163  >>> initializer = tf.keras.initializers.Ones()
164  >>> values = initializer(shape=(2, 2))
165
166  >>> # Usage in a Keras layer:
167  >>> initializer = tf.keras.initializers.Ones()
168  >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer)
169  """
170
171  def __call__(self, shape, dtype=None, **kwargs):
172    """Returns a tensor object initialized as specified by the initializer.
173
174    Args:
175      shape: Shape of the tensor.
176      dtype: Optional dtype of the tensor. Only numeric or boolean dtypes are
177       supported. If not specified, `tf.keras.backend.floatx()` is used,
178       which default to `float32` unless you configured it otherwise
179       (via `tf.keras.backend.set_floatx(float_dtype)`).
180      **kwargs: Additional keyword arguments.
181    """
182    _validate_kwargs(self.__class__.__name__, kwargs)
183    dtype = _get_dtype(dtype)
184    if not dtype.is_numpy_compatible or dtype == dtypes.string:
185      raise ValueError('Expected numeric or boolean dtype, got %s.' % dtype)
186    if _PARTITION_SHAPE in kwargs:
187      shape = kwargs[_PARTITION_SHAPE]
188    return array_ops.ones(shape, dtype)
189
190
191@keras_export('keras.initializers.Constant',
192              'keras.initializers.constant',
193              v1=[])
194class Constant(Initializer):
195  """Initializer that generates tensors with constant values.
196
197  Also available via the shortcut function `tf.keras.initializers.constant`.
198
199  Only scalar values are allowed.
200  The constant value provided must be convertible to the dtype requested
201  when calling the initializer.
202
203  Examples:
204
205  >>> # Standalone usage:
206  >>> initializer = tf.keras.initializers.Constant(3.)
207  >>> values = initializer(shape=(2, 2))
208
209  >>> # Usage in a Keras layer:
210  >>> initializer = tf.keras.initializers.Constant(3.)
211  >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer)
212
213  Args:
214    value: A Python scalar.
215  """
216
217  def __init__(self, value=0):
218    self.value = value
219
220  def __call__(self, shape, dtype=None, **kwargs):
221    """Returns a tensor object initialized to `self.value`.
222
223    Args:
224      shape: Shape of the tensor.
225      dtype: Optional dtype of the tensor. If not specified,
226       `tf.keras.backend.floatx()` is used,
227       which default to `float32` unless you configured it otherwise
228       (via `tf.keras.backend.set_floatx(float_dtype)`).
229      **kwargs: Additional keyword arguments.
230    """
231    del kwargs
232    return constant_op.constant(
233        self.value, dtype=_get_dtype(dtype), shape=shape)
234
235  def get_config(self):
236    return {'value': self.value}
237
238
239@keras_export('keras.initializers.RandomUniform',
240              'keras.initializers.random_uniform',
241              v1=[])
242class RandomUniform(Initializer):
243  """Initializer that generates tensors with a uniform distribution.
244
245  Also available via the shortcut function
246  `tf.keras.initializers.random_uniform`.
247
248  Examples:
249
250  >>> # Standalone usage:
251  >>> initializer = tf.keras.initializers.RandomUniform(minval=0., maxval=1.)
252  >>> values = initializer(shape=(2, 2))
253
254  >>> # Usage in a Keras layer:
255  >>> initializer = tf.keras.initializers.RandomUniform(minval=0., maxval=1.)
256  >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer)
257
258  Args:
259    minval: A python scalar or a scalar tensor. Lower bound of the range of
260      random values to generate (inclusive).
261    maxval: A python scalar or a scalar tensor. Upper bound of the range of
262      random values to generate (exclusive).
263    seed: A Python integer. An initializer created with a given seed will
264      always produce the same random tensor for a given shape and dtype.
265  """
266
267  def __init__(self, minval=-0.05, maxval=0.05, seed=None):
268    self.minval = minval
269    self.maxval = maxval
270    self.seed = seed
271    self._random_generator = _RandomGenerator(seed)
272
273  def __call__(self, shape, dtype=None, **kwargs):
274    """Returns a tensor object initialized as specified by the initializer.
275
276    Args:
277      shape: Shape of the tensor.
278      dtype: Optional dtype of the tensor. Only floating point and integer
279      types are supported. If not specified,
280        `tf.keras.backend.floatx()` is used,
281       which default to `float32` unless you configured it otherwise
282       (via `tf.keras.backend.set_floatx(float_dtype)`).
283      **kwargs: Additional keyword arguments.
284    """
285    _validate_kwargs(self.__class__.__name__, kwargs)
286    dtype = _get_dtype(dtype)
287    if not dtype.is_floating and not dtype.is_integer:
288      raise ValueError('Expected float or integer dtype, got %s.' % dtype)
289    if _PARTITION_SHAPE in kwargs:
290      shape = kwargs[_PARTITION_SHAPE]
291    return self._random_generator.random_uniform(shape, self.minval,
292                                                 self.maxval, dtype)
293
294  def get_config(self):
295    return {
296        'minval': self.minval,
297        'maxval': self.maxval,
298        'seed': self.seed
299    }
300
301
302@keras_export('keras.initializers.RandomNormal',
303              'keras.initializers.random_normal',
304              v1=[])
305class RandomNormal(Initializer):
306  """Initializer that generates tensors with a normal distribution.
307
308  Also available via the shortcut function
309  `tf.keras.initializers.random_normal`.
310
311  Examples:
312
313  >>> # Standalone usage:
314  >>> initializer = tf.keras.initializers.RandomNormal(mean=0., stddev=1.)
315  >>> values = initializer(shape=(2, 2))
316
317  >>> # Usage in a Keras layer:
318  >>> initializer = tf.keras.initializers.RandomNormal(mean=0., stddev=1.)
319  >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer)
320
321  Args:
322    mean: a python scalar or a scalar tensor. Mean of the random values to
323      generate.
324    stddev: a python scalar or a scalar tensor. Standard deviation of the random
325      values to generate.
326    seed: A Python integer. An initializer created with a given seed will
327      always produce the same random tensor for a given shape and dtype.
328  """
329
330  def __init__(self, mean=0.0, stddev=0.05, seed=None):
331    self.mean = mean
332    self.stddev = stddev
333    self.seed = seed
334    self._random_generator = _RandomGenerator(seed)
335
336  def __call__(self, shape, dtype=None, **kwargs):
337    """Returns a tensor object initialized to random normal values.
338
339    Args:
340      shape: Shape of the tensor.
341      dtype: Optional dtype of the tensor. Only floating point types are
342        supported. If not specified, `tf.keras.backend.floatx()` is used, which
343        default to `float32` unless you configured it otherwise (via
344        `tf.keras.backend.set_floatx(float_dtype)`)
345      **kwargs: Additional keyword arguments.
346    """
347    _validate_kwargs(self.__class__.__name__, kwargs)
348    dtype = _assert_float_dtype(_get_dtype(dtype))
349    if _PARTITION_SHAPE in kwargs:
350      shape = kwargs[_PARTITION_SHAPE]
351    return self._random_generator.random_normal(shape, self.mean, self.stddev,
352                                                dtype)
353
354  def get_config(self):
355    return {
356        'mean': self.mean,
357        'stddev': self.stddev,
358        'seed': self.seed
359    }
360
361
362@keras_export('keras.initializers.TruncatedNormal',
363              'keras.initializers.truncated_normal',
364              v1=[])
365class TruncatedNormal(Initializer):
366  """Initializer that generates a truncated normal distribution.
367
368  Also available via the shortcut function
369  `tf.keras.initializers.truncated_normal`.
370
371  The values generated are similar to values from a
372  `tf.keras.initializers.RandomNormal` initializer except that values more
373  than two standard deviations from the mean are
374  discarded and re-drawn.
375
376  Examples:
377
378  >>> # Standalone usage:
379  >>> initializer = tf.keras.initializers.TruncatedNormal(mean=0., stddev=1.)
380  >>> values = initializer(shape=(2, 2))
381
382  >>> # Usage in a Keras layer:
383  >>> initializer = tf.keras.initializers.TruncatedNormal(mean=0., stddev=1.)
384  >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer)
385
386  Args:
387    mean: a python scalar or a scalar tensor. Mean of the random values
388      to generate.
389    stddev: a python scalar or a scalar tensor. Standard deviation of the
390      random values to generate before truncation.
391    seed: A Python integer. An initializer created with a given seed will
392      always produce the same random tensor for a given shape and dtype.
393  """
394
395  def __init__(self, mean=0.0, stddev=0.05, seed=None):
396    self.mean = mean
397    self.stddev = stddev
398    self.seed = seed
399    self._random_generator = _RandomGenerator(seed)
400
401  def __call__(self, shape, dtype=None, **kwargs):
402    """Returns a tensor object initialized to random normal values (truncated).
403
404    Args:
405      shape: Shape of the tensor.
406      dtype: Optional dtype of the tensor. Only floating point types are
407        supported. If not specified, `tf.keras.backend.floatx()` is used, which
408        default to `float32` unless you configured it otherwise (via
409        `tf.keras.backend.set_floatx(float_dtype)`)
410      **kwargs: Additional keyword arguments.
411    """
412    _validate_kwargs(self.__class__.__name__, kwargs)
413    dtype = _assert_float_dtype(_get_dtype(dtype))
414    if _PARTITION_SHAPE in kwargs:
415      shape = kwargs[_PARTITION_SHAPE]
416    return self._random_generator.truncated_normal(shape, self.mean,
417                                                   self.stddev, dtype)
418
419  def get_config(self):
420    return {
421        'mean': self.mean,
422        'stddev': self.stddev,
423        'seed': self.seed
424    }
425
426
427@keras_export('keras.initializers.VarianceScaling',
428              'keras.initializers.variance_scaling',
429              v1=[])
430class VarianceScaling(Initializer):
431  """Initializer capable of adapting its scale to the shape of weights tensors.
432
433  Also available via the shortcut function
434  `tf.keras.initializers.variance_scaling`.
435
436  With `distribution="truncated_normal" or "untruncated_normal"`, samples are
437  drawn from a truncated/untruncated normal distribution with a mean of zero and
438  a standard deviation (after truncation, if used) `stddev = sqrt(scale / n)`,
439  where `n` is:
440
441  - number of input units in the weight tensor, if `mode="fan_in"`
442  - number of output units, if `mode="fan_out"`
443  - average of the numbers of input and output units, if `mode="fan_avg"`
444
445  With `distribution="uniform"`, samples are drawn from a uniform distribution
446  within `[-limit, limit]`, where `limit = sqrt(3 * scale / n)`.
447
448  Examples:
449
450  >>> # Standalone usage:
451  >>> initializer = tf.keras.initializers.VarianceScaling(
452  ... scale=0.1, mode='fan_in', distribution='uniform')
453  >>> values = initializer(shape=(2, 2))
454
455  >>> # Usage in a Keras layer:
456  >>> initializer = tf.keras.initializers.VarianceScaling(
457  ... scale=0.1, mode='fan_in', distribution='uniform')
458  >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer)
459
460  Args:
461    scale: Scaling factor (positive float).
462    mode: One of "fan_in", "fan_out", "fan_avg".
463    distribution: Random distribution to use. One of "truncated_normal",
464      "untruncated_normal" and  "uniform".
465    seed: A Python integer. An initializer created with a given seed will
466      always produce the same random tensor for a given shape and dtype.
467  """
468
469  def __init__(self,
470               scale=1.0,
471               mode='fan_in',
472               distribution='truncated_normal',
473               seed=None):
474    if scale <= 0.:
475      raise ValueError('`scale` must be positive float.')
476    if mode not in {'fan_in', 'fan_out', 'fan_avg'}:
477      raise ValueError('Invalid `mode` argument:', mode)
478    distribution = distribution.lower()
479    # Compatibility with keras-team/keras.
480    if distribution == 'normal':
481      distribution = 'truncated_normal'
482    if distribution not in {'uniform', 'truncated_normal',
483                            'untruncated_normal'}:
484      raise ValueError('Invalid `distribution` argument:', distribution)
485    self.scale = scale
486    self.mode = mode
487    self.distribution = distribution
488    self.seed = seed
489    self._random_generator = _RandomGenerator(seed)
490
491  def __call__(self, shape, dtype=None, **kwargs):
492    """Returns a tensor object initialized as specified by the initializer.
493
494    Args:
495      shape: Shape of the tensor.
496      dtype: Optional dtype of the tensor. Only floating point types are
497        supported. If not specified, `tf.keras.backend.floatx()` is used, which
498        default to `float32` unless you configured it otherwise (via
499        `tf.keras.backend.set_floatx(float_dtype)`)
500      **kwargs: Additional keyword arguments.
501    """
502    _validate_kwargs(self.__class__.__name__, kwargs)
503    dtype = _assert_float_dtype(_get_dtype(dtype))
504    scale = self.scale
505    fan_in, fan_out = _compute_fans(shape)
506    if _PARTITION_SHAPE in kwargs:
507      shape = kwargs[_PARTITION_SHAPE]
508    if self.mode == 'fan_in':
509      scale /= max(1., fan_in)
510    elif self.mode == 'fan_out':
511      scale /= max(1., fan_out)
512    else:
513      scale /= max(1., (fan_in + fan_out) / 2.)
514    if self.distribution == 'truncated_normal':
515      # constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
516      stddev = math.sqrt(scale) / .87962566103423978
517      return self._random_generator.truncated_normal(shape, 0.0, stddev, dtype)
518    elif self.distribution == 'untruncated_normal':
519      stddev = math.sqrt(scale)
520      return self._random_generator.random_normal(shape, 0.0, stddev, dtype)
521    else:
522      limit = math.sqrt(3.0 * scale)
523      return self._random_generator.random_uniform(shape, -limit, limit, dtype)
524
525  def get_config(self):
526    return {
527        'scale': self.scale,
528        'mode': self.mode,
529        'distribution': self.distribution,
530        'seed': self.seed
531    }
532
533
534@keras_export('keras.initializers.Orthogonal',
535              'keras.initializers.orthogonal',
536              v1=[])
537class Orthogonal(Initializer):
538  """Initializer that generates an orthogonal matrix.
539
540  Also available via the shortcut function `tf.keras.initializers.orthogonal`.
541
542  If the shape of the tensor to initialize is two-dimensional, it is initialized
543  with an orthogonal matrix obtained from the QR decomposition of a matrix of
544  random numbers drawn from a normal distribution.
545  If the matrix has fewer rows than columns then the output will have orthogonal
546  rows. Otherwise, the output will have orthogonal columns.
547
548  If the shape of the tensor to initialize is more than two-dimensional,
549  a matrix of shape `(shape[0] * ... * shape[n - 2], shape[n - 1])`
550  is initialized, where `n` is the length of the shape vector.
551  The matrix is subsequently reshaped to give a tensor of the desired shape.
552
553  Examples:
554
555  >>> # Standalone usage:
556  >>> initializer = tf.keras.initializers.Orthogonal()
557  >>> values = initializer(shape=(2, 2))
558
559  >>> # Usage in a Keras layer:
560  >>> initializer = tf.keras.initializers.Orthogonal()
561  >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer)
562
563  Args:
564    gain: multiplicative factor to apply to the orthogonal matrix
565    seed: A Python integer. An initializer created with a given seed will
566      always produce the same random tensor for a given shape and dtype.
567
568  References:
569      [Saxe et al., 2014](https://openreview.net/forum?id=_wzZwKpTDF_9C)
570      ([pdf](https://arxiv.org/pdf/1312.6120.pdf))
571  """
572
573  def __init__(self, gain=1.0, seed=None):
574    self.gain = gain
575    self.seed = seed
576    self._random_generator = _RandomGenerator(seed)
577
578  def __call__(self, shape, dtype=None, **kwargs):
579    """Returns a tensor object initialized to an orthogonal matrix.
580
581    Args:
582      shape: Shape of the tensor.
583      dtype: Optional dtype of the tensor. Only floating point types are
584        supported. If not specified, `tf.keras.backend.floatx()` is used,
585       which default to `float32` unless you configured it otherwise
586       (via `tf.keras.backend.set_floatx(float_dtype)`)
587      **kwargs: Additional keyword arguments.
588    """
589    _validate_kwargs(self.__class__.__name__, kwargs, support_partition=False)
590    dtype = _assert_float_dtype(_get_dtype(dtype))
591    # Check the shape
592    if len(shape) < 2:
593      raise ValueError('The tensor to initialize must be '
594                       'at least two-dimensional')
595    # Flatten the input shape with the last dimension remaining
596    # its original shape so it works for conv2d
597    num_rows = 1
598    for dim in shape[:-1]:
599      num_rows *= dim
600    num_cols = shape[-1]
601    flat_shape = (max(num_cols, num_rows), min(num_cols, num_rows))
602
603    # Generate a random matrix
604    a = self._random_generator.random_normal(flat_shape, dtype=dtype)
605    # Compute the qr factorization
606    q, r = gen_linalg_ops.qr(a, full_matrices=False)
607    # Make Q uniform
608    d = array_ops.tensor_diag_part(r)
609    q *= math_ops.sign(d)
610    if num_rows < num_cols:
611      q = array_ops.matrix_transpose(q)
612    return self.gain * array_ops.reshape(q, shape)
613
614  def get_config(self):
615    return {'gain': self.gain, 'seed': self.seed}
616
617
618@keras_export('keras.initializers.Identity',
619              'keras.initializers.identity',
620              v1=[])
621class Identity(Initializer):
622  """Initializer that generates the identity matrix.
623
624  Also available via the shortcut function `tf.keras.initializers.identity`.
625
626  Only usable for generating 2D matrices.
627
628  Examples:
629
630  >>> # Standalone usage:
631  >>> initializer = tf.keras.initializers.Identity()
632  >>> values = initializer(shape=(2, 2))
633
634  >>> # Usage in a Keras layer:
635  >>> initializer = tf.keras.initializers.Identity()
636  >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer)
637
638  Args:
639    gain: Multiplicative factor to apply to the identity matrix.
640  """
641
642  def __init__(self, gain=1.0):
643    self.gain = gain
644
645  def __call__(self, shape, dtype=None, **kwargs):
646    """Returns a tensor object initialized to a 2D identity matrix.
647
648    Args:
649      shape: Shape of the tensor. It should have exactly rank 2.
650      dtype: Optional dtype of the tensor. Only floating point types are
651       supported. If not specified, `tf.keras.backend.floatx()` is used,
652       which default to `float32` unless you configured it otherwise
653       (via `tf.keras.backend.set_floatx(float_dtype)`)
654      **kwargs: Additional keyword arguments.
655    """
656    _validate_kwargs(self.__class__.__name__, kwargs, support_partition=False)
657    dtype = _assert_float_dtype(_get_dtype(dtype))
658    if len(shape) != 2:
659      raise ValueError(
660          'Identity matrix initializer can only be used for 2D matrices.')
661    initializer = linalg_ops.eye(*shape, dtype=dtype)
662    return self.gain * initializer
663
664  def get_config(self):
665    return {'gain': self.gain}
666
667
668@keras_export('keras.initializers.GlorotUniform',
669              'keras.initializers.glorot_uniform',
670              v1=[])
671class GlorotUniform(VarianceScaling):
672  """The Glorot uniform initializer, also called Xavier uniform initializer.
673
674  Also available via the shortcut function
675  `tf.keras.initializers.glorot_uniform`.
676
677  Draws samples from a uniform distribution within `[-limit, limit]`, where
678  `limit = sqrt(6 / (fan_in + fan_out))` (`fan_in` is the number of input units
679  in the weight tensor and `fan_out` is the number of output units).
680
681  Examples:
682
683  >>> # Standalone usage:
684  >>> initializer = tf.keras.initializers.GlorotUniform()
685  >>> values = initializer(shape=(2, 2))
686
687  >>> # Usage in a Keras layer:
688  >>> initializer = tf.keras.initializers.GlorotUniform()
689  >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer)
690
691  Args:
692    seed: A Python integer. An initializer created with a given seed will
693      always produce the same random tensor for a given shape and dtype.
694
695  References:
696      [Glorot et al., 2010](http://proceedings.mlr.press/v9/glorot10a.html)
697      ([pdf](http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf))
698  """
699
700  def __init__(self, seed=None):
701    super(GlorotUniform, self).__init__(
702        scale=1.0,
703        mode='fan_avg',
704        distribution='uniform',
705        seed=seed)
706
707  def get_config(self):
708    return {'seed': self.seed}
709
710
711@keras_export('keras.initializers.GlorotNormal',
712              'keras.initializers.glorot_normal',
713              v1=[])
714class GlorotNormal(VarianceScaling):
715  """The Glorot normal initializer, also called Xavier normal initializer.
716
717  Also available via the shortcut function
718  `tf.keras.initializers.glorot_normal`.
719
720  Draws samples from a truncated normal distribution centered on 0 with `stddev
721  = sqrt(2 / (fan_in + fan_out))` where `fan_in` is the number of input units in
722  the weight tensor and `fan_out` is the number of output units in the weight
723  tensor.
724
725  Examples:
726
727  >>> # Standalone usage:
728  >>> initializer = tf.keras.initializers.GlorotNormal()
729  >>> values = initializer(shape=(2, 2))
730
731  >>> # Usage in a Keras layer:
732  >>> initializer = tf.keras.initializers.GlorotNormal()
733  >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer)
734
735  Args:
736    seed: A Python integer. An initializer created with a given seed will
737      always produce the same random tensor for a given shape and dtype.
738
739  References:
740      [Glorot et al., 2010](http://proceedings.mlr.press/v9/glorot10a.html)
741      ([pdf](http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf))
742  """
743
744  def __init__(self, seed=None):
745    super(GlorotNormal, self).__init__(
746        scale=1.0,
747        mode='fan_avg',
748        distribution='truncated_normal',
749        seed=seed)
750
751  def get_config(self):
752    return {'seed': self.seed}
753
754
755@keras_export('keras.initializers.LecunNormal',
756              'keras.initializers.lecun_normal',
757              v1=[])
758class LecunNormal(VarianceScaling):
759  """Lecun normal initializer.
760
761   Also available via the shortcut function
762  `tf.keras.initializers.lecun_normal`.
763
764  Initializers allow you to pre-specify an initialization strategy, encoded in
765  the Initializer object, without knowing the shape and dtype of the variable
766  being initialized.
767
768  Draws samples from a truncated normal distribution centered on 0 with `stddev
769  = sqrt(1 / fan_in)` where `fan_in` is the number of input units in the weight
770  tensor.
771
772  Examples:
773
774  >>> # Standalone usage:
775  >>> initializer = tf.keras.initializers.LecunNormal()
776  >>> values = initializer(shape=(2, 2))
777
778  >>> # Usage in a Keras layer:
779  >>> initializer = tf.keras.initializers.LecunNormal()
780  >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer)
781
782  Args:
783    seed: A Python integer. Used to seed the random generator.
784
785  References:
786      - Self-Normalizing Neural Networks,
787      [Klambauer et al., 2017]
788      (https://papers.nips.cc/paper/6698-self-normalizing-neural-networks)
789      ([pdf]
790      (https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf))
791      - Efficient Backprop,
792      [Lecun et al., 1998](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf)
793  """
794
795  def __init__(self, seed=None):
796    super(LecunNormal, self).__init__(
797        scale=1., mode='fan_in', distribution='truncated_normal', seed=seed)
798
799  def get_config(self):
800    return {'seed': self.seed}
801
802
803@keras_export('keras.initializers.LecunUniform',
804              'keras.initializers.lecun_uniform',
805              v1=[])
806class LecunUniform(VarianceScaling):
807  """Lecun uniform initializer.
808
809   Also available via the shortcut function
810  `tf.keras.initializers.lecun_uniform`.
811
812  Draws samples from a uniform distribution within `[-limit, limit]`,
813  where `limit = sqrt(3 / fan_in)` (`fan_in` is the number of input units in the
814  weight tensor).
815
816  Examples:
817
818  >>> # Standalone usage:
819  >>> initializer = tf.keras.initializers.LecunUniform()
820  >>> values = initializer(shape=(2, 2))
821
822  >>> # Usage in a Keras layer:
823  >>> initializer = tf.keras.initializers.LecunUniform()
824  >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer)
825
826  Args:
827    seed: A Python integer. An initializer created with a given seed will
828      always produce the same random tensor for a given shape and dtype.
829
830  References:
831      - Self-Normalizing Neural Networks,
832      [Klambauer et al., 2017](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks) # pylint: disable=line-too-long
833      ([pdf](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf))
834      - Efficient Backprop,
835      [Lecun et al., 1998](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf)
836  """
837
838  def __init__(self, seed=None):
839    super(LecunUniform, self).__init__(
840        scale=1., mode='fan_in', distribution='uniform', seed=seed)
841
842  def get_config(self):
843    return {'seed': self.seed}
844
845
846@keras_export('keras.initializers.HeNormal',
847              'keras.initializers.he_normal',
848              v1=[])
849class HeNormal(VarianceScaling):
850  """He normal initializer.
851
852   Also available via the shortcut function
853  `tf.keras.initializers.he_normal`.
854
855  It draws samples from a truncated normal distribution centered on 0 with
856  `stddev = sqrt(2 / fan_in)` where `fan_in` is the number of input units in the
857  weight tensor.
858
859  Examples:
860
861  >>> # Standalone usage:
862  >>> initializer = tf.keras.initializers.HeNormal()
863  >>> values = initializer(shape=(2, 2))
864
865  >>> # Usage in a Keras layer:
866  >>> initializer = tf.keras.initializers.HeNormal()
867  >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer)
868
869  Args:
870    seed: A Python integer. An initializer created with a given seed will
871      always produce the same random tensor for a given shape and dtype.
872
873  References:
874      [He et al., 2015](https://www.cv-foundation.org/openaccess/content_iccv_2015/html/He_Delving_Deep_into_ICCV_2015_paper.html) # pylint: disable=line-too-long
875      ([pdf](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf))
876  """
877
878  def __init__(self, seed=None):
879    super(HeNormal, self).__init__(
880        scale=2., mode='fan_in', distribution='truncated_normal', seed=seed)
881
882  def get_config(self):
883    return {'seed': self.seed}
884
885
886@keras_export('keras.initializers.HeUniform',
887              'keras.initializers.he_uniform',
888              v1=[])
889class HeUniform(VarianceScaling):
890  """He uniform variance scaling initializer.
891
892   Also available via the shortcut function
893  `tf.keras.initializers.he_uniform`.
894
895  Draws samples from a uniform distribution within `[-limit, limit]`, where
896  `limit = sqrt(6 / fan_in)` (`fan_in` is the number of input units in the
897  weight tensor).
898
899  Examples:
900
901  >>> # Standalone usage:
902  >>> initializer = tf.keras.initializers.HeUniform()
903  >>> values = initializer(shape=(2, 2))
904
905  >>> # Usage in a Keras layer:
906  >>> initializer = tf.keras.initializers.HeUniform()
907  >>> layer = tf.keras.layers.Dense(3, kernel_initializer=initializer)
908
909  Args:
910    seed: A Python integer. An initializer created with a given seed will
911      always produce the same random tensor for a given shape and dtype.
912
913  References:
914      [He et al., 2015](https://www.cv-foundation.org/openaccess/content_iccv_2015/html/He_Delving_Deep_into_ICCV_2015_paper.html) # pylint: disable=line-too-long
915      ([pdf](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf))
916  """
917
918  def __init__(self, seed=None):
919    super(HeUniform, self).__init__(
920        scale=2., mode='fan_in', distribution='uniform', seed=seed)
921
922  def get_config(self):
923    return {'seed': self.seed}
924
925
926def _get_dtype(dtype):
927  if dtype is None:
928    dtype = backend.floatx()
929  return dtypes.as_dtype(dtype)
930
931
932def _assert_float_dtype(dtype):
933  """Validate and return floating point type based on `dtype`.
934
935  `dtype` must be a floating point type.
936
937  Args:
938    dtype: The data type to validate.
939
940  Returns:
941    Validated type.
942
943  Raises:
944    ValueError: if `dtype` is not a floating point type.
945  """
946  dtype = dtypes.as_dtype(dtype)
947  if not dtype.is_floating:
948    raise ValueError('Expected floating point type, got %s.' % dtype)
949  return dtype
950
951
952class _RandomGenerator(object):
953  """Random generator that selects appropriate random ops."""
954
955  def __init__(self, seed=None):
956    super(_RandomGenerator, self).__init__()
957    if seed is not None:
958      # Stateless random ops requires 2-int seed.
959      self.seed = [seed, 0]
960    else:
961      self.seed = None
962
963  def random_normal(self, shape, mean=0.0, stddev=1, dtype=dtypes.float32):
964    """A deterministic random normal if seed is passed."""
965    if self.seed:
966      op = stateless_random_ops.stateless_random_normal
967    else:
968      op = random_ops.random_normal
969    return op(
970        shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=self.seed)
971
972  def random_uniform(self, shape, minval, maxval, dtype):
973    """A deterministic random uniform if seed is passed."""
974    if self.seed:
975      op = stateless_random_ops.stateless_random_uniform
976    else:
977      op = random_ops.random_uniform
978    return op(
979        shape=shape, minval=minval, maxval=maxval, dtype=dtype, seed=self.seed)
980
981  def truncated_normal(self, shape, mean, stddev, dtype):
982    """A deterministic truncated normal if seed is passed."""
983    if self.seed:
984      op = stateless_random_ops.stateless_truncated_normal
985    else:
986      op = random_ops.truncated_normal
987    return op(
988        shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=self.seed)
989
990
991def _compute_fans(shape):
992  """Computes the number of input and output units for a weight shape.
993
994  Args:
995    shape: Integer shape tuple or TF tensor shape.
996
997  Returns:
998    A tuple of integer scalars (fan_in, fan_out).
999  """
1000  if len(shape) < 1:  # Just to avoid errors for constants.
1001    fan_in = fan_out = 1
1002  elif len(shape) == 1:
1003    fan_in = fan_out = shape[0]
1004  elif len(shape) == 2:
1005    fan_in = shape[0]
1006    fan_out = shape[1]
1007  else:
1008    # Assuming convolution kernels (2D, 3D, or more).
1009    # kernel shape: (..., input_depth, depth)
1010    receptive_field_size = 1
1011    for dim in shape[:-2]:
1012      receptive_field_size *= dim
1013    fan_in = shape[-2] * receptive_field_size
1014    fan_out = shape[-1] * receptive_field_size
1015  return int(fan_in), int(fan_out)
1016
1017
1018def _validate_kwargs(cls_name, kwargs, support_partition=True):
1019  for kwarg in kwargs:
1020    if kwarg not in [_PARTITION_SHAPE, _PARTITION_OFFSET]:
1021      raise TypeError('Unknown keyword arguments: %s' % kwarg)
1022    elif not support_partition:
1023      raise ValueError('%s initializer doesn\'t support partition-related '
1024                       'arguments' % cls_name)
1025