• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operations often used for initializing tensors.
16
17All variable initializers returned by functions in this file should have the
18following signature:
19
20def _initializer(shape, dtype=dtypes.float32, partition_info=None):
21  Args:
22    shape: List of `int` representing the shape of the output `Tensor`. Some
23      initializers may also be able to accept a `Tensor`.
24    dtype: (Optional) Type of the output `Tensor`.
25    partition_info: (Optional) variable_scope._PartitionInfo object holding
26      additional information about how the variable is partitioned. May be
27      `None` if the variable is not partitioned.
28
29  Returns:
30    A `Tensor` of type `dtype` and `shape`.
31"""
32from __future__ import absolute_import
33from __future__ import division
34from __future__ import print_function
35
36import math
37
38import numpy as np
39
40from tensorflow.python.framework import constant_op
41from tensorflow.python.framework import dtypes
42from tensorflow.python.framework import tensor_shape
43from tensorflow.python.ops import array_ops
44from tensorflow.python.ops import gen_linalg_ops
45from tensorflow.python.ops import linalg_ops_impl
46from tensorflow.python.ops import math_ops
47from tensorflow.python.ops import random_ops
48from tensorflow.python.util import deprecation
49from tensorflow.python.util.deprecation import deprecated
50from tensorflow.python.util.deprecation import deprecated_arg_values
51from tensorflow.python.util.deprecation import deprecated_args
52from tensorflow.python.util.tf_export import tf_export
53
54
55class Initializer(object):
56  """Initializer base class: all initializers inherit from this class."""
57
58  def __call__(self, shape, dtype=None, partition_info=None):
59    """Returns a tensor object initialized as specified by the initializer.
60
61    Args:
62      shape: Shape of the tensor.
63      dtype: Optional dtype of the tensor. If not provided use the initializer
64        dtype.
65      partition_info: Optional information about the possible partitioning of a
66        tensor.
67    """
68    raise NotImplementedError
69
70  def get_config(self):
71    """Returns the configuration of the initializer as a JSON-serializable dict.
72
73    Returns:
74      A JSON-serializable Python dict.
75    """
76    return {}
77
78  @classmethod
79  def from_config(cls, config):
80    """Instantiates an initializer from a configuration dictionary.
81
82    Example:
83
84    ```python
85    initializer = RandomUniform(-1, 1)
86    config = initializer.get_config()
87    initializer = RandomUniform.from_config(config)
88    ```
89
90    Args:
91      config: A Python dictionary. It will typically be the output of
92        `get_config`.
93
94    Returns:
95      An Initializer instance.
96    """
97    return cls(**config)
98
99
100@tf_export(v1=["initializers.zeros", "zeros_initializer"])
101@deprecation.deprecated_endpoints("initializers.zeros")
102class Zeros(Initializer):
103  """Initializer that generates tensors initialized to 0."""
104
105  @deprecated_args(None,
106                   "Call initializer instance with the dtype argument instead "
107                   "of passing it to the constructor", "dtype")
108  def __init__(self, dtype=dtypes.float32):
109    self.dtype = dtypes.as_dtype(dtype)
110
111  def __call__(self, shape, dtype=None, partition_info=None):
112    if dtype is None:
113      dtype = self.dtype
114    return array_ops.zeros(shape, dtype)
115
116  def get_config(self):
117    return {"dtype": self.dtype.name}
118
119
120@tf_export(v1=["initializers.ones", "ones_initializer"])
121@deprecation.deprecated_endpoints("initializers.ones", "ones_initializer")
122class Ones(Initializer):
123  """Initializer that generates tensors initialized to 1."""
124
125  @deprecated_args(None,
126                   "Call initializer instance with the dtype argument instead "
127                   "of passing it to the constructor", "dtype")
128  def __init__(self, dtype=dtypes.float32):
129    self.dtype = dtypes.as_dtype(dtype)
130
131  def __call__(self, shape, dtype=None, partition_info=None):
132    if dtype is None:
133      dtype = self.dtype
134    return array_ops.ones(shape, dtype)
135
136  def get_config(self):
137    return {"dtype": self.dtype.name}
138
139
140@tf_export(v1=["initializers.constant", "constant_initializer"])
141@deprecation.deprecated_endpoints("constant_initializer")
142class Constant(Initializer):
143  """Initializer that generates tensors with constant values.
144
145  The resulting tensor is populated with values of type `dtype`, as
146  specified by arguments `value` following the desired `shape` of the
147  new tensor (see examples below).
148
149  The argument `value` can be a constant value, or a list of values of type
150  `dtype`. If `value` is a list, then the length of the list must be less
151  than or equal to the number of elements implied by the desired shape of the
152  tensor. In the case where the total number of elements in `value` is less
153  than the number of elements required by the tensor shape, the last element
154  in `value` will be used to fill the remaining entries. If the total number of
155  elements in `value` is greater than the number of elements required by the
156  tensor shape, the initializer will raise a `ValueError`.
157
158  Args:
159    value: A Python scalar, list or tuple of values, or a N-dimensional numpy
160      array. All elements of the initialized variable will be set to the
161      corresponding value in the `value` argument.
162    dtype: Default data type, used if no `dtype` argument is provided when
163      calling the initializer.
164    verify_shape: Boolean that enables verification of the shape of `value`. If
165      `True`, the initializer will throw an error if the shape of `value` is not
166      compatible with the shape of the initialized tensor.
167
168  Raises:
169    TypeError: If the input `value` is not one of the expected types.
170
171  Examples:
172    The following example can be rewritten using a numpy.ndarray instead
173    of the `value` list, even reshaped, as shown in the two commented lines
174    below the `value` list initialization.
175
176  >>> value = [0, 1, 2, 3, 4, 5, 6, 7]
177  >>> init = tf.compat.v1.constant_initializer(value)
178  >>> # fitting shape
179  >>> with tf.compat.v1.Session():
180  ...   x = tf.compat.v1.get_variable('x', shape=[2, 4], initializer=init)
181  ...   x.initializer.run()
182  ...   print(x.eval())
183  [[0. 1. 2. 3.]
184   [4. 5. 6. 7.]]
185  >>> # Larger shape
186  >>> with tf.compat.v1.Session():
187  ...   y = tf.compat.v1.get_variable('y', shape=[3, 4], initializer=init)
188  ...   y.initializer.run()
189  ...   print(y.eval())
190  [[0.  1.  2.  3.]
191   [4.  5.  6.  7.]
192   [7.  7.  7.  7.]]
193  >>> # Smaller shape
194  >>> with tf.compat.v1.Session():
195  ...   z = tf.compat.v1.get_variable('z', shape=[2, 3], initializer=init)
196  Traceback (most recent call last):
197  ...
198  ValueError: Too many elements provided. Needed at most 6, but received 8
199  >>> # Shape verification
200  >>> init_verify = tf.compat.v1.constant_initializer(value, verify_shape=True)
201  >>> with tf.compat.v1.Session():
202  ...  u = tf.compat.v1.get_variable('u', shape=[3, 4],
203  ...                                initializer=init_verify)
204  Traceback (most recent call last):
205  ...
206  TypeError: Expected Tensor's shape: (3, 4), got (8,).
207  """
208
209  @deprecated_args(None,
210                   "Call initializer instance with the dtype argument instead "
211                   "of passing it to the constructor", "dtype")
212  @deprecated_args(None, "Objects must now be the required shape or no shape "
213                   "can be specified", "verify_shape")
214  def __init__(self, value=0, dtype=dtypes.float32, verify_shape=False):
215    if not (np.isscalar(value) or isinstance(value, (list, tuple, np.ndarray))):
216      raise TypeError(
217          "Invalid type for initial value: %s (expected Python scalar, list or "
218          "tuple of values, or numpy.ndarray)." % type(value))
219
220    self.value = value
221    self.dtype = dtypes.as_dtype(dtype)
222    self._verify_shape = verify_shape
223
224  def __call__(self, shape, dtype=None, partition_info=None, verify_shape=None):
225    if dtype is None:
226      dtype = self.dtype
227    if verify_shape is None:
228      verify_shape = self._verify_shape
229    return constant_op.constant_v1(
230        self.value, dtype=dtype, shape=shape, verify_shape=verify_shape)
231
232  def get_config(self):
233    # We don't include `verify_shape` for compatibility with Keras.
234    # `verify_shape` should be passed as an argument to `__call__` rather
235    # than as a constructor argument: conceptually it isn't a property
236    # of the initializer.
237    return {"value": self.value, "dtype": self.dtype.name}
238
239
240@tf_export(v1=["initializers.random_uniform", "random_uniform_initializer"])
241@deprecation.deprecated_endpoints("initializers.random_uniform")
242class RandomUniform(Initializer):
243  """Initializer that generates tensors with a uniform distribution.
244
245  Args:
246    minval: A python scalar or a scalar tensor. Lower bound of the range of
247      random values to generate.
248    maxval: A python scalar or a scalar tensor. Upper bound of the range of
249      random values to generate.  Defaults to 1 for float types.
250    seed: A Python integer. Used to create random seeds. See
251      `tf.compat.v1.set_random_seed` for behavior.
252    dtype: Default data type, used if no `dtype` argument is provided when
253      calling the initializer.
254  """
255
256  @deprecated_args(None,
257                   "Call initializer instance with the dtype argument instead "
258                   "of passing it to the constructor", "dtype")
259  def __init__(self, minval=0, maxval=None, seed=None, dtype=dtypes.float32):
260    self.minval = minval
261    self.maxval = maxval
262    self.seed = seed
263    self.dtype = dtypes.as_dtype(dtype)
264
265  def __call__(self, shape, dtype=None, partition_info=None):
266    if dtype is None:
267      dtype = self.dtype
268    return random_ops.random_uniform(
269        shape, self.minval, self.maxval, dtype, seed=self.seed)
270
271  def get_config(self):
272    return {
273        "minval": self.minval,
274        "maxval": self.maxval,
275        "seed": self.seed,
276        "dtype": self.dtype.name
277    }
278
279
280@tf_export(v1=["initializers.random_normal", "random_normal_initializer"])
281@deprecation.deprecated_endpoints("initializers.random_normal")
282class RandomNormal(Initializer):
283  """Initializer that generates tensors with a normal distribution.
284
285  Args:
286    mean: a python scalar or a scalar tensor. Mean of the random values to
287      generate.
288    stddev: a python scalar or a scalar tensor. Standard deviation of the random
289      values to generate.
290    seed: A Python integer. Used to create random seeds. See
291      `tf.compat.v1.set_random_seed` for behavior.
292    dtype: Default data type, used if no `dtype` argument is provided when
293      calling the initializer. Only floating point types are supported.
294  """
295
296  @deprecated_args(None,
297                   "Call initializer instance with the dtype argument instead "
298                   "of passing it to the constructor", "dtype")
299  def __init__(self, mean=0.0, stddev=1.0, seed=None, dtype=dtypes.float32):
300    self.mean = mean
301    self.stddev = stddev
302    self.seed = seed
303    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
304
305  def __call__(self, shape, dtype=None, partition_info=None):
306    if dtype is None:
307      dtype = self.dtype
308    return random_ops.random_normal(
309        shape, self.mean, self.stddev, dtype, seed=self.seed)
310
311  def get_config(self):
312    return {
313        "mean": self.mean,
314        "stddev": self.stddev,
315        "seed": self.seed,
316        "dtype": self.dtype.name
317    }
318
319
320@tf_export(v1=["initializers.truncated_normal", "truncated_normal_initializer"])
321@deprecation.deprecated_endpoints("initializers.truncated_normal",
322                                  "truncated_normal_initializer")
323class TruncatedNormal(Initializer):
324  """Initializer that generates a truncated normal distribution.
325
326  These values are similar to values from a `random_normal_initializer`
327  except that values more than two standard deviations from the mean
328  are discarded and re-drawn. This is the recommended initializer for
329  neural network weights and filters.
330
331  Args:
332    mean: a python scalar or a scalar tensor. Mean of the random values to
333      generate.
334    stddev: a python scalar or a scalar tensor. Standard deviation of the random
335      values to generate.
336    seed: A Python integer. Used to create random seeds. See
337      `tf.compat.v1.set_random_seed` for behavior.
338    dtype: Default data type, used if no `dtype` argument is provided when
339      calling the initializer. Only floating point types are supported.
340  """
341
342  @deprecated_args(None,
343                   "Call initializer instance with the dtype argument instead "
344                   "of passing it to the constructor", "dtype")
345  def __init__(self, mean=0.0, stddev=1.0, seed=None, dtype=dtypes.float32):
346    self.mean = mean
347    self.stddev = stddev
348    self.seed = seed
349    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
350
351  def __call__(self, shape, dtype=None, partition_info=None):
352    if dtype is None:
353      dtype = self.dtype
354    return random_ops.truncated_normal(
355        shape, self.mean, self.stddev, dtype, seed=self.seed)
356
357  def get_config(self):
358    return {
359        "mean": self.mean,
360        "stddev": self.stddev,
361        "seed": self.seed,
362        "dtype": self.dtype.name
363    }
364
365
366@tf_export(v1=[
367    "initializers.uniform_unit_scaling", "uniform_unit_scaling_initializer"
368])
369@deprecation.deprecated_endpoints("uniform_unit_scaling_initializer",
370                                  "initializers.uniform_unit_scaling")
371class UniformUnitScaling(Initializer):
372  """Initializer that generates tensors without scaling variance.
373
374  When initializing a deep network, it is in principle advantageous to keep
375  the scale of the input variance constant, so it does not explode or diminish
376  by reaching the final layer. If the input is `x` and the operation `x * W`,
377  and we want to initialize `W` uniformly at random, we need to pick `W` from
378
379      [-sqrt(3) / sqrt(dim), sqrt(3) / sqrt(dim)]
380
381  to keep the scale intact, where `dim = W.shape[0]` (the size of the input).
382  A similar calculation for convolutional networks gives an analogous result
383  with `dim` equal to the product of the first 3 dimensions.  When
384  nonlinearities are present, we need to multiply this by a constant `factor`.
385  See (Sussillo et al., 2014) for deeper motivation, experiments
386  and the calculation of constants. In section 2.3 there, the constants were
387  numerically computed: for a linear layer it's 1.0, relu: ~1.43, tanh: ~1.15.
388
389  Args:
390    factor: Float.  A multiplicative factor by which the values will be scaled.
391    seed: A Python integer. Used to create random seeds. See
392      `tf.compat.v1.set_random_seed` for behavior.
393    dtype: Default data type, used if no `dtype` argument is provided when
394      calling the initializer. Only floating point types are supported.
395  References:
396      [Sussillo et al., 2014](https://arxiv.org/abs/1412.6558)
397      ([pdf](http://arxiv.org/pdf/1412.6558.pdf))
398  """
399
400  @deprecated_args(None,
401                   "Call initializer instance with the dtype argument instead "
402                   "of passing it to the constructor", "dtype")
403  @deprecated(None,
404              "Use tf.initializers.variance_scaling instead with distribution="
405              "uniform to get equivalent behavior.")
406  def __init__(self, factor=1.0, seed=None, dtype=dtypes.float32):
407    self.factor = factor
408    self.seed = seed
409    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
410
411  def __call__(self, shape, dtype=None, partition_info=None):
412    if dtype is None:
413      dtype = self.dtype
414    scale_shape = shape
415    if partition_info is not None:
416      scale_shape = partition_info.full_shape
417
418    input_size = 1.0
419    # Estimating input size is not possible to do perfectly, but we try.
420    # The estimate, obtained by multiplying all dimensions but the last one,
421    # is the right thing for matrix multiply and convolutions (see above).
422    for dim in scale_shape[:-1]:
423      input_size *= float(dim)
424    # Avoid errors when initializing zero-size tensors.
425    input_size = max(input_size, 1.0)
426    max_val = math.sqrt(3 / input_size) * self.factor
427    return random_ops.random_uniform(
428        shape, -max_val, max_val, dtype, seed=self.seed)
429
430  def get_config(self):
431    return {"factor": self.factor, "seed": self.seed, "dtype": self.dtype.name}
432
433
434@tf_export(v1=["initializers.variance_scaling", "variance_scaling_initializer"])
435@deprecation.deprecated_endpoints("initializers.variance_scaling",
436                                  "variance_scaling_initializer")
437class VarianceScaling(Initializer):
438  """Initializer capable of adapting its scale to the shape of weights tensors.
439
440  With `distribution="truncated_normal" or "untruncated_normal"`,
441  samples are drawn from a truncated/untruncated normal
442  distribution with a mean of zero and a standard deviation (after truncation,
443  if used) `stddev = sqrt(scale / n)`
444  where n is:
445    - number of input units in the weight tensor, if mode = "fan_in"
446    - number of output units, if mode = "fan_out"
447    - average of the numbers of input and output units, if mode = "fan_avg"
448
449  With `distribution="uniform"`, samples are drawn from a uniform distribution
450  within [-limit, limit], with `limit = sqrt(3 * scale / n)`.
451
452  Args:
453    scale: Scaling factor (positive float).
454    mode: One of "fan_in", "fan_out", "fan_avg".
455    distribution: Random distribution to use. One of "normal", "uniform".
456    seed: A Python integer. Used to create random seeds. See
457      `tf.compat.v1.set_random_seed` for behavior.
458    dtype: Default data type, used if no `dtype` argument is provided when
459      calling the initializer. Only floating point types are supported.
460
461  Raises:
462    ValueError: In case of an invalid value for the "scale", mode" or
463      "distribution" arguments.
464  """
465
466  @deprecated_args(None,
467                   "Call initializer instance with the dtype argument instead "
468                   "of passing it to the constructor", "dtype")
469  @deprecated_arg_values(
470      None,
471      "`normal` is a deprecated alias for `truncated_normal`",
472      distribution="normal")
473  def __init__(self,
474               scale=1.0,
475               mode="fan_in",
476               distribution="truncated_normal",
477               seed=None,
478               dtype=dtypes.float32):
479    if scale <= 0.:
480      raise ValueError("`scale` must be positive float.")
481    if mode not in {"fan_in", "fan_out", "fan_avg"}:
482      raise ValueError("Invalid `mode` argument:", mode)
483    distribution = distribution.lower()
484    if distribution not in {
485        "normal", "uniform", "truncated_normal", "untruncated_normal"
486    }:
487      raise ValueError("Invalid `distribution` argument:", distribution)
488    self.scale = scale
489    self.mode = mode
490    self.distribution = distribution
491    self.seed = seed
492    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
493
494  def __call__(self, shape, dtype=None, partition_info=None):
495    if dtype is None:
496      dtype = self.dtype
497    scale = self.scale
498    scale_shape = shape
499    if partition_info is not None:
500      scale_shape = partition_info.full_shape
501    fan_in, fan_out = _compute_fans(scale_shape)
502    if self.mode == "fan_in":
503      scale /= max(1., fan_in)
504    elif self.mode == "fan_out":
505      scale /= max(1., fan_out)
506    else:
507      scale /= max(1., (fan_in + fan_out) / 2.)
508    if self.distribution == "normal" or self.distribution == "truncated_normal":
509      # constant taken from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
510      stddev = math.sqrt(scale) / .87962566103423978
511      return random_ops.truncated_normal(
512          shape, 0.0, stddev, dtype, seed=self.seed)
513    elif self.distribution == "untruncated_normal":
514      stddev = math.sqrt(scale)
515      return random_ops.random_normal(shape, 0.0, stddev, dtype, seed=self.seed)
516    else:
517      limit = math.sqrt(3.0 * scale)
518      return random_ops.random_uniform(
519          shape, -limit, limit, dtype, seed=self.seed)
520
521  def get_config(self):
522    return {
523        "scale": self.scale,
524        "mode": self.mode,
525        "distribution": self.distribution,
526        "seed": self.seed,
527        "dtype": self.dtype.name
528    }
529
530
531@tf_export(v1=["initializers.orthogonal", "orthogonal_initializer"])
532@deprecation.deprecated_endpoints("initializers.orthogonal",
533                                  "orthogonal_initializer")
534class Orthogonal(Initializer):
535  """Initializer that generates an orthogonal matrix.
536
537  If the shape of the tensor to initialize is two-dimensional, it is initialized
538  with an orthogonal matrix obtained from the QR decomposition of a matrix of
539  random numbers drawn from a normal distribution.
540  If the matrix has fewer rows than columns then the output will have orthogonal
541  rows. Otherwise, the output will have orthogonal columns.
542
543  If the shape of the tensor to initialize is more than two-dimensional,
544  a matrix of shape `(shape[0] * ... * shape[n - 2], shape[n - 1])`
545  is initialized, where `n` is the length of the shape vector.
546  The matrix is subsequently reshaped to give a tensor of the desired shape.
547
548  Args:
549    gain: multiplicative factor to apply to the orthogonal matrix
550    seed: A Python integer. Used to create random seeds. See
551      `tf.compat.v1.set_random_seed` for behavior.
552    dtype: Default data type, used if no `dtype` argument is provided when
553      calling the initializer. Only floating point types are supported.
554  References:
555      [Saxe et al., 2014](https://openreview.net/forum?id=_wzZwKpTDF_9C)
556      ([pdf](https://arxiv.org/pdf/1312.6120.pdf))
557  """
558
559  @deprecated_args(None,
560                   "Call initializer instance with the dtype argument instead "
561                   "of passing it to the constructor", "dtype")
562  def __init__(self, gain=1.0, seed=None, dtype=dtypes.float32):
563    self.gain = gain
564    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
565    self.seed = seed
566
567  def __call__(self, shape, dtype=None, partition_info=None):
568    if dtype is None:
569      dtype = self.dtype
570    # Check the shape
571    if len(shape) < 2:
572      raise ValueError("The tensor to initialize must be "
573                       "at least two-dimensional")
574    # Flatten the input shape with the last dimension remaining
575    # its original shape so it works for conv2d
576    num_rows = 1
577    for dim in shape[:-1]:
578      num_rows *= dim
579    num_rows = int(num_rows)
580    num_cols = int(shape[-1])
581    if num_rows < num_cols:
582      flat_shape = (num_cols, num_rows)
583    else:
584      flat_shape = (num_rows, num_cols)
585
586    # Generate a random matrix
587    a = random_ops.random_normal(flat_shape, dtype=dtype, seed=self.seed)
588    # Compute the qr factorization
589    q, r = gen_linalg_ops.qr(a, full_matrices=False)
590    # Make Q uniform
591    d = array_ops.diag_part(r)
592    q *= math_ops.sign(d)
593    if num_rows < num_cols:
594      q = array_ops.matrix_transpose(q)
595    return self.gain * array_ops.reshape(q, shape)
596
597  def get_config(self):
598    return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name}
599
600
601# Note these haven't been ported to TF2.0. They are not currently visible and
602# the tests are non trivial to port
603class ConvolutionDeltaOrthogonal(Initializer):
604  """Initializer that generates a delta orthogonal kernel for ConvNets.
605
606  The shape of the tensor must have length 3, 4 or 5. The number of input
607  filters must not exceed the number of output filters. The center pixels of the
608  tensor form an orthogonal matrix. Other pixels are set to be zero. See
609  algorithm 2 in (Xiao et al., 2018).
610
611
612  Args:
613    gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1.
614      The 2-norm of an input is multiplied by a factor of `gain` after applying
615      this convolution.
616    seed: A Python integer. Used to create random seeds. See
617      `tf.compat.v1.set_random_seed` for behavior.
618    dtype: Default data type, used if no `dtype` argument is provided when
619      calling the initializer. Only floating point types are supported.
620  References:
621      [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html)
622      ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf))
623  """
624
625  def __init__(self, gain=1.0, seed=None, dtype=dtypes.float32):
626    self.gain = gain
627    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
628    self.seed = seed
629
630  def __call__(self, shape, dtype=None, partition_info=None):
631    if dtype is None:
632      dtype = self.dtype
633    # Check the shape
634    if len(shape) < 3 or len(shape) > 5:
635      raise ValueError("The tensor to initialize must be at least "
636                       "three-dimensional and at most five-dimensional")
637
638    if shape[-2] > shape[-1]:
639      raise ValueError("In_filters cannot be greater than out_filters.")
640
641    # Generate a random matrix
642    a = random_ops.random_normal([shape[-1], shape[-1]],
643                                 dtype=dtype,
644                                 seed=self.seed)
645    # Compute the qr factorization
646    q, r = gen_linalg_ops.qr(a, full_matrices=False)
647    # Make Q uniform
648    d = array_ops.diag_part(r)
649    q *= math_ops.sign(d)
650    q = q[:shape[-2], :]
651    q *= math_ops.cast(self.gain, dtype=dtype)
652    if len(shape) == 3:
653      weight = array_ops.scatter_nd([[(shape[0] - 1) // 2]],
654                                    array_ops.expand_dims(q, 0), shape)
655    elif len(shape) == 4:
656      weight = array_ops.scatter_nd([[(shape[0] - 1) // 2,
657                                      (shape[1] - 1) // 2]],
658                                    array_ops.expand_dims(q, 0), shape)
659    else:
660      weight = array_ops.scatter_nd([[(shape[0] - 1) // 2, (shape[1] - 1) // 2,
661                                      (shape[2] - 1) // 2]],
662                                    array_ops.expand_dims(q, 0), shape)
663    return weight
664
665  def get_config(self):
666    return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name}
667
668
669class ConvolutionOrthogonal(Initializer):
670  """Initializer that generates orthogonal kernel for ConvNets.
671
672  Base class used to construct 1D, 2D and 3D orthogonal kernels for convolution.
673
674  Args:
675    gain: multiplicative factor to apply to the orthogonal matrix. Default is 1.
676      The 2-norm of an input is multiplied by a factor of `gain` after applying
677      this convolution.
678    seed: A Python integer. Used to create random seeds. See
679      `tf.compat.v1.set_random_seed` for behavior.
680    dtype: Default data type, used if no `dtype` argument is provided when
681      calling the initializer. Only floating point types are supported.
682  References:
683      [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html)
684      ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf))
685  """
686
687  def __init__(self, gain=1.0, seed=None, dtype=dtypes.float32):
688    self.gain = gain
689    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
690    self.seed = seed
691
692  def __call__(self, shape, dtype=None, partition_info=None):
693    raise NotImplementedError
694
695  def get_config(self):
696    return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name}
697
698  # Helper functions.
699  def _orthogonal_matrix(self, n):
700    """Construct an n x n orthogonal matrix.
701
702    Args:
703      n: Dimension.
704
705    Returns:
706      A n x n orthogonal matrix.
707    """
708    a = random_ops.random_normal([n, n], dtype=self.dtype, seed=self.seed)
709    if self.seed:
710      self.seed += 1
711    q, r = gen_linalg_ops.qr(a)
712    d = array_ops.diag_part(r)
713    # make q uniform
714    q *= math_ops.sign(d)
715    return q
716
717  def _symmetric_projection(self, n):
718    """Compute a n x n symmetric projection matrix.
719
720    Args:
721      n: Dimension.
722
723    Returns:
724      A n x n symmetric projection matrix, i.e. a matrix P s.t. P=P*P, P=P^T.
725    """
726    q = self._orthogonal_matrix(n)
727    # randomly zeroing out some columns
728    mask = math_ops.cast(
729        random_ops.random_normal([n], seed=self.seed) > 0, self.dtype)
730    if self.seed:
731      self.seed += 1
732    c = math_ops.multiply(q, mask)
733    return math_ops.matmul(c, array_ops.matrix_transpose(c))
734
735
736class ConvolutionOrthogonal2D(ConvolutionOrthogonal):
737  """Initializer that generates a 2D orthogonal kernel for ConvNets.
738
739  The shape of the tensor must have length 4. The number of input
740  filters must not exceed the number of output filters.
741  The orthogonality(==isometry) is exact when the inputs are circular padded.
742  There are finite-width effects with non-circular padding (e.g. zero padding).
743  See algorithm 1 in (Xiao et al., 2018).
744
745  Args:
746    gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1.
747      This has the effect of scaling the output 2-norm by a factor of `gain`.
748    seed: A Python integer. Used to create random seeds. See
749      `tf.compat.v1.set_random_seed` for behavior.
750    dtype: Default data type, used if no `dtype` argument is provided when
751      calling the initializer. Only floating point types are supported.
752  References:
753      [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html)
754      ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf))
755  """
756
757  def __call__(self, shape, dtype=None, partition_info=None):
758    if dtype is None:
759      dtype = self.dtype
760    if len(shape) != 4:
761      raise ValueError("The tensor to initialize must be four-dimensional")
762
763    if shape[-2] > shape[-1]:
764      raise ValueError("In_filters cannot be greater than out_filters.")
765
766    if shape[0] != shape[1]:
767      raise ValueError("Kernel sizes must be equal.")
768
769    kernel = self._orthogonal_kernel(shape[0], shape[2], shape[3])
770    kernel *= math_ops.cast(self.gain, dtype=dtype)
771    return kernel
772
773  def _dict_to_tensor(self, x, k1, k2):
774    """Convert a dictionary to a tensor.
775
776    Args:
777      x: A k1 * k2 dictionary.
778      k1: First dimension of x.
779      k2: Second dimension of x.
780
781    Returns:
782      A k1 * k2 tensor.
783    """
784
785    return array_ops.stack([array_ops.stack([x[i, j] for j in range(k2)])
786                            for i in range(k1)])
787
788  def _block_orth(self, p1, p2):
789    """Construct a 2 x 2 kernel.
790
791    Used to construct orthgonal kernel.
792
793    Args:
794      p1: A symmetric projection matrix.
795      p2: A symmetric projection matrix.
796
797    Returns:
798      A 2 x 2 kernel [[p1p2,         p1(1-p2)],
799                      [(1-p1)p2, (1-p1)(1-p2)]].
800    Raises:
801      ValueError: If the dimensions of p1 and p2 are different.
802    """
803    if p1.shape.as_list() != p2.shape.as_list():
804      raise ValueError("The dimension of the matrices must be the same.")
805    n = p1.shape.as_list()[0]
806    kernel2x2 = {}
807    eye = linalg_ops_impl.eye(n, dtype=self.dtype)
808    kernel2x2[0, 0] = math_ops.matmul(p1, p2)
809    kernel2x2[0, 1] = math_ops.matmul(p1, (eye - p2))
810    kernel2x2[1, 0] = math_ops.matmul((eye - p1), p2)
811    kernel2x2[1, 1] = math_ops.matmul((eye - p1), (eye - p2))
812
813    return kernel2x2
814
815  def _matrix_conv(self, m1, m2):
816    """Matrix convolution.
817
818    Args:
819      m1: A k x k dictionary, each element is a n x n matrix.
820      m2: A l x l dictionary, each element is a n x n matrix.
821
822    Returns:
823      (k + l - 1) * (k + l - 1) dictionary each element is a n x n matrix.
824    Raises:
825      ValueError: if the entries of m1 and m2 are of different dimensions.
826    """
827
828    n = (m1[0, 0]).shape.as_list()[0]
829    if n != (m2[0, 0]).shape.as_list()[0]:
830      raise ValueError("The entries in matrices m1 and m2 "
831                       "must have the same dimensions!")
832    k = int(np.sqrt(len(m1)))
833    l = int(np.sqrt(len(m2)))
834    result = {}
835    size = k + l - 1
836    # Compute matrix convolution between m1 and m2.
837    for i in range(size):
838      for j in range(size):
839        result[i, j] = array_ops.zeros([n, n], self.dtype)
840        for index1 in range(min(k, i + 1)):
841          for index2 in range(min(k, j + 1)):
842            if (i - index1) < l and (j - index2) < l:
843              result[i, j] += math_ops.matmul(m1[index1, index2],
844                                              m2[i - index1, j - index2])
845    return result
846
847  def _orthogonal_kernel(self, ksize, cin, cout):
848    """Construct orthogonal kernel for convolution.
849
850    Args:
851      ksize: Kernel size.
852      cin: Number of input channels.
853      cout: Number of output channels.
854
855    Returns:
856      An [ksize, ksize, cin, cout] orthogonal kernel.
857    Raises:
858      ValueError: If cin > cout.
859    """
860    if cin > cout:
861      raise ValueError("The number of input channels cannot exceed "
862                       "the number of output channels.")
863    orth = self._orthogonal_matrix(cout)[0:cin, :]
864    if ksize == 1:
865      return array_ops.expand_dims(array_ops.expand_dims(orth, 0), 0)
866
867    p = self._block_orth(
868        self._symmetric_projection(cout), self._symmetric_projection(cout))
869    for _ in range(ksize - 2):
870      temp = self._block_orth(
871          self._symmetric_projection(cout), self._symmetric_projection(cout))
872      p = self._matrix_conv(p, temp)
873    for i in range(ksize):
874      for j in range(ksize):
875        p[i, j] = math_ops.matmul(orth, p[i, j])
876
877    return self._dict_to_tensor(p, ksize, ksize)
878
879
880class ConvolutionOrthogonal1D(ConvolutionOrthogonal):
881  """Initializer that generates a 1D orthogonal kernel for ConvNets.
882
883  The shape of the tensor must have length 3. The number of input
884  filters must not exceed the number of output filters.
885  The orthogonality(==isometry) is exact when the inputs are circular padded.
886  There are finite-width effects with non-circular padding (e.g. zero padding).
887  See algorithm 1 in (Xiao et al., 2018).
888
889  Args:
890    gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1.
891      The 2-norm of an input is multiplied by a factor of `gain` after applying
892      this convolution.
893    seed: A Python integer. Used to create random seeds. See
894      `tf.compat.v1.set_random_seed` for behavior.
895    dtype: Default data type, used if no `dtype` argument is provided when
896      calling the initializer. Only floating point types are supported.
897  References:
898      [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html)
899      ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf))
900  """
901
902  def __call__(self, shape, dtype=None, partition_info=None):
903    if dtype is None:
904      dtype = self.dtype
905    if len(shape) != 3:
906      raise ValueError("The tensor to initialize must be three-dimensional")
907
908    if shape[-2] > shape[-1]:
909      raise ValueError("In_filters cannot be greater than out_filters.")
910
911    kernel = self._orthogonal_kernel(shape[0], shape[-2], shape[-1])
912    kernel *= math_ops.cast(self.gain, dtype=dtype)
913    return kernel
914
915  def _dict_to_tensor(self, x, k):
916    """Convert a dictionary to a tensor.
917
918    Args:
919      x: A dictionary of length k.
920      k: Dimension of x.
921
922    Returns:
923      A tensor with the same dimension.
924    """
925
926    return array_ops.stack([x[i] for i in range(k)])
927
928  def _block_orth(self, projection_matrix):
929    """Construct a kernel.
930
931    Used to construct orthgonal kernel.
932
933    Args:
934      projection_matrix: A symmetric projection matrix of size n x n.
935
936    Returns:
937      [projection_matrix, (1 - projection_matrix)].
938    """
939    n = projection_matrix.shape.as_list()[0]
940    kernel = {}
941    eye = linalg_ops_impl.eye(n, dtype=self.dtype)
942    kernel[0] = projection_matrix
943    kernel[1] = eye - projection_matrix
944    return kernel
945
946  def _matrix_conv(self, m1, m2):
947    """Matrix convolution.
948
949    Args:
950      m1: A dictionary of length k, each element is a n x n matrix.
951      m2: A dictionary of length l, each element is a n x n matrix.
952
953    Returns:
954      (k + l - 1)  dictionary each element is a n x n matrix.
955    Raises:
956      ValueError: Ff the entries of m1 and m2 are of different dimensions.
957    """
958
959    n = (m1[0]).shape.as_list()[0]
960    if n != (m2[0]).shape.as_list()[0]:
961      raise ValueError("The entries in matrices m1 and m2 "
962                       "must have the same dimensions!")
963    k = len(m1)
964    l = len(m2)
965    result = {}
966    size = k + l - 1
967    # Compute matrix convolution between m1 and m2.
968    for i in range(size):
969      result[i] = array_ops.zeros([n, n], self.dtype)
970      for index in range(min(k, i + 1)):
971        if (i - index) < l:
972          result[i] += math_ops.matmul(m1[index], m2[i - index])
973    return result
974
975  def _orthogonal_kernel(self, ksize, cin, cout):
976    """Construct orthogonal kernel for convolution.
977
978    Args:
979      ksize: Kernel size.
980      cin: Number of input channels.
981      cout: Number of output channels.
982
983    Returns:
984      An [ksize, ksize, cin, cout] orthogonal kernel.
985    Raises:
986      ValueError: If cin > cout.
987    """
988    if cin > cout:
989      raise ValueError("The number of input channels cannot exceed "
990                       "the number of output channels.")
991    orth = self._orthogonal_matrix(cout)[0:cin, :]
992    if ksize == 1:
993      return array_ops.expand_dims(orth, 0)
994
995    p = self._block_orth(self._symmetric_projection(cout))
996    for _ in range(ksize - 2):
997      temp = self._block_orth(self._symmetric_projection(cout))
998      p = self._matrix_conv(p, temp)
999    for i in range(ksize):
1000      p[i] = math_ops.matmul(orth, p[i])
1001
1002    return self._dict_to_tensor(p, ksize)
1003
1004
1005class ConvolutionOrthogonal3D(ConvolutionOrthogonal):
1006  """Initializer that generates a 3D orthogonal kernel for ConvNets.
1007
1008  The shape of the tensor must have length 5. The number of input
1009  filters must not exceed the number of output filters.
1010  The orthogonality(==isometry) is exact when the inputs are circular padded.
1011  There are finite-width effects with non-circular padding (e.g. zero padding).
1012  See algorithm 1 (Xiao et al., 2018).
1013
1014  Args:
1015    gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1.
1016      The 2-norm of an input is multiplied by a factor of `gain` after applying
1017      this convolution.
1018    seed: A Python integer. Used to create random seeds. See
1019      `tf.compat.v1.set_random_seed` for behavior.
1020    dtype: Default data type, used if no `dtype` argument is provided when
1021      calling the initializer. Only floating point types are supported.
1022  References:
1023      [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html)
1024      ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf))
1025  """
1026
1027  def __call__(self, shape, dtype=None, partition_info=None):
1028    if dtype is None:
1029      dtype = self.dtype
1030    if len(shape) != 5:
1031      raise ValueError("The tensor to initialize must be five-dimensional")
1032
1033    if shape[-2] > shape[-1]:
1034      raise ValueError("In_filters cannot be greater than out_filters.")
1035
1036    if shape[0] != shape[1] or shape[0] != shape[2]:
1037      raise ValueError("Kernel sizes must be equal.")
1038
1039    kernel = self._orthogonal_kernel(shape[0], shape[-2], shape[-1])
1040    kernel *= math_ops.cast(self.gain, dtype=dtype)
1041    return kernel
1042
1043  def _dict_to_tensor(self, x, k1, k2, k3):
1044    """Convert a dictionary to a tensor.
1045
1046    Args:
1047      x: A k1 * k2 dictionary.
1048      k1: First dimension of x.
1049      k2: Second dimension of x.
1050      k3: Third dimension of x.
1051
1052    Returns:
1053      A k1 * k2 * k3 tensor.
1054    """
1055
1056    return array_ops.stack([array_ops.stack(
1057        [array_ops.stack([x[i, j, k] for k in range(k3)])
1058         for j in range(k2)]) for i in range(k1)])
1059
1060  def _block_orth(self, p1, p2, p3):
1061    """Construct a 3 x 3 kernel.
1062
1063    Used to construct orthgonal kernel.
1064
1065    Args:
1066      p1: A symmetric projection matrix.
1067      p2: A symmetric projection matrix.
1068      p3: A symmetric projection matrix.
1069
1070    Returns:
1071      A 2 x 2 x 2 kernel.
1072    Raises:
1073      ValueError: If the dimensions of p1, p2 and p3 are different.
1074    """
1075    p1_shape = p1.shape.as_list()
1076    if p1_shape != p2.shape.as_list() or p1_shape != p3.shape.as_list():
1077      raise ValueError("The dimension of the matrices must be the same.")
1078    n = p1_shape[0]
1079    eye = linalg_ops_impl.eye(n, dtype=self.dtype)
1080    kernel2x2x2 = {}
1081
1082    def matmul(p1, p2, p3):
1083      return math_ops.matmul(math_ops.matmul(p1, p2), p3)
1084
1085    def cast(i, p):
1086      """Return p or (1-p)."""
1087      return i * p + (1 - i) * (eye - p)
1088
1089    for i in [0, 1]:
1090      for j in [0, 1]:
1091        for k in [0, 1]:
1092          kernel2x2x2[i, j, k] = matmul(cast(i, p1), cast(j, p2), cast(k, p3))
1093    return kernel2x2x2
1094
1095  def _matrix_conv(self, m1, m2):
1096    """Matrix convolution.
1097
1098    Args:
1099      m1: is a k x k x k  dictionary, each element is a n x n matrix.
1100      m2: is a l x l x l dictionary, each element is a n x n matrix.
1101
1102    Returns:
1103      (k + l - 1) x (k + l - 1) x (k + l - 1) dictionary each
1104      element is a n x n matrix.
1105    Raises:
1106      ValueError: if the entries of m1 and m2 are of different dimensions.
1107    """
1108
1109    n = (m1[0, 0, 0]).shape.as_list()[0]
1110    if n != (m2[0, 0, 0]).shape.as_list()[0]:
1111      raise ValueError("The entries in matrices m1 and m2 "
1112                       "must have the same dimensions!")
1113    k = int(np.cbrt(len(m1)))
1114    l = int(np.cbrt(len(m2)))
1115    result = {}
1116    size = k + l - 1
1117    # Compute matrix convolution between m1 and m2.
1118    for i in range(size):
1119      for j in range(size):
1120        for r in range(size):
1121          result[i, j, r] = array_ops.zeros([n, n], self.dtype)
1122          for index1 in range(min(k, i + 1)):
1123            for index2 in range(min(k, j + 1)):
1124              for index3 in range(min(k, r + 1)):
1125                if (i - index1) < l and (j - index2) < l and (r - index3) < l:
1126                  result[i, j, r] += math_ops.matmul(
1127                      m1[index1, index2, index3],
1128                      m2[i - index1, j - index2, r - index3])
1129    return result
1130
1131  def _orthogonal_kernel(self, ksize, cin, cout):
1132    """Construct orthogonal kernel for convolution.
1133
1134    Args:
1135      ksize: Kernel size.
1136      cin: Number of input channels.
1137      cout: Number of output channels.
1138
1139    Returns:
1140      An [ksize, ksize, ksize, cin, cout] orthogonal kernel.
1141    Raises:
1142      ValueError: If cin > cout.
1143    """
1144    if cin > cout:
1145      raise ValueError("The number of input channels cannot exceed "
1146                       "the number of output channels.")
1147    orth = self._orthogonal_matrix(cout)[0:cin, :]
1148    if ksize == 1:
1149      return array_ops.expand_dims(
1150          array_ops.expand_dims(array_ops.expand_dims(orth, 0), 0), 0)
1151
1152    p = self._block_orth(
1153        self._symmetric_projection(cout), self._symmetric_projection(cout),
1154        self._symmetric_projection(cout))
1155    for _ in range(ksize - 2):
1156      temp = self._block_orth(
1157          self._symmetric_projection(cout), self._symmetric_projection(cout),
1158          self._symmetric_projection(cout))
1159      p = self._matrix_conv(p, temp)
1160    for i in range(ksize):
1161      for j in range(ksize):
1162        for k in range(ksize):
1163          p[i, j, k] = math_ops.matmul(orth, p[i, j, k])
1164
1165    return self._dict_to_tensor(p, ksize, ksize, ksize)
1166
1167
1168@tf_export(v1=["initializers.identity"])
1169@deprecation.deprecated_endpoints("initializers.identity")
1170class Identity(Initializer):
1171  """Initializer that generates the identity matrix.
1172
1173  Only use for 2D matrices.
1174
1175  Args:
1176    gain: Multiplicative factor to apply to the identity matrix.
1177    dtype: Default data type, used if no `dtype` argument is provided when
1178      calling the initializer. Only floating point types are supported.
1179  """
1180
1181  @deprecated_args(None,
1182                   "Call initializer instance with the dtype argument instead "
1183                   "of passing it to the constructor", "dtype")
1184  def __init__(self, gain=1.0, dtype=dtypes.float32):
1185    self.gain = gain
1186    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
1187
1188  def __call__(self, shape, dtype=None, partition_info=None):
1189    full_shape = shape if partition_info is None else partition_info.full_shape
1190    if len(full_shape) != 2:
1191      raise ValueError(
1192          "Identity matrix initializer can only be used for 2D matrices.")
1193    if dtype is None:
1194      dtype = self.dtype
1195    if isinstance(full_shape, tensor_shape.TensorShape):
1196      full_shape = full_shape.as_list()
1197    initializer = linalg_ops_impl.eye(*full_shape, dtype=dtype)
1198    if partition_info is not None:
1199      initializer = array_ops.slice(initializer, partition_info.var_offset,
1200                                    shape)
1201    return self.gain * initializer
1202
1203  def get_config(self):
1204    return {"gain": self.gain, "dtype": self.dtype.name}
1205
1206
1207@tf_export(v1=["glorot_uniform_initializer", "initializers.glorot_uniform"])
1208@deprecation.deprecated_endpoints("glorot_uniform_initializer",
1209                                  "initializers.glorot_uniform")
1210class GlorotUniform(VarianceScaling):
1211  """The Glorot uniform initializer, also called Xavier uniform initializer.
1212
1213  It draws samples from a uniform distribution within [-limit, limit]
1214  where `limit` is `sqrt(6 / (fan_in + fan_out))`
1215  where `fan_in` is the number of input units in the weight tensor
1216  and `fan_out` is the number of output units in the weight tensor.
1217
1218  Args:
1219    seed: A Python integer. Used to create random seeds. See
1220      `tf.compat.v1.set_random_seed` for behavior.
1221    dtype: Default data type, used if no `dtype` argument is provided when
1222      calling the initializer. Only floating point types are supported.
1223  References:
1224      [Glorot et al., 2010](http://proceedings.mlr.press/v9/glorot10a.html)
1225      ([pdf](http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf))
1226  """
1227
1228  @deprecated_args(None,
1229                   "Call initializer instance with the dtype argument instead "
1230                   "of passing it to the constructor", "dtype")
1231  def __init__(self, seed=None, dtype=dtypes.float32):
1232    super(GlorotUniform, self).__init__(
1233        scale=1.0, mode="fan_avg", distribution="uniform", seed=seed)
1234
1235  def get_config(self):
1236    return {"seed": self.seed, "dtype": self.dtype.name}
1237
1238
1239@tf_export(v1=["glorot_normal_initializer", "initializers.glorot_normal"])
1240@deprecation.deprecated_endpoints("glorot_normal_initializer",
1241                                  "initializers.glorot_normal")
1242class GlorotNormal(VarianceScaling):
1243  """The Glorot normal initializer, also called Xavier normal initializer.
1244
1245  It draws samples from a truncated normal distribution centered on 0
1246  with standard deviation (after truncation) given by
1247  `stddev = sqrt(2 / (fan_in + fan_out))` where `fan_in` is the number
1248  of input units in the weight tensor and `fan_out` is the number of
1249  output units in the weight tensor.
1250
1251  Args:
1252    seed: A Python integer. Used to create random seeds. See
1253      `tf.compat.v1.set_random_seed` for behavior.
1254    dtype: Default data type, used if no `dtype` argument is provided when
1255      calling the initializer. Only floating point types are supported.
1256  References:
1257      [Glorot et al., 2010](http://proceedings.mlr.press/v9/glorot10a.html)
1258      ([pdf](http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf))
1259  """
1260
1261  @deprecated_args(None,
1262                   "Call initializer instance with the dtype argument instead "
1263                   "of passing it to the constructor", "dtype")
1264  def __init__(self, seed=None, dtype=dtypes.float32):
1265    super(GlorotNormal, self).__init__(
1266        scale=1.0, mode="fan_avg", distribution="truncated_normal", seed=seed)
1267
1268  def get_config(self):
1269    return {"seed": self.seed, "dtype": self.dtype.name}
1270
1271
1272# Aliases.
1273
1274# pylint: disable=invalid-name
1275zeros_initializer = Zeros
1276ones_initializer = Ones
1277constant_initializer = Constant
1278random_uniform_initializer = RandomUniform
1279random_normal_initializer = RandomNormal
1280truncated_normal_initializer = TruncatedNormal
1281uniform_unit_scaling_initializer = UniformUnitScaling
1282variance_scaling_initializer = VarianceScaling
1283glorot_uniform_initializer = GlorotUniform
1284glorot_normal_initializer = GlorotNormal
1285orthogonal_initializer = Orthogonal
1286identity_initializer = Identity
1287convolutional_delta_orthogonal = ConvolutionDeltaOrthogonal
1288convolutional_orthogonal_1d = ConvolutionOrthogonal1D
1289convolutional_orthogonal_2d = ConvolutionOrthogonal2D
1290convolutional_orthogonal_3d = ConvolutionOrthogonal3D
1291# pylint: enable=invalid-name
1292
1293
1294@tf_export(v1=["initializers.lecun_normal"])
1295def lecun_normal(seed=None):
1296  """LeCun normal initializer.
1297
1298  It draws samples from a truncated normal distribution centered on 0
1299  with standard deviation (after truncation) given by
1300  `stddev = sqrt(1 / fan_in)` where `fan_in` is the number of
1301  input units in the weight tensor.
1302
1303  Args:
1304      seed: A Python integer. Used to seed the random generator.
1305
1306  Returns:
1307      An initializer.
1308
1309  References:
1310      - Self-Normalizing Neural Networks,
1311      [Klambauer et al.,
1312      2017](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks)
1313      # pylint: disable=line-too-long
1314      ([pdf](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf))
1315      - Efficient Backprop,
1316      [Lecun et al., 1998](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf)
1317  """
1318  return VarianceScaling(
1319      scale=1., mode="fan_in", distribution="truncated_normal", seed=seed)
1320
1321
1322@tf_export(v1=["initializers.lecun_uniform"])
1323def lecun_uniform(seed=None):
1324  """LeCun uniform initializer.
1325
1326  It draws samples from a uniform distribution within [-limit, limit]
1327  where `limit` is `sqrt(3 / fan_in)`
1328  where `fan_in` is the number of input units in the weight tensor.
1329
1330  Args:
1331      seed: A Python integer. Used to seed the random generator.
1332
1333  Returns:
1334      An initializer.
1335
1336  References:
1337      - Self-Normalizing Neural Networks,
1338      [Klambauer et al.,
1339      2017](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks)
1340      # pylint: disable=line-too-long
1341      ([pdf](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf))
1342      - Efficient Backprop,
1343      [Lecun et al., 1998](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf)
1344  """
1345  return VarianceScaling(
1346      scale=1., mode="fan_in", distribution="uniform", seed=seed)
1347
1348
1349@tf_export(v1=["initializers.he_normal"])
1350def he_normal(seed=None):
1351  """He normal initializer.
1352
1353  It draws samples from a truncated normal distribution centered on 0
1354  with standard deviation (after truncation) given by
1355  `stddev = sqrt(2 / fan_in)` where `fan_in` is the number of
1356  input units in the weight tensor.
1357
1358  Args:
1359      seed: A Python integer. Used to seed the random generator.
1360
1361  Returns:
1362      An initializer.
1363
1364  References:
1365      [He et al., 2015]
1366      (https://www.cv-foundation.org/openaccess/content_iccv_2015/html/He_Delving_Deep_into_ICCV_2015_paper.html)
1367      # pylint: disable=line-too-long
1368      ([pdf](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf))
1369  """
1370  return VarianceScaling(
1371      scale=2., mode="fan_in", distribution="truncated_normal", seed=seed)
1372
1373
1374@tf_export(v1=["initializers.he_uniform"])
1375def he_uniform(seed=None):
1376  """He uniform variance scaling initializer.
1377
1378  It draws samples from a uniform distribution within [-limit, limit]
1379  where `limit` is `sqrt(6 / fan_in)`
1380  where `fan_in` is the number of input units in the weight tensor.
1381
1382  Args:
1383      seed: A Python integer. Used to seed the random generator.
1384
1385  Returns:
1386      An initializer.
1387
1388  References:
1389      [He et al., 2015]
1390      (https://www.cv-foundation.org/openaccess/content_iccv_2015/html/He_Delving_Deep_into_ICCV_2015_paper.html)
1391      # pylint: disable=line-too-long
1392      ([pdf](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf))
1393  """
1394  return VarianceScaling(
1395      scale=2., mode="fan_in", distribution="uniform", seed=seed)
1396
1397
1398# Utility functions.
1399
1400
1401def _compute_fans(shape):
1402  """Computes the number of input and output units for a weight shape.
1403
1404  Args:
1405    shape: Integer shape tuple or TF tensor shape.
1406
1407  Returns:
1408    A tuple of integer scalars (fan_in, fan_out).
1409  """
1410  if len(shape) < 1:  # Just to avoid errors for constants.
1411    fan_in = fan_out = 1
1412  elif len(shape) == 1:
1413    fan_in = fan_out = shape[0]
1414  elif len(shape) == 2:
1415    fan_in = shape[0]
1416    fan_out = shape[1]
1417  else:
1418    # Assuming convolution kernels (2D, 3D, or more).
1419    # kernel shape: (..., input_depth, depth)
1420    receptive_field_size = 1
1421    for dim in shape[:-2]:
1422      receptive_field_size *= dim
1423    fan_in = shape[-2] * receptive_field_size
1424    fan_out = shape[-1] * receptive_field_size
1425  return int(fan_in), int(fan_out)
1426
1427
1428def _assert_float_dtype(dtype):
1429  """Validate and return floating point type based on `dtype`.
1430
1431  `dtype` must be a floating point type.
1432
1433  Args:
1434    dtype: The data type to validate.
1435
1436  Returns:
1437    Validated type.
1438
1439  Raises:
1440    ValueError: if `dtype` is not a floating point type.
1441  """
1442  if not dtype.is_floating:
1443    raise ValueError("Expected floating point type, got %s." % dtype)
1444  return dtype
1445