• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operations often used for initializing tensors.
16
17All variable initializers returned by functions in this file should have the
18following signature:
19
20def _initializer(shape, dtype=dtypes.float32, partition_info=None):
21  Args:
22    shape: List of `int` representing the shape of the output `Tensor`. Some
23      initializers may also be able to accept a `Tensor`.
24    dtype: (Optional) Type of the output `Tensor`.
25    partition_info: (Optional) variable_scope._PartitionInfo object holding
26      additional information about how the variable is partitioned. May be
27      `None` if the variable is not partitioned.
28
29  Returns:
30    A `Tensor` of type `dtype` and `shape`.
31"""
32from __future__ import absolute_import
33from __future__ import division
34from __future__ import print_function
35
36import math
37
38import numpy as np
39
40from tensorflow.python.framework import constant_op
41from tensorflow.python.framework import dtypes
42from tensorflow.python.framework import tensor_shape
43from tensorflow.python.ops import array_ops
44from tensorflow.python.ops import gen_linalg_ops
45from tensorflow.python.ops import linalg_ops_impl
46from tensorflow.python.ops import math_ops
47from tensorflow.python.ops import random_ops
48from tensorflow.python.util import deprecation
49from tensorflow.python.util.deprecation import deprecated
50from tensorflow.python.util.deprecation import deprecated_arg_values
51from tensorflow.python.util.deprecation import deprecated_args
52from tensorflow.python.util.tf_export import tf_export
53
54
55class Initializer(object):
56  """Initializer base class: all initializers inherit from this class."""
57
58  def __call__(self, shape, dtype=None, partition_info=None):
59    """Returns a tensor object initialized as specified by the initializer.
60
61    Args:
62      shape: Shape of the tensor.
63      dtype: Optional dtype of the tensor. If not provided use the initializer
64        dtype.
65      partition_info: Optional information about the possible partitioning of a
66        tensor.
67    """
68    raise NotImplementedError
69
70  def get_config(self):
71    """Returns the configuration of the initializer as a JSON-serializable dict.
72
73    Returns:
74      A JSON-serializable Python dict.
75    """
76    return {}
77
78  @classmethod
79  def from_config(cls, config):
80    """Instantiates an initializer from a configuration dictionary.
81
82    Example:
83
84    ```python
85    initializer = RandomUniform(-1, 1)
86    config = initializer.get_config()
87    initializer = RandomUniform.from_config(config)
88    ```
89
90    Args:
91      config: A Python dictionary. It will typically be the output of
92        `get_config`.
93
94    Returns:
95      An Initializer instance.
96    """
97    return cls(**config)
98
99
100@tf_export(v1=["initializers.zeros", "zeros_initializer"])
101@deprecation.deprecated_endpoints("initializers.zeros")
102class Zeros(Initializer):
103  """Initializer that generates tensors initialized to 0.
104
105  @compatibility(TF2)
106  `tf.compat.v1.zeros_initializer` is compatible with eager execution
107  and `tf.function`.
108
109  To migrate to TF2, please use `tf.zeros_initializer` instead. The `dtype`
110  argument in `tf.compat.v1.zeros_initializer.__init__()` does not exist in
111  `tf.zeros_initializer.__init__()`. However, you can specify the `dtype` in
112  `__call__()` in both cases.
113
114  #### Structural Mapping to Native TF2
115
116  Before:
117
118  ```python
119  initializer = tf.compat.v1.zeros_initializer(dtype=tf.float32)
120  variable = tf.Variable(initializer(shape=[3, 3]))
121  ```
122
123  After:
124
125  ```python
126  initializer = tf.zeros_initializer()
127  variable = tf.Variable(initializer(shape=[3, 3], dtype=tf.float32))
128  ```
129
130  #### How to Map Arguments
131
132  | TF1 Arg Name         | TF2 Arg Name     | Note                       |
133  | :------------------- | :--------------- | :------------------------- |
134  | `dtype`              | `dtype`          | In `__call__()` method     |
135  | `partition_info`     | - |  (`__call__` arg in TF1) Not supported    |
136
137
138  #### Before & After Usage Example
139
140  Before:
141
142  >>> initializer = tf.compat.v1.zeros_initializer(dtype=tf.float32)
143  >>> tf.Variable(initializer(shape=[3])).numpy()
144  array([0., 0., 0.], dtype=float32)
145  >>> tf.Variable(initializer(shape=[3, 3])).numpy()
146  array([[0., 0., 0.],
147         [0., 0., 0.],
148         [0., 0., 0.]], dtype=float32)
149  >>> initializer = tf.compat.v1.zeros_initializer()
150  >>> tf.Variable(initializer(shape=[3], dtype=tf.float32)).numpy()
151  array([0., 0., 0.], dtype=float32)
152  >>> tf.Variable(initializer(shape=[3, 3], dtype=tf.float32)).numpy()
153  array([[0., 0., 0.],
154         [0., 0., 0.],
155         [0., 0., 0.]], dtype=float32)
156
157  After:
158
159  >>> initializer = tf.zeros_initializer()
160  >>> tf.Variable(initializer(shape=[3], dtype=tf.float32)).numpy()
161  array([0., 0., 0.], dtype=float32)
162  >>> tf.Variable(initializer(shape=[3, 3], dtype=tf.float32)).numpy()
163  array([[0., 0., 0.],
164         [0., 0., 0.],
165         [0., 0., 0.]], dtype=float32)
166
167  @end_compatibility
168  """
169
170  @deprecated_args(None,
171                   "Call initializer instance with the dtype argument instead "
172                   "of passing it to the constructor", "dtype")
173  def __init__(self, dtype=dtypes.float32):
174    self.dtype = dtypes.as_dtype(dtype)
175
176  def __call__(self, shape, dtype=None, partition_info=None):
177    if dtype is None:
178      dtype = self.dtype
179    return array_ops.zeros(shape, dtype)
180
181  def get_config(self):
182    return {"dtype": self.dtype.name}
183
184
185@tf_export(v1=["initializers.ones", "ones_initializer"])
186@deprecation.deprecated_endpoints("initializers.ones", "ones_initializer")
187class Ones(Initializer):
188  """Initializer that generates tensors initialized to 1.
189
190  @compatibility(TF2)
191  This API is compatible with TF2 behavior and `tf.function`, and can be
192  migrated immediately with `tf.keras.initializers.ones`.
193
194  Before:
195  >>> initializer = tf.compat.v1.keras.initializers.ones()
196  >>> initializer((1, 1))
197  <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[1.]], dtype=float32)>
198
199  After:
200  >>> initializer = tf.keras.initializers.ones()
201  >>> initializer((1, 1))
202  <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[1.]], dtype=float32)>
203
204  @end_compatibility
205  """
206
207  @deprecated_args(None,
208                   "Call initializer instance with the dtype argument instead "
209                   "of passing it to the constructor", "dtype")
210  def __init__(self, dtype=dtypes.float32):
211    self.dtype = dtypes.as_dtype(dtype)
212
213  def __call__(self, shape, dtype=None, partition_info=None):
214    if dtype is None:
215      dtype = self.dtype
216    return array_ops.ones(shape, dtype)
217
218  def get_config(self):
219    return {"dtype": self.dtype.name}
220
221
222@tf_export(v1=["initializers.constant", "constant_initializer"])
223@deprecation.deprecated_endpoints("constant_initializer")
224class Constant(Initializer):
225  """Initializer that generates tensors with constant values.
226
227  The resulting tensor is populated with values of type `dtype`, as
228  specified by arguments `value` following the desired `shape` of the
229  new tensor (see examples below).
230
231  The argument `value` can be a constant value, or a list of values of type
232  `dtype`. If `value` is a list, then the length of the list must be less
233  than or equal to the number of elements implied by the desired shape of the
234  tensor. In the case where the total number of elements in `value` is less
235  than the number of elements required by the tensor shape, the last element
236  in `value` will be used to fill the remaining entries. If the total number of
237  elements in `value` is greater than the number of elements required by the
238  tensor shape, the initializer will raise a `ValueError`.
239
240  Args:
241    value: A Python scalar, list or tuple of values, or a N-dimensional numpy
242      array. All elements of the initialized variable will be set to the
243      corresponding value in the `value` argument.
244    dtype: Default data type, used if no `dtype` argument is provided when
245      calling the initializer.
246    verify_shape: Boolean that enables verification of the shape of `value`. If
247      `True`, the initializer will throw an error if the shape of `value` is not
248      compatible with the shape of the initialized tensor.
249
250  Raises:
251    TypeError: If the input `value` is not one of the expected types.
252
253  Examples:
254    The following example can be rewritten using a numpy.ndarray instead
255    of the `value` list, even reshaped, as shown in the two commented lines
256    below the `value` list initialization.
257
258  >>> value = [0, 1, 2, 3, 4, 5, 6, 7]
259  >>> init = tf.compat.v1.constant_initializer(value)
260  >>> # fitting shape
261  >>> with tf.compat.v1.Session():
262  ...   x = tf.compat.v1.get_variable('x', shape=[2, 4], initializer=init)
263  ...   x.initializer.run()
264  ...   print(x.eval())
265  [[0. 1. 2. 3.]
266   [4. 5. 6. 7.]]
267  >>> # Larger shape
268  >>> with tf.compat.v1.Session():
269  ...   y = tf.compat.v1.get_variable('y', shape=[3, 4], initializer=init)
270  ...   y.initializer.run()
271  ...   print(y.eval())
272  [[0.  1.  2.  3.]
273   [4.  5.  6.  7.]
274   [7.  7.  7.  7.]]
275  >>> # Smaller shape
276  >>> with tf.compat.v1.Session():
277  ...   z = tf.compat.v1.get_variable('z', shape=[2, 3], initializer=init)
278  Traceback (most recent call last):
279  ...
280  ValueError: Too many elements provided. Needed at most 6, but received 8
281  >>> # Shape verification
282  >>> init_verify = tf.compat.v1.constant_initializer(value, verify_shape=True)
283  >>> with tf.compat.v1.Session():
284  ...  u = tf.compat.v1.get_variable('u', shape=[3, 4],
285  ...                                initializer=init_verify)
286  Traceback (most recent call last):
287  ...
288  TypeError: Expected Tensor's shape: (3, 4), got (8,).
289
290  @compatibility(TF2)
291  Although it is a legacy API endpoint, `tf.compat.v1.constant_initializer`
292  is compatible with eager execution and `tf.function`.
293
294  To migrate to a non-legacy TF2 API, please use `tf.constant_initializer`
295  instead. The `dtype`
296  argument in `tf.compat.v1.constant_initializer.__init__()` does not exist in
297  `tf.constant_initializer.__init__()`. However, you can specify the `dtype` in
298  `__call__()` in both cases.
299
300  In the `compat.v1` symbol, if `verify_shape` is set to `True`, an exception
301  is raised when initializing a variable with a different shape from
302  `value`. If set to `False`, `value` is reshaped to initialize the variable
303  if necessary. An exception would only be raised when the number of
304  elements are different.
305
306  The `verify_shape` argument is not supported in TF2. Using
307  `tf.constant_initializer` is equivalent to setting `verify_shape` to `False`.
308
309  #### Structural Mapping to Native TF2
310
311  Before:
312
313  ```python
314  value = [0, 1, 2, 3, 4, 5, 6, 7]
315  initializer = tf.compat.v1.constant_initializer(
316      value=value,
317      dtype=tf.float32,
318      verify_shape=False)
319  variable = tf.Variable(initializer(shape=[2, 4]))
320  ```
321
322  After:
323
324  ```python
325  value = [0, 1, 2, 3, 4, 5, 6, 7]
326  initializer = tf.constant_initializer(value=value)
327  tf.Variable(initializer(shape=[2, 4], dtype=tf.float32))
328  ```
329
330  #### How to Map Arguments
331
332  | TF1 Arg Name          | TF2 Arg Name     | Note                        |
333  | :-------------------- | :--------------- | :-------------------------- |
334  | `value`               | `value`          | In constructor              |
335  | `dtype`               | `dtype`          | In `__call__()` method      |
336  | `verify_shape`        | Not Supported    | Equivalent to set to `False`|
337  | `partition_info`      | - |  (`__call__` arg in TF1) Not supported     |
338
339
340  #### Before & After Usage Example
341
342  Before:
343
344  >>> value = [1., 2., 3., 4.]
345  >>> initializer = tf.compat.v1.constant_initializer(
346  ...     value=value, dtype=tf.float32, verify_shape=True)
347  >>> tf.Variable(initializer(shape=[2, 2])).numpy()
348  Traceback (most recent call last):
349  ...
350  TypeError: Expected Tensor's shape: (2, 2), got (4,).
351  >>> initializer = tf.compat.v1.constant_initializer(
352  ...     value=value, dtype=tf.float32, verify_shape=False)
353  >>> tf.Variable(initializer(shape=[2, 2])).numpy()
354  array([[1., 2.],
355         [3., 4.]], dtype=float32)
356
357  After:
358
359  >>> value = [1., 2., 3., 4.]
360  >>> initializer = tf.constant_initializer(value=value)
361  >>> tf.Variable(initializer(shape=[2, 2], dtype=tf.float32)).numpy()
362  array([[1., 2.],
363         [3., 4.]], dtype=float32)
364
365  @end_compatibility
366  """
367
368  @deprecated_args(None,
369                   "Call initializer instance with the dtype argument instead "
370                   "of passing it to the constructor", "dtype")
371  @deprecated_args(None, "Objects must now be the required shape or no shape "
372                   "can be specified", "verify_shape")
373  def __init__(self, value=0, dtype=dtypes.float32, verify_shape=False):
374    if not (np.isscalar(value) or isinstance(value, (list, tuple, np.ndarray))):
375      raise TypeError(
376          "Invalid type for initial value: %s (expected Python scalar, list or "
377          "tuple of values, or numpy.ndarray)." % type(value))
378
379    self.value = value
380    self.dtype = dtypes.as_dtype(dtype)
381    self._verify_shape = verify_shape
382
383  def __call__(self, shape, dtype=None, partition_info=None, verify_shape=None):
384    if dtype is None:
385      dtype = self.dtype
386    if verify_shape is None:
387      verify_shape = self._verify_shape
388    return constant_op.constant_v1(
389        self.value, dtype=dtype, shape=shape, verify_shape=verify_shape)
390
391  def get_config(self):
392    # We don't include `verify_shape` for compatibility with Keras.
393    # `verify_shape` should be passed as an argument to `__call__` rather
394    # than as a constructor argument: conceptually it isn't a property
395    # of the initializer.
396    return {"value": self.value, "dtype": self.dtype.name}
397
398
399@tf_export(v1=["initializers.random_uniform", "random_uniform_initializer"])
400@deprecation.deprecated_endpoints("initializers.random_uniform")
401class RandomUniform(Initializer):
402  """Initializer that generates tensors with a uniform distribution.
403
404  Args:
405    minval: A python scalar or a scalar tensor. Lower bound of the range of
406      random values to generate.
407    maxval: A python scalar or a scalar tensor. Upper bound of the range of
408      random values to generate.  Defaults to 1 for float types.
409    seed: A Python integer. Used to create random seeds. See
410      `tf.compat.v1.set_random_seed` for behavior.
411    dtype: Default data type, used if no `dtype` argument is provided when
412      calling the initializer.
413
414  @compatibility(TF2)
415  Although it is a legacy compat.v1 API, this symbol is compatible with eager
416  execution and `tf.function`.
417
418  To switch to native TF2, switch to using either
419  `tf.initializers.RandomUniform` or `tf.keras.initializers.RandomUniform`
420  (neither from `compat.v1`) and
421  pass the dtype when calling the initializer. Keep in mind that
422  the default minval, maxval and the behavior of fixed seeds have changed.
423
424  Random seed behavior:
425  Also be aware that if you pass a seed to the TF2 initializer
426  API it will reuse that same seed for every single initialization
427  (unlike the TF1 intializer)
428
429  #### Structural Mapping to Native TF2
430
431  Before:
432
433  ```python
434  initializer = tf.compat.v1.random_uniform_initializer(
435    minval=minval,
436    maxval=maxval,
437    seed=seed,
438    dtype=dtype)
439
440  weight_one = tf.Variable(initializer(shape_one))
441  weight_two = tf.Variable(initializer(shape_two))
442  ```
443
444  After:
445
446  ```python
447  initializer = tf.initializers.RandomUniform(
448    minval=minval,
449    maxval=maxval,
450    # seed=seed,  # Setting a seed in the native TF2 API
451                  # causes it to produce the same initializations
452                  # across multiple calls of the same initializer.
453    )
454
455  weight_one = tf.Variable(initializer(shape_one, dtype=dtype))
456  weight_two = tf.Variable(initializer(shape_two, dtype=dtype))
457  ```
458
459  #### How to Map Arguments
460
461  | TF1 Arg Name          | TF2 Arg Name    | Note                       |
462  | :-------------------- | :-------------- | :------------------------- |
463  | `minval`               | `minval`    | Default changes from 0 to -0.05 |
464  | `maxval`         | `maxval`        | Default changes from 1.0 to 0.05 |
465  | `seed`             | `seed` | Different random number generation |
466  :                    :        : semantics (to change in a  :
467  :                    :        : future version). If set, the TF2 version :
468  :                    :        : will use stateless random number :
469  :                    :        : generation which will produce the exact :
470  :                    :        : same initialization even across multiple :
471  :                    :        : calls of the initializer instance. the :
472  :                    :        : `compat.v1` version will generate new :
473  :                    :        : initializations each time. Do not set :
474  :                    :        : a seed if you need different          :
475  :                    :        : initializations each time. Instead    :
476  :                    :        : either set a global tf seed with      :
477  :                    :        : `tf.random.set_seed` if you need :
478  :                    :        : determinism, or initialize each weight :
479  :                    :        : with a separate initializer instance  :
480  :                    :        : and a different seed.                 :
481  | `dtype` | `dtype`   | The TF2 native api only takes it  |
482  :                     :      : as a `__call__` arg, not a constructor arg. :
483  | `partition_info`     | - |  (`__call__` arg in TF1) Not supported       |
484
485  #### Example of fixed-seed behavior differences
486
487  `compat.v1` Fixed seed behavior:
488
489  >>> initializer = tf.compat.v1.random_uniform_initializer(seed=10)
490  >>> a = initializer(shape=(2, 2))
491  >>> b = initializer(shape=(2, 2))
492  >>> tf.reduce_sum(a - b) == 0
493  <tf.Tensor: shape=(), dtype=bool, numpy=False>
494
495  After:
496
497  >>> initializer = tf.initializers.RandomUniform(seed=10)
498  >>> a = initializer(shape=(2, 2))
499  >>> b = initializer(shape=(2, 2))
500  >>> tf.reduce_sum(a - b) == 0
501  <tf.Tensor: shape=(), dtype=bool, numpy=True>
502
503  @end_compatibility
504  """
505
506  @deprecated_args(None,
507                   "Call initializer instance with the dtype argument instead "
508                   "of passing it to the constructor", "dtype")
509  def __init__(self, minval=0, maxval=None, seed=None, dtype=dtypes.float32):
510    self.minval = minval
511    self.maxval = maxval
512    self.seed = seed
513    self.dtype = dtypes.as_dtype(dtype)
514
515  def __call__(self, shape, dtype=None, partition_info=None):
516    if dtype is None:
517      dtype = self.dtype
518    return random_ops.random_uniform(
519        shape, self.minval, self.maxval, dtype, seed=self.seed)
520
521  def get_config(self):
522    return {
523        "minval": self.minval,
524        "maxval": self.maxval,
525        "seed": self.seed,
526        "dtype": self.dtype.name
527    }
528
529
530@tf_export(v1=["initializers.random_normal", "random_normal_initializer"])
531@deprecation.deprecated_endpoints("initializers.random_normal")
532class RandomNormal(Initializer):
533  """Initializer that generates tensors with a normal distribution.
534
535  Args:
536    mean: a python scalar or a scalar tensor. Mean of the random values to
537      generate.
538    stddev: a python scalar or a scalar tensor. Standard deviation of the random
539      values to generate.
540    seed: A Python integer. Used to create random seeds. See
541      `tf.compat.v1.set_random_seed` for behavior.
542    dtype: Default data type, used if no `dtype` argument is provided when
543      calling the initializer. Only floating point types are supported.
544
545  @compatibility(TF2)
546  Although it is a legacy `compat.v1` API, this symbol is compatible with eager
547  execution and `tf.function`.
548
549  To switch to native TF2, switch to using either
550  `tf.initializers.RandomNormal` or `tf.keras.initializers.RandomNormal`
551  (neither from `compat.v1`) and
552  pass the dtype when calling the initializer. Keep in mind that
553  the default stddev and the behavior of fixed seeds have changed.
554
555  Random seed behavior:
556  Also be aware that if you pass a seed to the TF2 initializer
557  API it will reuse that same seed for every single initialization
558  (unlike the TF1 intializer)
559
560  #### Structural Mapping to Native TF2
561
562  Before:
563
564  ```python
565  initializer = tf.compat.v1.random_normal_initializer(
566    mean=mean,
567    stddev=stddev,
568    seed=seed,
569    dtype=dtype)
570
571  weight_one = tf.Variable(initializer(shape_one))
572  weight_two = tf.Variable(initializer(shape_two))
573  ```
574
575  After:
576
577  ```python
578  initializer = tf.initializers.RandomNormal(
579    mean=mean,
580    # seed=seed,  # Setting a seed in the native TF2 API
581                  # causes it to produce the same initializations
582                  # across multiple calls of the same initializer.
583    stddev=stddev)
584
585  weight_one = tf.Variable(initializer(shape_one, dtype=dtype))
586  weight_two = tf.Variable(initializer(shape_two, dtype=dtype))
587  ```
588
589  #### How to Map Arguments
590
591  | TF1 Arg Name       | TF2 Arg Name    | Note                       |
592  | :----------------- | :-------------- | :------------------------- |
593  | `mean`             | `mean`          | No change to defaults |
594  | `stddev`           | `stddev`        | Default changes from 1.0 to 0.05 |
595  | `seed`             | `seed` | Different random number generation |
596  :                    :        : semantics (to change in a  :
597  :                    :        : future version). If set, the TF2 version :
598  :                    :        : will use stateless random number :
599  :                    :        : generation which will produce the exact :
600  :                    :        : same initialization even across multiple :
601  :                    :        : calls of the initializer instance. the :
602  :                    :        : `compat.v1` version will generate new :
603  :                    :        : initializations each time. Do not set :
604  :                    :        : a seed if you need different          :
605  :                    :        : initializations each time. Instead    :
606  :                    :        : either set a global tf seed with      :
607  :                    :        : `tf.random.set_seed` if you need :
608  :                    :        : determinism, or initialize each weight :
609  :                    :        : with a separate initializer instance  :
610  :                    :        : and a different seed.                 :
611  | `dtype`            | `dtype`  | The TF2 native api only takes it as a |
612  :                    :          : `__call__` arg, not a constructor arg. :
613  | `partition_info`   | -     |  (`__call__` arg in TF1) Not supported.  |
614
615  #### Example of fixed-seed behavior differences
616
617  `compat.v1` Fixed seed behavior:
618
619  >>> initializer = tf.compat.v1.random_normal_initializer(seed=10)
620  >>> a = initializer(shape=(2, 2))
621  >>> b = initializer(shape=(2, 2))
622  >>> tf.reduce_sum(a - b) == 0
623  <tf.Tensor: shape=(), dtype=bool, numpy=False>
624
625  After:
626
627  >>> initializer = tf.initializers.RandomNormal(seed=10)
628  >>> a = initializer(shape=(2, 2))
629  >>> b = initializer(shape=(2, 2))
630  >>> tf.reduce_sum(a - b) == 0
631  <tf.Tensor: shape=(), dtype=bool, numpy=True>
632
633  @end_compatibility
634  """
635
636  @deprecated_args(None,
637                   "Call initializer instance with the dtype argument instead "
638                   "of passing it to the constructor", "dtype")
639  def __init__(self, mean=0.0, stddev=1.0, seed=None, dtype=dtypes.float32):
640    self.mean = mean
641    self.stddev = stddev
642    self.seed = seed
643    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
644
645  def __call__(self, shape, dtype=None, partition_info=None):
646    if dtype is None:
647      dtype = self.dtype
648    return random_ops.random_normal(
649        shape, self.mean, self.stddev, dtype, seed=self.seed)
650
651  def get_config(self):
652    return {
653        "mean": self.mean,
654        "stddev": self.stddev,
655        "seed": self.seed,
656        "dtype": self.dtype.name
657    }
658
659
660@tf_export(v1=["initializers.truncated_normal", "truncated_normal_initializer"])
661@deprecation.deprecated_endpoints("initializers.truncated_normal",
662                                  "truncated_normal_initializer")
663class TruncatedNormal(Initializer):
664  """Initializer that generates a truncated normal distribution.
665
666  These values are similar to values from a `random_normal_initializer`
667  except that values more than two standard deviations from the mean
668  are discarded and re-drawn. This is the recommended initializer for
669  neural network weights and filters.
670
671  Args:
672    mean: a python scalar or a scalar tensor. Mean of the random values to
673      generate.
674    stddev: a python scalar or a scalar tensor. Standard deviation of the random
675      values to generate.
676    seed: A Python integer. Used to create random seeds. See
677      `tf.compat.v1.set_random_seed` for behavior.
678    dtype: Default data type, used if no `dtype` argument is provided when
679      calling the initializer. Only floating point types are supported.
680
681  @compatibility(TF2)
682  Although it is a legacy compat.v1 API, this symbol is compatible with eager
683  execution and `tf.function`.
684
685  To switch to native TF2, switch to using either
686  `tf.initializers.truncated_normal` or `tf.keras.initializers.TruncatedNormal`
687  (neither from `compat.v1`) and
688  pass the dtype when calling the initializer. Keep in mind that
689  the default stddev and the behavior of fixed seeds have changed.
690
691  Random seed behavior:
692  Also be aware that if you pass a seed to the TF2 initializer
693  API it will reuse that same seed for every single initialization
694  (unlike the TF1 intializer)
695
696  #### Structural Mapping to Native TF2
697
698  Before:
699
700  ```python
701  initializer = tf.compat.v1.truncated_normal_initializer(
702    mean=mean,
703    stddev=stddev,
704    seed=seed,
705    dtype=dtype)
706
707  weight_one = tf.Variable(initializer(shape_one))
708  weight_two = tf.Variable(initializer(shape_two))
709  ```
710
711  After:
712
713  ```python
714  initializer = tf.initializers.truncated_normal(
715    mean=mean,
716    # seed=seed,  # Setting a seed in the native TF2 API
717                  # causes it to produce the same initializations
718                  # across multiple calls of the same initializer.
719    stddev=stddev)
720
721  weight_one = tf.Variable(initializer(shape_one, dtype=dtype))
722  weight_two = tf.Variable(initializer(shape_two, dtype=dtype))
723  ```
724
725  #### How to Map Arguments
726
727  | TF1 Arg Name          | TF2 Arg Name    | Note                       |
728  | :-------------------- | :-------------- | :------------------------- |
729  | `mean`               | `mean`        | No change to defaults |
730  | `stddev`         | `stddev`        | Default changes from 1.0 to 0.05 |
731  | `seed`             | `seed` | Different random number generation |
732  :                    :        : semantics (to change in a  :
733  :                    :        : future version). If set, the TF2 version :
734  :                    :        : will use stateless random number :
735  :                    :        : generation which will produce the exact :
736  :                    :        : same initialization even across multiple :
737  :                    :        : calls of the initializer instance. the :
738  :                    :        : `compat.v1` version will generate new :
739  :                    :        : initializations each time. Do not set :
740  :                    :        : a seed if you need different          :
741  :                    :        : initializations each time. Instead    :
742  :                    :        : either set a global tf seed with      :
743  :                    :        : `tf.random.set_seed` if you need :
744  :                    :        : determinism, or initialize each weight :
745  :                    :        : with a separate initializer instance  :
746  :                    :        : and a different seed.                 :
747  | `dtype` | `dtype`   | The TF2 native api only takes it  |
748  :                     :      : as a `__call__` arg, not a constructor arg. :
749  | `partition_info`     | - |  (`__call__` arg in TF1) Not supported       |
750
751  #### Example of fixed-seed behavior differences
752
753  `compat.v1` Fixed seed behavior:
754
755  >>> initializer = tf.compat.v1.truncated_normal_initializer(seed=10)
756  >>> a = initializer(shape=(2, 2))
757  >>> b = initializer(shape=(2, 2))
758  >>> tf.reduce_sum(a - b) == 0
759  <tf.Tensor: shape=(), dtype=bool, numpy=False>
760
761  After:
762
763  >>> initializer = tf.initializers.truncated_normal(seed=10)
764  >>> a = initializer(shape=(2, 2))
765  >>> b = initializer(shape=(2, 2))
766  >>> tf.reduce_sum(a - b) == 0
767  <tf.Tensor: shape=(), dtype=bool, numpy=True>
768
769  @end_compatibility
770  """
771
772  @deprecated_args(None,
773                   "Call initializer instance with the dtype argument instead "
774                   "of passing it to the constructor", "dtype")
775  def __init__(self, mean=0.0, stddev=1.0, seed=None, dtype=dtypes.float32):
776    self.mean = mean
777    self.stddev = stddev
778    self.seed = seed
779    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
780
781  def __call__(self, shape, dtype=None, partition_info=None):
782    if dtype is None:
783      dtype = self.dtype
784    return random_ops.truncated_normal(
785        shape, self.mean, self.stddev, dtype, seed=self.seed)
786
787  def get_config(self):
788    return {
789        "mean": self.mean,
790        "stddev": self.stddev,
791        "seed": self.seed,
792        "dtype": self.dtype.name
793    }
794
795
796@tf_export(v1=[
797    "initializers.uniform_unit_scaling", "uniform_unit_scaling_initializer"
798])
799@deprecation.deprecated_endpoints("uniform_unit_scaling_initializer",
800                                  "initializers.uniform_unit_scaling")
801class UniformUnitScaling(Initializer):
802  """Initializer that generates tensors without scaling variance.
803
804  When initializing a deep network, it is in principle advantageous to keep
805  the scale of the input variance constant, so it does not explode or diminish
806  by reaching the final layer. If the input is `x` and the operation `x * W`,
807  and we want to initialize `W` uniformly at random, we need to pick `W` from
808
809      [-sqrt(3) / sqrt(dim), sqrt(3) / sqrt(dim)]
810
811  to keep the scale intact, where `dim = W.shape[0]` (the size of the input).
812  A similar calculation for convolutional networks gives an analogous result
813  with `dim` equal to the product of the first 3 dimensions.  When
814  nonlinearities are present, we need to multiply this by a constant `factor`.
815  See (Sussillo et al., 2014) for deeper motivation, experiments
816  and the calculation of constants. In section 2.3 there, the constants were
817  numerically computed: for a linear layer it's 1.0, relu: ~1.43, tanh: ~1.15.
818
819  Args:
820    factor: Float.  A multiplicative factor by which the values will be scaled.
821    seed: A Python integer. Used to create random seeds. See
822      `tf.compat.v1.set_random_seed` for behavior.
823    dtype: Default data type, used if no `dtype` argument is provided when
824      calling the initializer. Only floating point types are supported.
825  References:
826      [Sussillo et al., 2014](https://arxiv.org/abs/1412.6558)
827      ([pdf](http://arxiv.org/pdf/1412.6558.pdf))
828  """
829
830  @deprecated_args(None,
831                   "Call initializer instance with the dtype argument instead "
832                   "of passing it to the constructor", "dtype")
833  @deprecated(None,
834              "Use tf.initializers.variance_scaling instead with distribution="
835              "uniform to get equivalent behavior.")
836  def __init__(self, factor=1.0, seed=None, dtype=dtypes.float32):
837    self.factor = factor
838    self.seed = seed
839    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
840
841  def __call__(self, shape, dtype=None, partition_info=None):
842    if dtype is None:
843      dtype = self.dtype
844    scale_shape = shape
845    if partition_info is not None:
846      scale_shape = partition_info.full_shape
847
848    input_size = 1.0
849    # Estimating input size is not possible to do perfectly, but we try.
850    # The estimate, obtained by multiplying all dimensions but the last one,
851    # is the right thing for matrix multiply and convolutions (see above).
852    for dim in scale_shape[:-1]:
853      input_size *= float(dim)
854    # Avoid errors when initializing zero-size tensors.
855    input_size = max(input_size, 1.0)
856    max_val = math.sqrt(3 / input_size) * self.factor
857    return random_ops.random_uniform(
858        shape, -max_val, max_val, dtype, seed=self.seed)
859
860  def get_config(self):
861    return {"factor": self.factor, "seed": self.seed, "dtype": self.dtype.name}
862
863
864@tf_export(v1=["initializers.variance_scaling", "variance_scaling_initializer"])
865@deprecation.deprecated_endpoints("initializers.variance_scaling",
866                                  "variance_scaling_initializer")
867class VarianceScaling(Initializer):
868  """Initializer capable of adapting its scale to the shape of weights tensors.
869
870  With `distribution="truncated_normal" or "untruncated_normal"`,
871  samples are drawn from a truncated/untruncated normal
872  distribution with a mean of zero and a standard deviation (after truncation,
873  if used) `stddev = sqrt(scale / n)`
874  where n is:
875    - number of input units in the weight tensor, if mode = "fan_in"
876    - number of output units, if mode = "fan_out"
877    - average of the numbers of input and output units, if mode = "fan_avg"
878
879  With `distribution="uniform"`, samples are drawn from a uniform distribution
880  within [-limit, limit], with `limit = sqrt(3 * scale / n)`.
881
882  Args:
883    scale: Scaling factor (positive float).
884    mode: One of "fan_in", "fan_out", "fan_avg".
885    distribution: Random distribution to use. One of "normal", "uniform".
886    seed: A Python integer. Used to create random seeds. See
887      `tf.compat.v1.set_random_seed` for behavior.
888    dtype: Default data type, used if no `dtype` argument is provided when
889      calling the initializer. Only floating point types are supported.
890
891  Raises:
892    ValueError: In case of an invalid value for the "scale", mode" or
893      "distribution" arguments.
894  """
895
896  @deprecated_args(None,
897                   "Call initializer instance with the dtype argument instead "
898                   "of passing it to the constructor", "dtype")
899  @deprecated_arg_values(
900      None,
901      "`normal` is a deprecated alias for `truncated_normal`",
902      distribution="normal")
903  def __init__(self,
904               scale=1.0,
905               mode="fan_in",
906               distribution="truncated_normal",
907               seed=None,
908               dtype=dtypes.float32):
909    if scale <= 0.:
910      raise ValueError("`scale` must be positive float.")
911    if mode not in {"fan_in", "fan_out", "fan_avg"}:
912      raise ValueError("Invalid `mode` argument:", mode)
913    distribution = distribution.lower()
914    if distribution not in {
915        "normal", "uniform", "truncated_normal", "untruncated_normal"
916    }:
917      raise ValueError("Invalid `distribution` argument:", distribution)
918    self.scale = scale
919    self.mode = mode
920    self.distribution = distribution
921    self.seed = seed
922    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
923
924  def __call__(self, shape, dtype=None, partition_info=None):
925    if dtype is None:
926      dtype = self.dtype
927    scale = self.scale
928    scale_shape = shape
929    if partition_info is not None:
930      scale_shape = partition_info.full_shape
931    fan_in, fan_out = _compute_fans(scale_shape)
932    if self.mode == "fan_in":
933      scale /= max(1., fan_in)
934    elif self.mode == "fan_out":
935      scale /= max(1., fan_out)
936    else:
937      scale /= max(1., (fan_in + fan_out) / 2.)
938    if self.distribution == "normal" or self.distribution == "truncated_normal":
939      # constant taken from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
940      stddev = math.sqrt(scale) / .87962566103423978
941      return random_ops.truncated_normal(
942          shape, 0.0, stddev, dtype, seed=self.seed)
943    elif self.distribution == "untruncated_normal":
944      stddev = math.sqrt(scale)
945      return random_ops.random_normal(shape, 0.0, stddev, dtype, seed=self.seed)
946    else:
947      limit = math.sqrt(3.0 * scale)
948      return random_ops.random_uniform(
949          shape, -limit, limit, dtype, seed=self.seed)
950
951  def get_config(self):
952    return {
953        "scale": self.scale,
954        "mode": self.mode,
955        "distribution": self.distribution,
956        "seed": self.seed,
957        "dtype": self.dtype.name
958    }
959
960
961@tf_export(v1=["initializers.orthogonal", "orthogonal_initializer"])
962@deprecation.deprecated_endpoints("initializers.orthogonal",
963                                  "orthogonal_initializer")
964class Orthogonal(Initializer):
965  """Initializer that generates an orthogonal matrix.
966
967  If the shape of the tensor to initialize is two-dimensional, it is initialized
968  with an orthogonal matrix obtained from the QR decomposition of a matrix of
969  random numbers drawn from a normal distribution.
970  If the matrix has fewer rows than columns then the output will have orthogonal
971  rows. Otherwise, the output will have orthogonal columns.
972
973  If the shape of the tensor to initialize is more than two-dimensional,
974  a matrix of shape `(shape[0] * ... * shape[n - 2], shape[n - 1])`
975  is initialized, where `n` is the length of the shape vector.
976  The matrix is subsequently reshaped to give a tensor of the desired shape.
977
978  Args:
979    gain: multiplicative factor to apply to the orthogonal matrix
980    seed: A Python integer. Used to create random seeds. See
981      `tf.compat.v1.set_random_seed` for behavior.
982    dtype: Default data type, used if no `dtype` argument is provided when
983      calling the initializer. Only floating point types are supported.
984  References:
985      [Saxe et al., 2014](https://openreview.net/forum?id=_wzZwKpTDF_9C)
986      ([pdf](https://arxiv.org/pdf/1312.6120.pdf))
987  """
988
989  @deprecated_args(None,
990                   "Call initializer instance with the dtype argument instead "
991                   "of passing it to the constructor", "dtype")
992  def __init__(self, gain=1.0, seed=None, dtype=dtypes.float32):
993    self.gain = gain
994    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
995    self.seed = seed
996
997  def __call__(self, shape, dtype=None, partition_info=None):
998    if dtype is None:
999      dtype = self.dtype
1000    # Check the shape
1001    if len(shape) < 2:
1002      raise ValueError("The tensor to initialize must be "
1003                       "at least two-dimensional")
1004    # Flatten the input shape with the last dimension remaining
1005    # its original shape so it works for conv2d
1006    num_rows = 1
1007    for dim in shape[:-1]:
1008      num_rows *= dim
1009    num_rows = int(num_rows)
1010    num_cols = int(shape[-1])
1011    if num_rows < num_cols:
1012      flat_shape = (num_cols, num_rows)
1013    else:
1014      flat_shape = (num_rows, num_cols)
1015
1016    # Generate a random matrix
1017    a = random_ops.random_normal(flat_shape, dtype=dtype, seed=self.seed)
1018    # Compute the qr factorization
1019    q, r = gen_linalg_ops.qr(a, full_matrices=False)
1020    # Make Q uniform
1021    d = array_ops.diag_part(r)
1022    q *= math_ops.sign(d)
1023    if num_rows < num_cols:
1024      q = array_ops.matrix_transpose(q)
1025    return self.gain * array_ops.reshape(q, shape)
1026
1027  def get_config(self):
1028    return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name}
1029
1030
1031# Note these haven't been ported to TF2.0. They are not currently visible and
1032# the tests are non trivial to port
1033class ConvolutionDeltaOrthogonal(Initializer):
1034  """Initializer that generates a delta orthogonal kernel for ConvNets.
1035
1036  The shape of the tensor must have length 3, 4 or 5. The number of input
1037  filters must not exceed the number of output filters. The center pixels of the
1038  tensor form an orthogonal matrix. Other pixels are set to be zero. See
1039  algorithm 2 in (Xiao et al., 2018).
1040
1041
1042  Args:
1043    gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1.
1044      The 2-norm of an input is multiplied by a factor of `gain` after applying
1045      this convolution.
1046    seed: A Python integer. Used to create random seeds. See
1047      `tf.compat.v1.set_random_seed` for behavior.
1048    dtype: Default data type, used if no `dtype` argument is provided when
1049      calling the initializer. Only floating point types are supported.
1050  References:
1051      [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html)
1052      ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf))
1053  """
1054
1055  def __init__(self, gain=1.0, seed=None, dtype=dtypes.float32):
1056    self.gain = gain
1057    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
1058    self.seed = seed
1059
1060  def __call__(self, shape, dtype=None, partition_info=None):
1061    if dtype is None:
1062      dtype = self.dtype
1063    # Check the shape
1064    if len(shape) < 3 or len(shape) > 5:
1065      raise ValueError("The tensor to initialize must be at least "
1066                       "three-dimensional and at most five-dimensional")
1067
1068    if shape[-2] > shape[-1]:
1069      raise ValueError("In_filters cannot be greater than out_filters.")
1070
1071    # Generate a random matrix
1072    a = random_ops.random_normal([shape[-1], shape[-1]],
1073                                 dtype=dtype,
1074                                 seed=self.seed)
1075    # Compute the qr factorization
1076    q, r = gen_linalg_ops.qr(a, full_matrices=False)
1077    # Make Q uniform
1078    d = array_ops.diag_part(r)
1079    q *= math_ops.sign(d)
1080    q = q[:shape[-2], :]
1081    q *= math_ops.cast(self.gain, dtype=dtype)
1082    if len(shape) == 3:
1083      weight = array_ops.scatter_nd([[(shape[0] - 1) // 2]],
1084                                    array_ops.expand_dims(q, 0), shape)
1085    elif len(shape) == 4:
1086      weight = array_ops.scatter_nd([[(shape[0] - 1) // 2,
1087                                      (shape[1] - 1) // 2]],
1088                                    array_ops.expand_dims(q, 0), shape)
1089    else:
1090      weight = array_ops.scatter_nd([[(shape[0] - 1) // 2, (shape[1] - 1) // 2,
1091                                      (shape[2] - 1) // 2]],
1092                                    array_ops.expand_dims(q, 0), shape)
1093    return weight
1094
1095  def get_config(self):
1096    return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name}
1097
1098
1099class ConvolutionOrthogonal(Initializer):
1100  """Initializer that generates orthogonal kernel for ConvNets.
1101
1102  Base class used to construct 1D, 2D and 3D orthogonal kernels for convolution.
1103
1104  Args:
1105    gain: multiplicative factor to apply to the orthogonal matrix. Default is 1.
1106      The 2-norm of an input is multiplied by a factor of `gain` after applying
1107      this convolution.
1108    seed: A Python integer. Used to create random seeds. See
1109      `tf.compat.v1.set_random_seed` for behavior.
1110    dtype: Default data type, used if no `dtype` argument is provided when
1111      calling the initializer. Only floating point types are supported.
1112  References:
1113      [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html)
1114      ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf))
1115  """
1116
1117  def __init__(self, gain=1.0, seed=None, dtype=dtypes.float32):
1118    self.gain = gain
1119    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
1120    self.seed = seed
1121
1122  def __call__(self, shape, dtype=None, partition_info=None):
1123    raise NotImplementedError
1124
1125  def get_config(self):
1126    return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name}
1127
1128  # Helper functions.
1129  def _orthogonal_matrix(self, n):
1130    """Construct an n x n orthogonal matrix.
1131
1132    Args:
1133      n: Dimension.
1134
1135    Returns:
1136      A n x n orthogonal matrix.
1137    """
1138    a = random_ops.random_normal([n, n], dtype=self.dtype, seed=self.seed)
1139    if self.seed:
1140      self.seed += 1
1141    q, r = gen_linalg_ops.qr(a)
1142    d = array_ops.diag_part(r)
1143    # make q uniform
1144    q *= math_ops.sign(d)
1145    return q
1146
1147  def _symmetric_projection(self, n):
1148    """Compute a n x n symmetric projection matrix.
1149
1150    Args:
1151      n: Dimension.
1152
1153    Returns:
1154      A n x n symmetric projection matrix, i.e. a matrix P s.t. P=P*P, P=P^T.
1155    """
1156    q = self._orthogonal_matrix(n)
1157    # randomly zeroing out some columns
1158    mask = math_ops.cast(
1159        random_ops.random_normal([n], seed=self.seed) > 0, self.dtype)
1160    if self.seed:
1161      self.seed += 1
1162    c = math_ops.multiply(q, mask)
1163    return math_ops.matmul(c, array_ops.matrix_transpose(c))
1164
1165
1166class ConvolutionOrthogonal2D(ConvolutionOrthogonal):
1167  """Initializer that generates a 2D orthogonal kernel for ConvNets.
1168
1169  The shape of the tensor must have length 4. The number of input
1170  filters must not exceed the number of output filters.
1171  The orthogonality(==isometry) is exact when the inputs are circular padded.
1172  There are finite-width effects with non-circular padding (e.g. zero padding).
1173  See algorithm 1 in (Xiao et al., 2018).
1174
1175  Args:
1176    gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1.
1177      This has the effect of scaling the output 2-norm by a factor of `gain`.
1178    seed: A Python integer. Used to create random seeds. See
1179      `tf.compat.v1.set_random_seed` for behavior.
1180    dtype: Default data type, used if no `dtype` argument is provided when
1181      calling the initializer. Only floating point types are supported.
1182  References:
1183      [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html)
1184      ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf))
1185  """
1186
1187  def __call__(self, shape, dtype=None, partition_info=None):
1188    if dtype is None:
1189      dtype = self.dtype
1190    if len(shape) != 4:
1191      raise ValueError("The tensor to initialize must be four-dimensional")
1192
1193    if shape[-2] > shape[-1]:
1194      raise ValueError("In_filters cannot be greater than out_filters.")
1195
1196    if shape[0] != shape[1]:
1197      raise ValueError("Kernel sizes must be equal.")
1198
1199    kernel = self._orthogonal_kernel(shape[0], shape[2], shape[3])
1200    kernel *= math_ops.cast(self.gain, dtype=dtype)
1201    return kernel
1202
1203  def _dict_to_tensor(self, x, k1, k2):
1204    """Convert a dictionary to a tensor.
1205
1206    Args:
1207      x: A k1 * k2 dictionary.
1208      k1: First dimension of x.
1209      k2: Second dimension of x.
1210
1211    Returns:
1212      A k1 * k2 tensor.
1213    """
1214
1215    return array_ops.stack([array_ops.stack([x[i, j] for j in range(k2)])
1216                            for i in range(k1)])
1217
1218  def _block_orth(self, p1, p2):
1219    """Construct a 2 x 2 kernel.
1220
1221    Used to construct orthgonal kernel.
1222
1223    Args:
1224      p1: A symmetric projection matrix.
1225      p2: A symmetric projection matrix.
1226
1227    Returns:
1228      A 2 x 2 kernel [[p1p2,         p1(1-p2)],
1229                      [(1-p1)p2, (1-p1)(1-p2)]].
1230    Raises:
1231      ValueError: If the dimensions of p1 and p2 are different.
1232    """
1233    if p1.shape.as_list() != p2.shape.as_list():
1234      raise ValueError("The dimension of the matrices must be the same.")
1235    n = p1.shape.as_list()[0]
1236    kernel2x2 = {}
1237    eye = linalg_ops_impl.eye(n, dtype=self.dtype)
1238    kernel2x2[0, 0] = math_ops.matmul(p1, p2)
1239    kernel2x2[0, 1] = math_ops.matmul(p1, (eye - p2))
1240    kernel2x2[1, 0] = math_ops.matmul((eye - p1), p2)
1241    kernel2x2[1, 1] = math_ops.matmul((eye - p1), (eye - p2))
1242
1243    return kernel2x2
1244
1245  def _matrix_conv(self, m1, m2):
1246    """Matrix convolution.
1247
1248    Args:
1249      m1: A k x k dictionary, each element is a n x n matrix.
1250      m2: A l x l dictionary, each element is a n x n matrix.
1251
1252    Returns:
1253      (k + l - 1) * (k + l - 1) dictionary each element is a n x n matrix.
1254    Raises:
1255      ValueError: if the entries of m1 and m2 are of different dimensions.
1256    """
1257
1258    n = (m1[0, 0]).shape.as_list()[0]
1259    if n != (m2[0, 0]).shape.as_list()[0]:
1260      raise ValueError("The entries in matrices m1 and m2 "
1261                       "must have the same dimensions!")
1262    k = int(np.sqrt(len(m1)))
1263    l = int(np.sqrt(len(m2)))
1264    result = {}
1265    size = k + l - 1
1266    # Compute matrix convolution between m1 and m2.
1267    for i in range(size):
1268      for j in range(size):
1269        result[i, j] = array_ops.zeros([n, n], self.dtype)
1270        for index1 in range(min(k, i + 1)):
1271          for index2 in range(min(k, j + 1)):
1272            if (i - index1) < l and (j - index2) < l:
1273              result[i, j] += math_ops.matmul(m1[index1, index2],
1274                                              m2[i - index1, j - index2])
1275    return result
1276
1277  def _orthogonal_kernel(self, ksize, cin, cout):
1278    """Construct orthogonal kernel for convolution.
1279
1280    Args:
1281      ksize: Kernel size.
1282      cin: Number of input channels.
1283      cout: Number of output channels.
1284
1285    Returns:
1286      An [ksize, ksize, cin, cout] orthogonal kernel.
1287    Raises:
1288      ValueError: If cin > cout.
1289    """
1290    if cin > cout:
1291      raise ValueError("The number of input channels cannot exceed "
1292                       "the number of output channels.")
1293    orth = self._orthogonal_matrix(cout)[0:cin, :]
1294    if ksize == 1:
1295      return array_ops.expand_dims(array_ops.expand_dims(orth, 0), 0)
1296
1297    p = self._block_orth(
1298        self._symmetric_projection(cout), self._symmetric_projection(cout))
1299    for _ in range(ksize - 2):
1300      temp = self._block_orth(
1301          self._symmetric_projection(cout), self._symmetric_projection(cout))
1302      p = self._matrix_conv(p, temp)
1303    for i in range(ksize):
1304      for j in range(ksize):
1305        p[i, j] = math_ops.matmul(orth, p[i, j])
1306
1307    return self._dict_to_tensor(p, ksize, ksize)
1308
1309
1310class ConvolutionOrthogonal1D(ConvolutionOrthogonal):
1311  """Initializer that generates a 1D orthogonal kernel for ConvNets.
1312
1313  The shape of the tensor must have length 3. The number of input
1314  filters must not exceed the number of output filters.
1315  The orthogonality(==isometry) is exact when the inputs are circular padded.
1316  There are finite-width effects with non-circular padding (e.g. zero padding).
1317  See algorithm 1 in (Xiao et al., 2018).
1318
1319  Args:
1320    gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1.
1321      The 2-norm of an input is multiplied by a factor of `gain` after applying
1322      this convolution.
1323    seed: A Python integer. Used to create random seeds. See
1324      `tf.compat.v1.set_random_seed` for behavior.
1325    dtype: Default data type, used if no `dtype` argument is provided when
1326      calling the initializer. Only floating point types are supported.
1327  References:
1328      [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html)
1329      ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf))
1330  """
1331
1332  def __call__(self, shape, dtype=None, partition_info=None):
1333    if dtype is None:
1334      dtype = self.dtype
1335    if len(shape) != 3:
1336      raise ValueError("The tensor to initialize must be three-dimensional")
1337
1338    if shape[-2] > shape[-1]:
1339      raise ValueError("In_filters cannot be greater than out_filters.")
1340
1341    kernel = self._orthogonal_kernel(shape[0], shape[-2], shape[-1])
1342    kernel *= math_ops.cast(self.gain, dtype=dtype)
1343    return kernel
1344
1345  def _dict_to_tensor(self, x, k):
1346    """Convert a dictionary to a tensor.
1347
1348    Args:
1349      x: A dictionary of length k.
1350      k: Dimension of x.
1351
1352    Returns:
1353      A tensor with the same dimension.
1354    """
1355
1356    return array_ops.stack([x[i] for i in range(k)])
1357
1358  def _block_orth(self, projection_matrix):
1359    """Construct a kernel.
1360
1361    Used to construct orthgonal kernel.
1362
1363    Args:
1364      projection_matrix: A symmetric projection matrix of size n x n.
1365
1366    Returns:
1367      [projection_matrix, (1 - projection_matrix)].
1368    """
1369    n = projection_matrix.shape.as_list()[0]
1370    kernel = {}
1371    eye = linalg_ops_impl.eye(n, dtype=self.dtype)
1372    kernel[0] = projection_matrix
1373    kernel[1] = eye - projection_matrix
1374    return kernel
1375
1376  def _matrix_conv(self, m1, m2):
1377    """Matrix convolution.
1378
1379    Args:
1380      m1: A dictionary of length k, each element is a n x n matrix.
1381      m2: A dictionary of length l, each element is a n x n matrix.
1382
1383    Returns:
1384      (k + l - 1)  dictionary each element is a n x n matrix.
1385    Raises:
1386      ValueError: Ff the entries of m1 and m2 are of different dimensions.
1387    """
1388
1389    n = (m1[0]).shape.as_list()[0]
1390    if n != (m2[0]).shape.as_list()[0]:
1391      raise ValueError("The entries in matrices m1 and m2 "
1392                       "must have the same dimensions!")
1393    k = len(m1)
1394    l = len(m2)
1395    result = {}
1396    size = k + l - 1
1397    # Compute matrix convolution between m1 and m2.
1398    for i in range(size):
1399      result[i] = array_ops.zeros([n, n], self.dtype)
1400      for index in range(min(k, i + 1)):
1401        if (i - index) < l:
1402          result[i] += math_ops.matmul(m1[index], m2[i - index])
1403    return result
1404
1405  def _orthogonal_kernel(self, ksize, cin, cout):
1406    """Construct orthogonal kernel for convolution.
1407
1408    Args:
1409      ksize: Kernel size.
1410      cin: Number of input channels.
1411      cout: Number of output channels.
1412
1413    Returns:
1414      An [ksize, ksize, cin, cout] orthogonal kernel.
1415    Raises:
1416      ValueError: If cin > cout.
1417    """
1418    if cin > cout:
1419      raise ValueError("The number of input channels cannot exceed "
1420                       "the number of output channels.")
1421    orth = self._orthogonal_matrix(cout)[0:cin, :]
1422    if ksize == 1:
1423      return array_ops.expand_dims(orth, 0)
1424
1425    p = self._block_orth(self._symmetric_projection(cout))
1426    for _ in range(ksize - 2):
1427      temp = self._block_orth(self._symmetric_projection(cout))
1428      p = self._matrix_conv(p, temp)
1429    for i in range(ksize):
1430      p[i] = math_ops.matmul(orth, p[i])
1431
1432    return self._dict_to_tensor(p, ksize)
1433
1434
1435class ConvolutionOrthogonal3D(ConvolutionOrthogonal):
1436  """Initializer that generates a 3D orthogonal kernel for ConvNets.
1437
1438  The shape of the tensor must have length 5. The number of input
1439  filters must not exceed the number of output filters.
1440  The orthogonality(==isometry) is exact when the inputs are circular padded.
1441  There are finite-width effects with non-circular padding (e.g. zero padding).
1442  See algorithm 1 (Xiao et al., 2018).
1443
1444  Args:
1445    gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1.
1446      The 2-norm of an input is multiplied by a factor of `gain` after applying
1447      this convolution.
1448    seed: A Python integer. Used to create random seeds. See
1449      `tf.compat.v1.set_random_seed` for behavior.
1450    dtype: Default data type, used if no `dtype` argument is provided when
1451      calling the initializer. Only floating point types are supported.
1452  References:
1453      [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html)
1454      ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf))
1455  """
1456
1457  def __call__(self, shape, dtype=None, partition_info=None):
1458    if dtype is None:
1459      dtype = self.dtype
1460    if len(shape) != 5:
1461      raise ValueError("The tensor to initialize must be five-dimensional")
1462
1463    if shape[-2] > shape[-1]:
1464      raise ValueError("In_filters cannot be greater than out_filters.")
1465
1466    if shape[0] != shape[1] or shape[0] != shape[2]:
1467      raise ValueError("Kernel sizes must be equal.")
1468
1469    kernel = self._orthogonal_kernel(shape[0], shape[-2], shape[-1])
1470    kernel *= math_ops.cast(self.gain, dtype=dtype)
1471    return kernel
1472
1473  def _dict_to_tensor(self, x, k1, k2, k3):
1474    """Convert a dictionary to a tensor.
1475
1476    Args:
1477      x: A k1 * k2 dictionary.
1478      k1: First dimension of x.
1479      k2: Second dimension of x.
1480      k3: Third dimension of x.
1481
1482    Returns:
1483      A k1 * k2 * k3 tensor.
1484    """
1485
1486    return array_ops.stack([array_ops.stack(
1487        [array_ops.stack([x[i, j, k] for k in range(k3)])
1488         for j in range(k2)]) for i in range(k1)])
1489
1490  def _block_orth(self, p1, p2, p3):
1491    """Construct a 3 x 3 kernel.
1492
1493    Used to construct orthgonal kernel.
1494
1495    Args:
1496      p1: A symmetric projection matrix.
1497      p2: A symmetric projection matrix.
1498      p3: A symmetric projection matrix.
1499
1500    Returns:
1501      A 2 x 2 x 2 kernel.
1502    Raises:
1503      ValueError: If the dimensions of p1, p2 and p3 are different.
1504    """
1505    p1_shape = p1.shape.as_list()
1506    if p1_shape != p2.shape.as_list() or p1_shape != p3.shape.as_list():
1507      raise ValueError("The dimension of the matrices must be the same.")
1508    n = p1_shape[0]
1509    eye = linalg_ops_impl.eye(n, dtype=self.dtype)
1510    kernel2x2x2 = {}
1511
1512    def matmul(p1, p2, p3):
1513      return math_ops.matmul(math_ops.matmul(p1, p2), p3)
1514
1515    def cast(i, p):
1516      """Return p or (1-p)."""
1517      return i * p + (1 - i) * (eye - p)
1518
1519    for i in [0, 1]:
1520      for j in [0, 1]:
1521        for k in [0, 1]:
1522          kernel2x2x2[i, j, k] = matmul(cast(i, p1), cast(j, p2), cast(k, p3))
1523    return kernel2x2x2
1524
1525  def _matrix_conv(self, m1, m2):
1526    """Matrix convolution.
1527
1528    Args:
1529      m1: is a k x k x k  dictionary, each element is a n x n matrix.
1530      m2: is a l x l x l dictionary, each element is a n x n matrix.
1531
1532    Returns:
1533      (k + l - 1) x (k + l - 1) x (k + l - 1) dictionary each
1534      element is a n x n matrix.
1535    Raises:
1536      ValueError: if the entries of m1 and m2 are of different dimensions.
1537    """
1538
1539    n = (m1[0, 0, 0]).shape.as_list()[0]
1540    if n != (m2[0, 0, 0]).shape.as_list()[0]:
1541      raise ValueError("The entries in matrices m1 and m2 "
1542                       "must have the same dimensions!")
1543    k = int(np.cbrt(len(m1)))
1544    l = int(np.cbrt(len(m2)))
1545    result = {}
1546    size = k + l - 1
1547    # Compute matrix convolution between m1 and m2.
1548    for i in range(size):
1549      for j in range(size):
1550        for r in range(size):
1551          result[i, j, r] = array_ops.zeros([n, n], self.dtype)
1552          for index1 in range(min(k, i + 1)):
1553            for index2 in range(min(k, j + 1)):
1554              for index3 in range(min(k, r + 1)):
1555                if (i - index1) < l and (j - index2) < l and (r - index3) < l:
1556                  result[i, j, r] += math_ops.matmul(
1557                      m1[index1, index2, index3],
1558                      m2[i - index1, j - index2, r - index3])
1559    return result
1560
1561  def _orthogonal_kernel(self, ksize, cin, cout):
1562    """Construct orthogonal kernel for convolution.
1563
1564    Args:
1565      ksize: Kernel size.
1566      cin: Number of input channels.
1567      cout: Number of output channels.
1568
1569    Returns:
1570      An [ksize, ksize, ksize, cin, cout] orthogonal kernel.
1571    Raises:
1572      ValueError: If cin > cout.
1573    """
1574    if cin > cout:
1575      raise ValueError("The number of input channels cannot exceed "
1576                       "the number of output channels.")
1577    orth = self._orthogonal_matrix(cout)[0:cin, :]
1578    if ksize == 1:
1579      return array_ops.expand_dims(
1580          array_ops.expand_dims(array_ops.expand_dims(orth, 0), 0), 0)
1581
1582    p = self._block_orth(
1583        self._symmetric_projection(cout), self._symmetric_projection(cout),
1584        self._symmetric_projection(cout))
1585    for _ in range(ksize - 2):
1586      temp = self._block_orth(
1587          self._symmetric_projection(cout), self._symmetric_projection(cout),
1588          self._symmetric_projection(cout))
1589      p = self._matrix_conv(p, temp)
1590    for i in range(ksize):
1591      for j in range(ksize):
1592        for k in range(ksize):
1593          p[i, j, k] = math_ops.matmul(orth, p[i, j, k])
1594
1595    return self._dict_to_tensor(p, ksize, ksize, ksize)
1596
1597
1598@tf_export(v1=["initializers.identity"])
1599@deprecation.deprecated_endpoints("initializers.identity")
1600class Identity(Initializer):
1601  """Initializer that generates the identity matrix.
1602
1603  Only use for 2D matrices.
1604
1605  Args:
1606    gain: Multiplicative factor to apply to the identity matrix.
1607    dtype: Default data type, used if no `dtype` argument is provided when
1608      calling the initializer. Only floating point types are supported.
1609  """
1610
1611  @deprecated_args(None,
1612                   "Call initializer instance with the dtype argument instead "
1613                   "of passing it to the constructor", "dtype")
1614  def __init__(self, gain=1.0, dtype=dtypes.float32):
1615    self.gain = gain
1616    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
1617
1618  def __call__(self, shape, dtype=None, partition_info=None):
1619    full_shape = shape if partition_info is None else partition_info.full_shape
1620    if len(full_shape) != 2:
1621      raise ValueError(
1622          "Identity matrix initializer can only be used for 2D matrices.")
1623    if dtype is None:
1624      dtype = self.dtype
1625    if isinstance(full_shape, tensor_shape.TensorShape):
1626      full_shape = full_shape.as_list()
1627    initializer = linalg_ops_impl.eye(*full_shape, dtype=dtype)
1628    if partition_info is not None:
1629      initializer = array_ops.slice(initializer, partition_info.var_offset,
1630                                    shape)
1631    return self.gain * initializer
1632
1633  def get_config(self):
1634    return {"gain": self.gain, "dtype": self.dtype.name}
1635
1636
1637@tf_export(v1=["glorot_uniform_initializer", "initializers.glorot_uniform"])
1638@deprecation.deprecated_endpoints("glorot_uniform_initializer",
1639                                  "initializers.glorot_uniform")
1640class GlorotUniform(VarianceScaling):
1641  """The Glorot uniform initializer, also called Xavier uniform initializer.
1642
1643  It draws samples from a uniform distribution within [-limit, limit]
1644  where `limit` is `sqrt(6 / (fan_in + fan_out))`
1645  where `fan_in` is the number of input units in the weight tensor
1646  and `fan_out` is the number of output units in the weight tensor.
1647
1648  Args:
1649    seed: A Python integer. Used to create random seeds. See
1650      `tf.compat.v1.set_random_seed` for behavior.
1651    dtype: Default data type, used if no `dtype` argument is provided when
1652      calling the initializer. Only floating point types are supported.
1653  References:
1654      [Glorot et al., 2010](http://proceedings.mlr.press/v9/glorot10a.html)
1655      ([pdf](http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf))
1656  """
1657
1658  @deprecated_args(None,
1659                   "Call initializer instance with the dtype argument instead "
1660                   "of passing it to the constructor", "dtype")
1661  def __init__(self, seed=None, dtype=dtypes.float32):
1662    super(GlorotUniform, self).__init__(
1663        scale=1.0, mode="fan_avg", distribution="uniform", seed=seed)
1664
1665  def get_config(self):
1666    return {"seed": self.seed, "dtype": self.dtype.name}
1667
1668
1669@tf_export(v1=["glorot_normal_initializer", "initializers.glorot_normal"])
1670@deprecation.deprecated_endpoints("glorot_normal_initializer",
1671                                  "initializers.glorot_normal")
1672class GlorotNormal(VarianceScaling):
1673  """The Glorot normal initializer, also called Xavier normal initializer.
1674
1675  It draws samples from a truncated normal distribution centered on 0
1676  with standard deviation (after truncation) given by
1677  `stddev = sqrt(2 / (fan_in + fan_out))` where `fan_in` is the number
1678  of input units in the weight tensor and `fan_out` is the number of
1679  output units in the weight tensor.
1680
1681  Args:
1682    seed: A Python integer. Used to create random seeds. See
1683      `tf.compat.v1.set_random_seed` for behavior.
1684    dtype: Default data type, used if no `dtype` argument is provided when
1685      calling the initializer. Only floating point types are supported.
1686  References:
1687      [Glorot et al., 2010](http://proceedings.mlr.press/v9/glorot10a.html)
1688      ([pdf](http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf))
1689  """
1690
1691  @deprecated_args(None,
1692                   "Call initializer instance with the dtype argument instead "
1693                   "of passing it to the constructor", "dtype")
1694  def __init__(self, seed=None, dtype=dtypes.float32):
1695    super(GlorotNormal, self).__init__(
1696        scale=1.0, mode="fan_avg", distribution="truncated_normal", seed=seed)
1697
1698  def get_config(self):
1699    return {"seed": self.seed, "dtype": self.dtype.name}
1700
1701
1702# Aliases.
1703
1704# pylint: disable=invalid-name
1705zeros_initializer = Zeros
1706ones_initializer = Ones
1707constant_initializer = Constant
1708random_uniform_initializer = RandomUniform
1709random_normal_initializer = RandomNormal
1710truncated_normal_initializer = TruncatedNormal
1711uniform_unit_scaling_initializer = UniformUnitScaling
1712variance_scaling_initializer = VarianceScaling
1713glorot_uniform_initializer = GlorotUniform
1714glorot_normal_initializer = GlorotNormal
1715orthogonal_initializer = Orthogonal
1716identity_initializer = Identity
1717convolutional_delta_orthogonal = ConvolutionDeltaOrthogonal
1718convolutional_orthogonal_1d = ConvolutionOrthogonal1D
1719convolutional_orthogonal_2d = ConvolutionOrthogonal2D
1720convolutional_orthogonal_3d = ConvolutionOrthogonal3D
1721# pylint: enable=invalid-name
1722
1723
1724@tf_export(v1=["initializers.lecun_normal"])
1725def lecun_normal(seed=None):
1726  """LeCun normal initializer.
1727
1728  It draws samples from a truncated normal distribution centered on 0
1729  with standard deviation (after truncation) given by
1730  `stddev = sqrt(1 / fan_in)` where `fan_in` is the number of
1731  input units in the weight tensor.
1732
1733  Args:
1734      seed: A Python integer. Used to seed the random generator.
1735
1736  Returns:
1737      An initializer.
1738
1739  References:
1740      - Self-Normalizing Neural Networks,
1741      [Klambauer et al.,
1742      2017](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks)
1743      # pylint: disable=line-too-long
1744      ([pdf](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf))
1745      - Efficient Backprop,
1746      [Lecun et al., 1998](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf)
1747  """
1748  return VarianceScaling(
1749      scale=1., mode="fan_in", distribution="truncated_normal", seed=seed)
1750
1751
1752@tf_export(v1=["initializers.lecun_uniform"])
1753def lecun_uniform(seed=None):
1754  """LeCun uniform initializer.
1755
1756  It draws samples from a uniform distribution within [-limit, limit]
1757  where `limit` is `sqrt(3 / fan_in)`
1758  where `fan_in` is the number of input units in the weight tensor.
1759
1760  Args:
1761      seed: A Python integer. Used to seed the random generator.
1762
1763  Returns:
1764      An initializer.
1765
1766  References:
1767      - Self-Normalizing Neural Networks,
1768      [Klambauer et al.,
1769      2017](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks)
1770      # pylint: disable=line-too-long
1771      ([pdf](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf))
1772      - Efficient Backprop,
1773      [Lecun et al., 1998](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf)
1774  """
1775  return VarianceScaling(
1776      scale=1., mode="fan_in", distribution="uniform", seed=seed)
1777
1778
1779@tf_export(v1=["initializers.he_normal"])
1780def he_normal(seed=None):
1781  """He normal initializer.
1782
1783  It draws samples from a truncated normal distribution centered on 0
1784  with standard deviation (after truncation) given by
1785  `stddev = sqrt(2 / fan_in)` where `fan_in` is the number of
1786  input units in the weight tensor.
1787
1788  Args:
1789      seed: A Python integer. Used to seed the random generator.
1790
1791  Returns:
1792      An initializer.
1793
1794  References:
1795      [He et al., 2015]
1796      (https://www.cv-foundation.org/openaccess/content_iccv_2015/html/He_Delving_Deep_into_ICCV_2015_paper.html)
1797      # pylint: disable=line-too-long
1798      ([pdf](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf))
1799  """
1800  return VarianceScaling(
1801      scale=2., mode="fan_in", distribution="truncated_normal", seed=seed)
1802
1803
1804@tf_export(v1=["initializers.he_uniform"])
1805def he_uniform(seed=None):
1806  """He uniform variance scaling initializer.
1807
1808  It draws samples from a uniform distribution within [-limit, limit]
1809  where `limit` is `sqrt(6 / fan_in)`
1810  where `fan_in` is the number of input units in the weight tensor.
1811
1812  Args:
1813      seed: A Python integer. Used to seed the random generator.
1814
1815  Returns:
1816      An initializer.
1817
1818  References:
1819      [He et al., 2015]
1820      (https://www.cv-foundation.org/openaccess/content_iccv_2015/html/He_Delving_Deep_into_ICCV_2015_paper.html)
1821      # pylint: disable=line-too-long
1822      ([pdf](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf))
1823  """
1824  return VarianceScaling(
1825      scale=2., mode="fan_in", distribution="uniform", seed=seed)
1826
1827
1828# Utility functions.
1829
1830
1831def _compute_fans(shape):
1832  """Computes the number of input and output units for a weight shape.
1833
1834  Args:
1835    shape: Integer shape tuple or TF tensor shape.
1836
1837  Returns:
1838    A tuple of integer scalars (fan_in, fan_out).
1839  """
1840  if len(shape) < 1:  # Just to avoid errors for constants.
1841    fan_in = fan_out = 1
1842  elif len(shape) == 1:
1843    fan_in = fan_out = shape[0]
1844  elif len(shape) == 2:
1845    fan_in = shape[0]
1846    fan_out = shape[1]
1847  else:
1848    # Assuming convolution kernels (2D, 3D, or more).
1849    # kernel shape: (..., input_depth, depth)
1850    receptive_field_size = 1
1851    for dim in shape[:-2]:
1852      receptive_field_size *= dim
1853    fan_in = shape[-2] * receptive_field_size
1854    fan_out = shape[-1] * receptive_field_size
1855  return int(fan_in), int(fan_out)
1856
1857
1858def _assert_float_dtype(dtype):
1859  """Validate and return floating point type based on `dtype`.
1860
1861  `dtype` must be a floating point type.
1862
1863  Args:
1864    dtype: The data type to validate.
1865
1866  Returns:
1867    Validated type.
1868
1869  Raises:
1870    ValueError: if `dtype` is not a floating point type.
1871  """
1872  if not dtype.is_floating:
1873    raise ValueError("Expected floating point type, got %s." % dtype)
1874  return dtype
1875