1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Operations often used for initializing tensors. 16 17All variable initializers returned by functions in this file should have the 18following signature: 19 20def _initializer(shape, dtype=dtypes.float32, partition_info=None): 21 Args: 22 shape: List of `int` representing the shape of the output `Tensor`. Some 23 initializers may also be able to accept a `Tensor`. 24 dtype: (Optional) Type of the output `Tensor`. 25 partition_info: (Optional) variable_scope._PartitionInfo object holding 26 additional information about how the variable is partitioned. May be 27 `None` if the variable is not partitioned. 28 29 Returns: 30 A `Tensor` of type `dtype` and `shape`. 31""" 32from __future__ import absolute_import 33from __future__ import division 34from __future__ import print_function 35 36import math 37 38import numpy as np 39 40from tensorflow.python.framework import constant_op 41from tensorflow.python.framework import dtypes 42from tensorflow.python.framework import tensor_shape 43from tensorflow.python.ops import array_ops 44from tensorflow.python.ops import gen_linalg_ops 45from tensorflow.python.ops import linalg_ops_impl 46from tensorflow.python.ops import math_ops 47from tensorflow.python.ops import random_ops 48from tensorflow.python.util import deprecation 49from tensorflow.python.util.deprecation import deprecated 50from tensorflow.python.util.deprecation import deprecated_arg_values 51from tensorflow.python.util.deprecation import deprecated_args 52from tensorflow.python.util.tf_export import tf_export 53 54 55class Initializer(object): 56 """Initializer base class: all initializers inherit from this class.""" 57 58 def __call__(self, shape, dtype=None, partition_info=None): 59 """Returns a tensor object initialized as specified by the initializer. 60 61 Args: 62 shape: Shape of the tensor. 63 dtype: Optional dtype of the tensor. If not provided use the initializer 64 dtype. 65 partition_info: Optional information about the possible partitioning of a 66 tensor. 67 """ 68 raise NotImplementedError 69 70 def get_config(self): 71 """Returns the configuration of the initializer as a JSON-serializable dict. 72 73 Returns: 74 A JSON-serializable Python dict. 75 """ 76 return {} 77 78 @classmethod 79 def from_config(cls, config): 80 """Instantiates an initializer from a configuration dictionary. 81 82 Example: 83 84 ```python 85 initializer = RandomUniform(-1, 1) 86 config = initializer.get_config() 87 initializer = RandomUniform.from_config(config) 88 ``` 89 90 Args: 91 config: A Python dictionary. It will typically be the output of 92 `get_config`. 93 94 Returns: 95 An Initializer instance. 96 """ 97 return cls(**config) 98 99 100@tf_export(v1=["initializers.zeros", "zeros_initializer"]) 101@deprecation.deprecated_endpoints("initializers.zeros") 102class Zeros(Initializer): 103 """Initializer that generates tensors initialized to 0.""" 104 105 @deprecated_args(None, 106 "Call initializer instance with the dtype argument instead " 107 "of passing it to the constructor", "dtype") 108 def __init__(self, dtype=dtypes.float32): 109 self.dtype = dtypes.as_dtype(dtype) 110 111 def __call__(self, shape, dtype=None, partition_info=None): 112 if dtype is None: 113 dtype = self.dtype 114 return array_ops.zeros(shape, dtype) 115 116 def get_config(self): 117 return {"dtype": self.dtype.name} 118 119 120@tf_export(v1=["initializers.ones", "ones_initializer"]) 121@deprecation.deprecated_endpoints("initializers.ones", "ones_initializer") 122class Ones(Initializer): 123 """Initializer that generates tensors initialized to 1.""" 124 125 @deprecated_args(None, 126 "Call initializer instance with the dtype argument instead " 127 "of passing it to the constructor", "dtype") 128 def __init__(self, dtype=dtypes.float32): 129 self.dtype = dtypes.as_dtype(dtype) 130 131 def __call__(self, shape, dtype=None, partition_info=None): 132 if dtype is None: 133 dtype = self.dtype 134 return array_ops.ones(shape, dtype) 135 136 def get_config(self): 137 return {"dtype": self.dtype.name} 138 139 140@tf_export(v1=["initializers.constant", "constant_initializer"]) 141@deprecation.deprecated_endpoints("constant_initializer") 142class Constant(Initializer): 143 """Initializer that generates tensors with constant values. 144 145 The resulting tensor is populated with values of type `dtype`, as 146 specified by arguments `value` following the desired `shape` of the 147 new tensor (see examples below). 148 149 The argument `value` can be a constant value, or a list of values of type 150 `dtype`. If `value` is a list, then the length of the list must be less 151 than or equal to the number of elements implied by the desired shape of the 152 tensor. In the case where the total number of elements in `value` is less 153 than the number of elements required by the tensor shape, the last element 154 in `value` will be used to fill the remaining entries. If the total number of 155 elements in `value` is greater than the number of elements required by the 156 tensor shape, the initializer will raise a `ValueError`. 157 158 Args: 159 value: A Python scalar, list or tuple of values, or a N-dimensional numpy 160 array. All elements of the initialized variable will be set to the 161 corresponding value in the `value` argument. 162 dtype: Default data type, used if no `dtype` argument is provided when 163 calling the initializer. 164 verify_shape: Boolean that enables verification of the shape of `value`. If 165 `True`, the initializer will throw an error if the shape of `value` is not 166 compatible with the shape of the initialized tensor. 167 168 Raises: 169 TypeError: If the input `value` is not one of the expected types. 170 171 Examples: 172 The following example can be rewritten using a numpy.ndarray instead 173 of the `value` list, even reshaped, as shown in the two commented lines 174 below the `value` list initialization. 175 176 >>> value = [0, 1, 2, 3, 4, 5, 6, 7] 177 >>> init = tf.compat.v1.constant_initializer(value) 178 >>> # fitting shape 179 >>> with tf.compat.v1.Session(): 180 ... x = tf.compat.v1.get_variable('x', shape=[2, 4], initializer=init) 181 ... x.initializer.run() 182 ... print(x.eval()) 183 [[0. 1. 2. 3.] 184 [4. 5. 6. 7.]] 185 >>> # Larger shape 186 >>> with tf.compat.v1.Session(): 187 ... y = tf.compat.v1.get_variable('y', shape=[3, 4], initializer=init) 188 ... y.initializer.run() 189 ... print(y.eval()) 190 [[0. 1. 2. 3.] 191 [4. 5. 6. 7.] 192 [7. 7. 7. 7.]] 193 >>> # Smaller shape 194 >>> with tf.compat.v1.Session(): 195 ... z = tf.compat.v1.get_variable('z', shape=[2, 3], initializer=init) 196 Traceback (most recent call last): 197 ... 198 ValueError: Too many elements provided. Needed at most 6, but received 8 199 >>> # Shape verification 200 >>> init_verify = tf.compat.v1.constant_initializer(value, verify_shape=True) 201 >>> with tf.compat.v1.Session(): 202 ... u = tf.compat.v1.get_variable('u', shape=[3, 4], 203 ... initializer=init_verify) 204 Traceback (most recent call last): 205 ... 206 TypeError: Expected Tensor's shape: (3, 4), got (8,). 207 """ 208 209 @deprecated_args(None, 210 "Call initializer instance with the dtype argument instead " 211 "of passing it to the constructor", "dtype") 212 @deprecated_args(None, "Objects must now be the required shape or no shape " 213 "can be specified", "verify_shape") 214 def __init__(self, value=0, dtype=dtypes.float32, verify_shape=False): 215 if not (np.isscalar(value) or isinstance(value, (list, tuple, np.ndarray))): 216 raise TypeError( 217 "Invalid type for initial value: %s (expected Python scalar, list or " 218 "tuple of values, or numpy.ndarray)." % type(value)) 219 220 self.value = value 221 self.dtype = dtypes.as_dtype(dtype) 222 self._verify_shape = verify_shape 223 224 def __call__(self, shape, dtype=None, partition_info=None, verify_shape=None): 225 if dtype is None: 226 dtype = self.dtype 227 if verify_shape is None: 228 verify_shape = self._verify_shape 229 return constant_op.constant_v1( 230 self.value, dtype=dtype, shape=shape, verify_shape=verify_shape) 231 232 def get_config(self): 233 # We don't include `verify_shape` for compatibility with Keras. 234 # `verify_shape` should be passed as an argument to `__call__` rather 235 # than as a constructor argument: conceptually it isn't a property 236 # of the initializer. 237 return {"value": self.value, "dtype": self.dtype.name} 238 239 240@tf_export(v1=["initializers.random_uniform", "random_uniform_initializer"]) 241@deprecation.deprecated_endpoints("initializers.random_uniform") 242class RandomUniform(Initializer): 243 """Initializer that generates tensors with a uniform distribution. 244 245 Args: 246 minval: A python scalar or a scalar tensor. Lower bound of the range of 247 random values to generate. 248 maxval: A python scalar or a scalar tensor. Upper bound of the range of 249 random values to generate. Defaults to 1 for float types. 250 seed: A Python integer. Used to create random seeds. See 251 `tf.compat.v1.set_random_seed` for behavior. 252 dtype: Default data type, used if no `dtype` argument is provided when 253 calling the initializer. 254 """ 255 256 @deprecated_args(None, 257 "Call initializer instance with the dtype argument instead " 258 "of passing it to the constructor", "dtype") 259 def __init__(self, minval=0, maxval=None, seed=None, dtype=dtypes.float32): 260 self.minval = minval 261 self.maxval = maxval 262 self.seed = seed 263 self.dtype = dtypes.as_dtype(dtype) 264 265 def __call__(self, shape, dtype=None, partition_info=None): 266 if dtype is None: 267 dtype = self.dtype 268 return random_ops.random_uniform( 269 shape, self.minval, self.maxval, dtype, seed=self.seed) 270 271 def get_config(self): 272 return { 273 "minval": self.minval, 274 "maxval": self.maxval, 275 "seed": self.seed, 276 "dtype": self.dtype.name 277 } 278 279 280@tf_export(v1=["initializers.random_normal", "random_normal_initializer"]) 281@deprecation.deprecated_endpoints("initializers.random_normal") 282class RandomNormal(Initializer): 283 """Initializer that generates tensors with a normal distribution. 284 285 Args: 286 mean: a python scalar or a scalar tensor. Mean of the random values to 287 generate. 288 stddev: a python scalar or a scalar tensor. Standard deviation of the random 289 values to generate. 290 seed: A Python integer. Used to create random seeds. See 291 `tf.compat.v1.set_random_seed` for behavior. 292 dtype: Default data type, used if no `dtype` argument is provided when 293 calling the initializer. Only floating point types are supported. 294 """ 295 296 @deprecated_args(None, 297 "Call initializer instance with the dtype argument instead " 298 "of passing it to the constructor", "dtype") 299 def __init__(self, mean=0.0, stddev=1.0, seed=None, dtype=dtypes.float32): 300 self.mean = mean 301 self.stddev = stddev 302 self.seed = seed 303 self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype)) 304 305 def __call__(self, shape, dtype=None, partition_info=None): 306 if dtype is None: 307 dtype = self.dtype 308 return random_ops.random_normal( 309 shape, self.mean, self.stddev, dtype, seed=self.seed) 310 311 def get_config(self): 312 return { 313 "mean": self.mean, 314 "stddev": self.stddev, 315 "seed": self.seed, 316 "dtype": self.dtype.name 317 } 318 319 320@tf_export(v1=["initializers.truncated_normal", "truncated_normal_initializer"]) 321@deprecation.deprecated_endpoints("initializers.truncated_normal", 322 "truncated_normal_initializer") 323class TruncatedNormal(Initializer): 324 """Initializer that generates a truncated normal distribution. 325 326 These values are similar to values from a `random_normal_initializer` 327 except that values more than two standard deviations from the mean 328 are discarded and re-drawn. This is the recommended initializer for 329 neural network weights and filters. 330 331 Args: 332 mean: a python scalar or a scalar tensor. Mean of the random values to 333 generate. 334 stddev: a python scalar or a scalar tensor. Standard deviation of the random 335 values to generate. 336 seed: A Python integer. Used to create random seeds. See 337 `tf.compat.v1.set_random_seed` for behavior. 338 dtype: Default data type, used if no `dtype` argument is provided when 339 calling the initializer. Only floating point types are supported. 340 """ 341 342 @deprecated_args(None, 343 "Call initializer instance with the dtype argument instead " 344 "of passing it to the constructor", "dtype") 345 def __init__(self, mean=0.0, stddev=1.0, seed=None, dtype=dtypes.float32): 346 self.mean = mean 347 self.stddev = stddev 348 self.seed = seed 349 self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype)) 350 351 def __call__(self, shape, dtype=None, partition_info=None): 352 if dtype is None: 353 dtype = self.dtype 354 return random_ops.truncated_normal( 355 shape, self.mean, self.stddev, dtype, seed=self.seed) 356 357 def get_config(self): 358 return { 359 "mean": self.mean, 360 "stddev": self.stddev, 361 "seed": self.seed, 362 "dtype": self.dtype.name 363 } 364 365 366@tf_export(v1=[ 367 "initializers.uniform_unit_scaling", "uniform_unit_scaling_initializer" 368]) 369@deprecation.deprecated_endpoints("uniform_unit_scaling_initializer", 370 "initializers.uniform_unit_scaling") 371class UniformUnitScaling(Initializer): 372 """Initializer that generates tensors without scaling variance. 373 374 When initializing a deep network, it is in principle advantageous to keep 375 the scale of the input variance constant, so it does not explode or diminish 376 by reaching the final layer. If the input is `x` and the operation `x * W`, 377 and we want to initialize `W` uniformly at random, we need to pick `W` from 378 379 [-sqrt(3) / sqrt(dim), sqrt(3) / sqrt(dim)] 380 381 to keep the scale intact, where `dim = W.shape[0]` (the size of the input). 382 A similar calculation for convolutional networks gives an analogous result 383 with `dim` equal to the product of the first 3 dimensions. When 384 nonlinearities are present, we need to multiply this by a constant `factor`. 385 See (Sussillo et al., 2014) for deeper motivation, experiments 386 and the calculation of constants. In section 2.3 there, the constants were 387 numerically computed: for a linear layer it's 1.0, relu: ~1.43, tanh: ~1.15. 388 389 Args: 390 factor: Float. A multiplicative factor by which the values will be scaled. 391 seed: A Python integer. Used to create random seeds. See 392 `tf.compat.v1.set_random_seed` for behavior. 393 dtype: Default data type, used if no `dtype` argument is provided when 394 calling the initializer. Only floating point types are supported. 395 References: 396 [Sussillo et al., 2014](https://arxiv.org/abs/1412.6558) 397 ([pdf](http://arxiv.org/pdf/1412.6558.pdf)) 398 """ 399 400 @deprecated_args(None, 401 "Call initializer instance with the dtype argument instead " 402 "of passing it to the constructor", "dtype") 403 @deprecated(None, 404 "Use tf.initializers.variance_scaling instead with distribution=" 405 "uniform to get equivalent behavior.") 406 def __init__(self, factor=1.0, seed=None, dtype=dtypes.float32): 407 self.factor = factor 408 self.seed = seed 409 self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype)) 410 411 def __call__(self, shape, dtype=None, partition_info=None): 412 if dtype is None: 413 dtype = self.dtype 414 scale_shape = shape 415 if partition_info is not None: 416 scale_shape = partition_info.full_shape 417 418 input_size = 1.0 419 # Estimating input size is not possible to do perfectly, but we try. 420 # The estimate, obtained by multiplying all dimensions but the last one, 421 # is the right thing for matrix multiply and convolutions (see above). 422 for dim in scale_shape[:-1]: 423 input_size *= float(dim) 424 # Avoid errors when initializing zero-size tensors. 425 input_size = max(input_size, 1.0) 426 max_val = math.sqrt(3 / input_size) * self.factor 427 return random_ops.random_uniform( 428 shape, -max_val, max_val, dtype, seed=self.seed) 429 430 def get_config(self): 431 return {"factor": self.factor, "seed": self.seed, "dtype": self.dtype.name} 432 433 434@tf_export(v1=["initializers.variance_scaling", "variance_scaling_initializer"]) 435@deprecation.deprecated_endpoints("initializers.variance_scaling", 436 "variance_scaling_initializer") 437class VarianceScaling(Initializer): 438 """Initializer capable of adapting its scale to the shape of weights tensors. 439 440 With `distribution="truncated_normal" or "untruncated_normal"`, 441 samples are drawn from a truncated/untruncated normal 442 distribution with a mean of zero and a standard deviation (after truncation, 443 if used) `stddev = sqrt(scale / n)` 444 where n is: 445 - number of input units in the weight tensor, if mode = "fan_in" 446 - number of output units, if mode = "fan_out" 447 - average of the numbers of input and output units, if mode = "fan_avg" 448 449 With `distribution="uniform"`, samples are drawn from a uniform distribution 450 within [-limit, limit], with `limit = sqrt(3 * scale / n)`. 451 452 Args: 453 scale: Scaling factor (positive float). 454 mode: One of "fan_in", "fan_out", "fan_avg". 455 distribution: Random distribution to use. One of "normal", "uniform". 456 seed: A Python integer. Used to create random seeds. See 457 `tf.compat.v1.set_random_seed` for behavior. 458 dtype: Default data type, used if no `dtype` argument is provided when 459 calling the initializer. Only floating point types are supported. 460 461 Raises: 462 ValueError: In case of an invalid value for the "scale", mode" or 463 "distribution" arguments. 464 """ 465 466 @deprecated_args(None, 467 "Call initializer instance with the dtype argument instead " 468 "of passing it to the constructor", "dtype") 469 @deprecated_arg_values( 470 None, 471 "`normal` is a deprecated alias for `truncated_normal`", 472 distribution="normal") 473 def __init__(self, 474 scale=1.0, 475 mode="fan_in", 476 distribution="truncated_normal", 477 seed=None, 478 dtype=dtypes.float32): 479 if scale <= 0.: 480 raise ValueError("`scale` must be positive float.") 481 if mode not in {"fan_in", "fan_out", "fan_avg"}: 482 raise ValueError("Invalid `mode` argument:", mode) 483 distribution = distribution.lower() 484 if distribution not in { 485 "normal", "uniform", "truncated_normal", "untruncated_normal" 486 }: 487 raise ValueError("Invalid `distribution` argument:", distribution) 488 self.scale = scale 489 self.mode = mode 490 self.distribution = distribution 491 self.seed = seed 492 self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype)) 493 494 def __call__(self, shape, dtype=None, partition_info=None): 495 if dtype is None: 496 dtype = self.dtype 497 scale = self.scale 498 scale_shape = shape 499 if partition_info is not None: 500 scale_shape = partition_info.full_shape 501 fan_in, fan_out = _compute_fans(scale_shape) 502 if self.mode == "fan_in": 503 scale /= max(1., fan_in) 504 elif self.mode == "fan_out": 505 scale /= max(1., fan_out) 506 else: 507 scale /= max(1., (fan_in + fan_out) / 2.) 508 if self.distribution == "normal" or self.distribution == "truncated_normal": 509 # constant taken from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) 510 stddev = math.sqrt(scale) / .87962566103423978 511 return random_ops.truncated_normal( 512 shape, 0.0, stddev, dtype, seed=self.seed) 513 elif self.distribution == "untruncated_normal": 514 stddev = math.sqrt(scale) 515 return random_ops.random_normal(shape, 0.0, stddev, dtype, seed=self.seed) 516 else: 517 limit = math.sqrt(3.0 * scale) 518 return random_ops.random_uniform( 519 shape, -limit, limit, dtype, seed=self.seed) 520 521 def get_config(self): 522 return { 523 "scale": self.scale, 524 "mode": self.mode, 525 "distribution": self.distribution, 526 "seed": self.seed, 527 "dtype": self.dtype.name 528 } 529 530 531@tf_export(v1=["initializers.orthogonal", "orthogonal_initializer"]) 532@deprecation.deprecated_endpoints("initializers.orthogonal", 533 "orthogonal_initializer") 534class Orthogonal(Initializer): 535 """Initializer that generates an orthogonal matrix. 536 537 If the shape of the tensor to initialize is two-dimensional, it is initialized 538 with an orthogonal matrix obtained from the QR decomposition of a matrix of 539 random numbers drawn from a normal distribution. 540 If the matrix has fewer rows than columns then the output will have orthogonal 541 rows. Otherwise, the output will have orthogonal columns. 542 543 If the shape of the tensor to initialize is more than two-dimensional, 544 a matrix of shape `(shape[0] * ... * shape[n - 2], shape[n - 1])` 545 is initialized, where `n` is the length of the shape vector. 546 The matrix is subsequently reshaped to give a tensor of the desired shape. 547 548 Args: 549 gain: multiplicative factor to apply to the orthogonal matrix 550 seed: A Python integer. Used to create random seeds. See 551 `tf.compat.v1.set_random_seed` for behavior. 552 dtype: Default data type, used if no `dtype` argument is provided when 553 calling the initializer. Only floating point types are supported. 554 References: 555 [Saxe et al., 2014](https://openreview.net/forum?id=_wzZwKpTDF_9C) 556 ([pdf](https://arxiv.org/pdf/1312.6120.pdf)) 557 """ 558 559 @deprecated_args(None, 560 "Call initializer instance with the dtype argument instead " 561 "of passing it to the constructor", "dtype") 562 def __init__(self, gain=1.0, seed=None, dtype=dtypes.float32): 563 self.gain = gain 564 self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype)) 565 self.seed = seed 566 567 def __call__(self, shape, dtype=None, partition_info=None): 568 if dtype is None: 569 dtype = self.dtype 570 # Check the shape 571 if len(shape) < 2: 572 raise ValueError("The tensor to initialize must be " 573 "at least two-dimensional") 574 # Flatten the input shape with the last dimension remaining 575 # its original shape so it works for conv2d 576 num_rows = 1 577 for dim in shape[:-1]: 578 num_rows *= dim 579 num_rows = int(num_rows) 580 num_cols = int(shape[-1]) 581 if num_rows < num_cols: 582 flat_shape = (num_cols, num_rows) 583 else: 584 flat_shape = (num_rows, num_cols) 585 586 # Generate a random matrix 587 a = random_ops.random_normal(flat_shape, dtype=dtype, seed=self.seed) 588 # Compute the qr factorization 589 q, r = gen_linalg_ops.qr(a, full_matrices=False) 590 # Make Q uniform 591 d = array_ops.diag_part(r) 592 q *= math_ops.sign(d) 593 if num_rows < num_cols: 594 q = array_ops.matrix_transpose(q) 595 return self.gain * array_ops.reshape(q, shape) 596 597 def get_config(self): 598 return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name} 599 600 601# Note these haven't been ported to TF2.0. They are not currently visible and 602# the tests are non trivial to port 603class ConvolutionDeltaOrthogonal(Initializer): 604 """Initializer that generates a delta orthogonal kernel for ConvNets. 605 606 The shape of the tensor must have length 3, 4 or 5. The number of input 607 filters must not exceed the number of output filters. The center pixels of the 608 tensor form an orthogonal matrix. Other pixels are set to be zero. See 609 algorithm 2 in (Xiao et al., 2018). 610 611 612 Args: 613 gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1. 614 The 2-norm of an input is multiplied by a factor of `gain` after applying 615 this convolution. 616 seed: A Python integer. Used to create random seeds. See 617 `tf.compat.v1.set_random_seed` for behavior. 618 dtype: Default data type, used if no `dtype` argument is provided when 619 calling the initializer. Only floating point types are supported. 620 References: 621 [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html) 622 ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf)) 623 """ 624 625 def __init__(self, gain=1.0, seed=None, dtype=dtypes.float32): 626 self.gain = gain 627 self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype)) 628 self.seed = seed 629 630 def __call__(self, shape, dtype=None, partition_info=None): 631 if dtype is None: 632 dtype = self.dtype 633 # Check the shape 634 if len(shape) < 3 or len(shape) > 5: 635 raise ValueError("The tensor to initialize must be at least " 636 "three-dimensional and at most five-dimensional") 637 638 if shape[-2] > shape[-1]: 639 raise ValueError("In_filters cannot be greater than out_filters.") 640 641 # Generate a random matrix 642 a = random_ops.random_normal([shape[-1], shape[-1]], 643 dtype=dtype, 644 seed=self.seed) 645 # Compute the qr factorization 646 q, r = gen_linalg_ops.qr(a, full_matrices=False) 647 # Make Q uniform 648 d = array_ops.diag_part(r) 649 q *= math_ops.sign(d) 650 q = q[:shape[-2], :] 651 q *= math_ops.cast(self.gain, dtype=dtype) 652 if len(shape) == 3: 653 weight = array_ops.scatter_nd([[(shape[0] - 1) // 2]], 654 array_ops.expand_dims(q, 0), shape) 655 elif len(shape) == 4: 656 weight = array_ops.scatter_nd([[(shape[0] - 1) // 2, 657 (shape[1] - 1) // 2]], 658 array_ops.expand_dims(q, 0), shape) 659 else: 660 weight = array_ops.scatter_nd([[(shape[0] - 1) // 2, (shape[1] - 1) // 2, 661 (shape[2] - 1) // 2]], 662 array_ops.expand_dims(q, 0), shape) 663 return weight 664 665 def get_config(self): 666 return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name} 667 668 669class ConvolutionOrthogonal(Initializer): 670 """Initializer that generates orthogonal kernel for ConvNets. 671 672 Base class used to construct 1D, 2D and 3D orthogonal kernels for convolution. 673 674 Args: 675 gain: multiplicative factor to apply to the orthogonal matrix. Default is 1. 676 The 2-norm of an input is multiplied by a factor of `gain` after applying 677 this convolution. 678 seed: A Python integer. Used to create random seeds. See 679 `tf.compat.v1.set_random_seed` for behavior. 680 dtype: Default data type, used if no `dtype` argument is provided when 681 calling the initializer. Only floating point types are supported. 682 References: 683 [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html) 684 ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf)) 685 """ 686 687 def __init__(self, gain=1.0, seed=None, dtype=dtypes.float32): 688 self.gain = gain 689 self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype)) 690 self.seed = seed 691 692 def __call__(self, shape, dtype=None, partition_info=None): 693 raise NotImplementedError 694 695 def get_config(self): 696 return {"gain": self.gain, "seed": self.seed, "dtype": self.dtype.name} 697 698 # Helper functions. 699 def _orthogonal_matrix(self, n): 700 """Construct an n x n orthogonal matrix. 701 702 Args: 703 n: Dimension. 704 705 Returns: 706 A n x n orthogonal matrix. 707 """ 708 a = random_ops.random_normal([n, n], dtype=self.dtype, seed=self.seed) 709 if self.seed: 710 self.seed += 1 711 q, r = gen_linalg_ops.qr(a) 712 d = array_ops.diag_part(r) 713 # make q uniform 714 q *= math_ops.sign(d) 715 return q 716 717 def _symmetric_projection(self, n): 718 """Compute a n x n symmetric projection matrix. 719 720 Args: 721 n: Dimension. 722 723 Returns: 724 A n x n symmetric projection matrix, i.e. a matrix P s.t. P=P*P, P=P^T. 725 """ 726 q = self._orthogonal_matrix(n) 727 # randomly zeroing out some columns 728 mask = math_ops.cast( 729 random_ops.random_normal([n], seed=self.seed) > 0, self.dtype) 730 if self.seed: 731 self.seed += 1 732 c = math_ops.multiply(q, mask) 733 return math_ops.matmul(c, array_ops.matrix_transpose(c)) 734 735 736class ConvolutionOrthogonal2D(ConvolutionOrthogonal): 737 """Initializer that generates a 2D orthogonal kernel for ConvNets. 738 739 The shape of the tensor must have length 4. The number of input 740 filters must not exceed the number of output filters. 741 The orthogonality(==isometry) is exact when the inputs are circular padded. 742 There are finite-width effects with non-circular padding (e.g. zero padding). 743 See algorithm 1 in (Xiao et al., 2018). 744 745 Args: 746 gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1. 747 This has the effect of scaling the output 2-norm by a factor of `gain`. 748 seed: A Python integer. Used to create random seeds. See 749 `tf.compat.v1.set_random_seed` for behavior. 750 dtype: Default data type, used if no `dtype` argument is provided when 751 calling the initializer. Only floating point types are supported. 752 References: 753 [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html) 754 ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf)) 755 """ 756 757 def __call__(self, shape, dtype=None, partition_info=None): 758 if dtype is None: 759 dtype = self.dtype 760 if len(shape) != 4: 761 raise ValueError("The tensor to initialize must be four-dimensional") 762 763 if shape[-2] > shape[-1]: 764 raise ValueError("In_filters cannot be greater than out_filters.") 765 766 if shape[0] != shape[1]: 767 raise ValueError("Kernel sizes must be equal.") 768 769 kernel = self._orthogonal_kernel(shape[0], shape[2], shape[3]) 770 kernel *= math_ops.cast(self.gain, dtype=dtype) 771 return kernel 772 773 def _dict_to_tensor(self, x, k1, k2): 774 """Convert a dictionary to a tensor. 775 776 Args: 777 x: A k1 * k2 dictionary. 778 k1: First dimension of x. 779 k2: Second dimension of x. 780 781 Returns: 782 A k1 * k2 tensor. 783 """ 784 785 return array_ops.stack([array_ops.stack([x[i, j] for j in range(k2)]) 786 for i in range(k1)]) 787 788 def _block_orth(self, p1, p2): 789 """Construct a 2 x 2 kernel. 790 791 Used to construct orthgonal kernel. 792 793 Args: 794 p1: A symmetric projection matrix. 795 p2: A symmetric projection matrix. 796 797 Returns: 798 A 2 x 2 kernel [[p1p2, p1(1-p2)], 799 [(1-p1)p2, (1-p1)(1-p2)]]. 800 Raises: 801 ValueError: If the dimensions of p1 and p2 are different. 802 """ 803 if p1.shape.as_list() != p2.shape.as_list(): 804 raise ValueError("The dimension of the matrices must be the same.") 805 n = p1.shape.as_list()[0] 806 kernel2x2 = {} 807 eye = linalg_ops_impl.eye(n, dtype=self.dtype) 808 kernel2x2[0, 0] = math_ops.matmul(p1, p2) 809 kernel2x2[0, 1] = math_ops.matmul(p1, (eye - p2)) 810 kernel2x2[1, 0] = math_ops.matmul((eye - p1), p2) 811 kernel2x2[1, 1] = math_ops.matmul((eye - p1), (eye - p2)) 812 813 return kernel2x2 814 815 def _matrix_conv(self, m1, m2): 816 """Matrix convolution. 817 818 Args: 819 m1: A k x k dictionary, each element is a n x n matrix. 820 m2: A l x l dictionary, each element is a n x n matrix. 821 822 Returns: 823 (k + l - 1) * (k + l - 1) dictionary each element is a n x n matrix. 824 Raises: 825 ValueError: if the entries of m1 and m2 are of different dimensions. 826 """ 827 828 n = (m1[0, 0]).shape.as_list()[0] 829 if n != (m2[0, 0]).shape.as_list()[0]: 830 raise ValueError("The entries in matrices m1 and m2 " 831 "must have the same dimensions!") 832 k = int(np.sqrt(len(m1))) 833 l = int(np.sqrt(len(m2))) 834 result = {} 835 size = k + l - 1 836 # Compute matrix convolution between m1 and m2. 837 for i in range(size): 838 for j in range(size): 839 result[i, j] = array_ops.zeros([n, n], self.dtype) 840 for index1 in range(min(k, i + 1)): 841 for index2 in range(min(k, j + 1)): 842 if (i - index1) < l and (j - index2) < l: 843 result[i, j] += math_ops.matmul(m1[index1, index2], 844 m2[i - index1, j - index2]) 845 return result 846 847 def _orthogonal_kernel(self, ksize, cin, cout): 848 """Construct orthogonal kernel for convolution. 849 850 Args: 851 ksize: Kernel size. 852 cin: Number of input channels. 853 cout: Number of output channels. 854 855 Returns: 856 An [ksize, ksize, cin, cout] orthogonal kernel. 857 Raises: 858 ValueError: If cin > cout. 859 """ 860 if cin > cout: 861 raise ValueError("The number of input channels cannot exceed " 862 "the number of output channels.") 863 orth = self._orthogonal_matrix(cout)[0:cin, :] 864 if ksize == 1: 865 return array_ops.expand_dims(array_ops.expand_dims(orth, 0), 0) 866 867 p = self._block_orth( 868 self._symmetric_projection(cout), self._symmetric_projection(cout)) 869 for _ in range(ksize - 2): 870 temp = self._block_orth( 871 self._symmetric_projection(cout), self._symmetric_projection(cout)) 872 p = self._matrix_conv(p, temp) 873 for i in range(ksize): 874 for j in range(ksize): 875 p[i, j] = math_ops.matmul(orth, p[i, j]) 876 877 return self._dict_to_tensor(p, ksize, ksize) 878 879 880class ConvolutionOrthogonal1D(ConvolutionOrthogonal): 881 """Initializer that generates a 1D orthogonal kernel for ConvNets. 882 883 The shape of the tensor must have length 3. The number of input 884 filters must not exceed the number of output filters. 885 The orthogonality(==isometry) is exact when the inputs are circular padded. 886 There are finite-width effects with non-circular padding (e.g. zero padding). 887 See algorithm 1 in (Xiao et al., 2018). 888 889 Args: 890 gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1. 891 The 2-norm of an input is multiplied by a factor of `gain` after applying 892 this convolution. 893 seed: A Python integer. Used to create random seeds. See 894 `tf.compat.v1.set_random_seed` for behavior. 895 dtype: Default data type, used if no `dtype` argument is provided when 896 calling the initializer. Only floating point types are supported. 897 References: 898 [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html) 899 ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf)) 900 """ 901 902 def __call__(self, shape, dtype=None, partition_info=None): 903 if dtype is None: 904 dtype = self.dtype 905 if len(shape) != 3: 906 raise ValueError("The tensor to initialize must be three-dimensional") 907 908 if shape[-2] > shape[-1]: 909 raise ValueError("In_filters cannot be greater than out_filters.") 910 911 kernel = self._orthogonal_kernel(shape[0], shape[-2], shape[-1]) 912 kernel *= math_ops.cast(self.gain, dtype=dtype) 913 return kernel 914 915 def _dict_to_tensor(self, x, k): 916 """Convert a dictionary to a tensor. 917 918 Args: 919 x: A dictionary of length k. 920 k: Dimension of x. 921 922 Returns: 923 A tensor with the same dimension. 924 """ 925 926 return array_ops.stack([x[i] for i in range(k)]) 927 928 def _block_orth(self, projection_matrix): 929 """Construct a kernel. 930 931 Used to construct orthgonal kernel. 932 933 Args: 934 projection_matrix: A symmetric projection matrix of size n x n. 935 936 Returns: 937 [projection_matrix, (1 - projection_matrix)]. 938 """ 939 n = projection_matrix.shape.as_list()[0] 940 kernel = {} 941 eye = linalg_ops_impl.eye(n, dtype=self.dtype) 942 kernel[0] = projection_matrix 943 kernel[1] = eye - projection_matrix 944 return kernel 945 946 def _matrix_conv(self, m1, m2): 947 """Matrix convolution. 948 949 Args: 950 m1: A dictionary of length k, each element is a n x n matrix. 951 m2: A dictionary of length l, each element is a n x n matrix. 952 953 Returns: 954 (k + l - 1) dictionary each element is a n x n matrix. 955 Raises: 956 ValueError: Ff the entries of m1 and m2 are of different dimensions. 957 """ 958 959 n = (m1[0]).shape.as_list()[0] 960 if n != (m2[0]).shape.as_list()[0]: 961 raise ValueError("The entries in matrices m1 and m2 " 962 "must have the same dimensions!") 963 k = len(m1) 964 l = len(m2) 965 result = {} 966 size = k + l - 1 967 # Compute matrix convolution between m1 and m2. 968 for i in range(size): 969 result[i] = array_ops.zeros([n, n], self.dtype) 970 for index in range(min(k, i + 1)): 971 if (i - index) < l: 972 result[i] += math_ops.matmul(m1[index], m2[i - index]) 973 return result 974 975 def _orthogonal_kernel(self, ksize, cin, cout): 976 """Construct orthogonal kernel for convolution. 977 978 Args: 979 ksize: Kernel size. 980 cin: Number of input channels. 981 cout: Number of output channels. 982 983 Returns: 984 An [ksize, ksize, cin, cout] orthogonal kernel. 985 Raises: 986 ValueError: If cin > cout. 987 """ 988 if cin > cout: 989 raise ValueError("The number of input channels cannot exceed " 990 "the number of output channels.") 991 orth = self._orthogonal_matrix(cout)[0:cin, :] 992 if ksize == 1: 993 return array_ops.expand_dims(orth, 0) 994 995 p = self._block_orth(self._symmetric_projection(cout)) 996 for _ in range(ksize - 2): 997 temp = self._block_orth(self._symmetric_projection(cout)) 998 p = self._matrix_conv(p, temp) 999 for i in range(ksize): 1000 p[i] = math_ops.matmul(orth, p[i]) 1001 1002 return self._dict_to_tensor(p, ksize) 1003 1004 1005class ConvolutionOrthogonal3D(ConvolutionOrthogonal): 1006 """Initializer that generates a 3D orthogonal kernel for ConvNets. 1007 1008 The shape of the tensor must have length 5. The number of input 1009 filters must not exceed the number of output filters. 1010 The orthogonality(==isometry) is exact when the inputs are circular padded. 1011 There are finite-width effects with non-circular padding (e.g. zero padding). 1012 See algorithm 1 (Xiao et al., 2018). 1013 1014 Args: 1015 gain: Multiplicative factor to apply to the orthogonal matrix. Default is 1. 1016 The 2-norm of an input is multiplied by a factor of `gain` after applying 1017 this convolution. 1018 seed: A Python integer. Used to create random seeds. See 1019 `tf.compat.v1.set_random_seed` for behavior. 1020 dtype: Default data type, used if no `dtype` argument is provided when 1021 calling the initializer. Only floating point types are supported. 1022 References: 1023 [Xiao et al., 2018](http://proceedings.mlr.press/v80/xiao18a.html) 1024 ([pdf](http://proceedings.mlr.press/v80/xiao18a/xiao18a.pdf)) 1025 """ 1026 1027 def __call__(self, shape, dtype=None, partition_info=None): 1028 if dtype is None: 1029 dtype = self.dtype 1030 if len(shape) != 5: 1031 raise ValueError("The tensor to initialize must be five-dimensional") 1032 1033 if shape[-2] > shape[-1]: 1034 raise ValueError("In_filters cannot be greater than out_filters.") 1035 1036 if shape[0] != shape[1] or shape[0] != shape[2]: 1037 raise ValueError("Kernel sizes must be equal.") 1038 1039 kernel = self._orthogonal_kernel(shape[0], shape[-2], shape[-1]) 1040 kernel *= math_ops.cast(self.gain, dtype=dtype) 1041 return kernel 1042 1043 def _dict_to_tensor(self, x, k1, k2, k3): 1044 """Convert a dictionary to a tensor. 1045 1046 Args: 1047 x: A k1 * k2 dictionary. 1048 k1: First dimension of x. 1049 k2: Second dimension of x. 1050 k3: Third dimension of x. 1051 1052 Returns: 1053 A k1 * k2 * k3 tensor. 1054 """ 1055 1056 return array_ops.stack([array_ops.stack( 1057 [array_ops.stack([x[i, j, k] for k in range(k3)]) 1058 for j in range(k2)]) for i in range(k1)]) 1059 1060 def _block_orth(self, p1, p2, p3): 1061 """Construct a 3 x 3 kernel. 1062 1063 Used to construct orthgonal kernel. 1064 1065 Args: 1066 p1: A symmetric projection matrix. 1067 p2: A symmetric projection matrix. 1068 p3: A symmetric projection matrix. 1069 1070 Returns: 1071 A 2 x 2 x 2 kernel. 1072 Raises: 1073 ValueError: If the dimensions of p1, p2 and p3 are different. 1074 """ 1075 p1_shape = p1.shape.as_list() 1076 if p1_shape != p2.shape.as_list() or p1_shape != p3.shape.as_list(): 1077 raise ValueError("The dimension of the matrices must be the same.") 1078 n = p1_shape[0] 1079 eye = linalg_ops_impl.eye(n, dtype=self.dtype) 1080 kernel2x2x2 = {} 1081 1082 def matmul(p1, p2, p3): 1083 return math_ops.matmul(math_ops.matmul(p1, p2), p3) 1084 1085 def cast(i, p): 1086 """Return p or (1-p).""" 1087 return i * p + (1 - i) * (eye - p) 1088 1089 for i in [0, 1]: 1090 for j in [0, 1]: 1091 for k in [0, 1]: 1092 kernel2x2x2[i, j, k] = matmul(cast(i, p1), cast(j, p2), cast(k, p3)) 1093 return kernel2x2x2 1094 1095 def _matrix_conv(self, m1, m2): 1096 """Matrix convolution. 1097 1098 Args: 1099 m1: is a k x k x k dictionary, each element is a n x n matrix. 1100 m2: is a l x l x l dictionary, each element is a n x n matrix. 1101 1102 Returns: 1103 (k + l - 1) x (k + l - 1) x (k + l - 1) dictionary each 1104 element is a n x n matrix. 1105 Raises: 1106 ValueError: if the entries of m1 and m2 are of different dimensions. 1107 """ 1108 1109 n = (m1[0, 0, 0]).shape.as_list()[0] 1110 if n != (m2[0, 0, 0]).shape.as_list()[0]: 1111 raise ValueError("The entries in matrices m1 and m2 " 1112 "must have the same dimensions!") 1113 k = int(np.cbrt(len(m1))) 1114 l = int(np.cbrt(len(m2))) 1115 result = {} 1116 size = k + l - 1 1117 # Compute matrix convolution between m1 and m2. 1118 for i in range(size): 1119 for j in range(size): 1120 for r in range(size): 1121 result[i, j, r] = array_ops.zeros([n, n], self.dtype) 1122 for index1 in range(min(k, i + 1)): 1123 for index2 in range(min(k, j + 1)): 1124 for index3 in range(min(k, r + 1)): 1125 if (i - index1) < l and (j - index2) < l and (r - index3) < l: 1126 result[i, j, r] += math_ops.matmul( 1127 m1[index1, index2, index3], 1128 m2[i - index1, j - index2, r - index3]) 1129 return result 1130 1131 def _orthogonal_kernel(self, ksize, cin, cout): 1132 """Construct orthogonal kernel for convolution. 1133 1134 Args: 1135 ksize: Kernel size. 1136 cin: Number of input channels. 1137 cout: Number of output channels. 1138 1139 Returns: 1140 An [ksize, ksize, ksize, cin, cout] orthogonal kernel. 1141 Raises: 1142 ValueError: If cin > cout. 1143 """ 1144 if cin > cout: 1145 raise ValueError("The number of input channels cannot exceed " 1146 "the number of output channels.") 1147 orth = self._orthogonal_matrix(cout)[0:cin, :] 1148 if ksize == 1: 1149 return array_ops.expand_dims( 1150 array_ops.expand_dims(array_ops.expand_dims(orth, 0), 0), 0) 1151 1152 p = self._block_orth( 1153 self._symmetric_projection(cout), self._symmetric_projection(cout), 1154 self._symmetric_projection(cout)) 1155 for _ in range(ksize - 2): 1156 temp = self._block_orth( 1157 self._symmetric_projection(cout), self._symmetric_projection(cout), 1158 self._symmetric_projection(cout)) 1159 p = self._matrix_conv(p, temp) 1160 for i in range(ksize): 1161 for j in range(ksize): 1162 for k in range(ksize): 1163 p[i, j, k] = math_ops.matmul(orth, p[i, j, k]) 1164 1165 return self._dict_to_tensor(p, ksize, ksize, ksize) 1166 1167 1168@tf_export(v1=["initializers.identity"]) 1169@deprecation.deprecated_endpoints("initializers.identity") 1170class Identity(Initializer): 1171 """Initializer that generates the identity matrix. 1172 1173 Only use for 2D matrices. 1174 1175 Args: 1176 gain: Multiplicative factor to apply to the identity matrix. 1177 dtype: Default data type, used if no `dtype` argument is provided when 1178 calling the initializer. Only floating point types are supported. 1179 """ 1180 1181 @deprecated_args(None, 1182 "Call initializer instance with the dtype argument instead " 1183 "of passing it to the constructor", "dtype") 1184 def __init__(self, gain=1.0, dtype=dtypes.float32): 1185 self.gain = gain 1186 self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype)) 1187 1188 def __call__(self, shape, dtype=None, partition_info=None): 1189 full_shape = shape if partition_info is None else partition_info.full_shape 1190 if len(full_shape) != 2: 1191 raise ValueError( 1192 "Identity matrix initializer can only be used for 2D matrices.") 1193 if dtype is None: 1194 dtype = self.dtype 1195 if isinstance(full_shape, tensor_shape.TensorShape): 1196 full_shape = full_shape.as_list() 1197 initializer = linalg_ops_impl.eye(*full_shape, dtype=dtype) 1198 if partition_info is not None: 1199 initializer = array_ops.slice(initializer, partition_info.var_offset, 1200 shape) 1201 return self.gain * initializer 1202 1203 def get_config(self): 1204 return {"gain": self.gain, "dtype": self.dtype.name} 1205 1206 1207@tf_export(v1=["glorot_uniform_initializer", "initializers.glorot_uniform"]) 1208@deprecation.deprecated_endpoints("glorot_uniform_initializer", 1209 "initializers.glorot_uniform") 1210class GlorotUniform(VarianceScaling): 1211 """The Glorot uniform initializer, also called Xavier uniform initializer. 1212 1213 It draws samples from a uniform distribution within [-limit, limit] 1214 where `limit` is `sqrt(6 / (fan_in + fan_out))` 1215 where `fan_in` is the number of input units in the weight tensor 1216 and `fan_out` is the number of output units in the weight tensor. 1217 1218 Args: 1219 seed: A Python integer. Used to create random seeds. See 1220 `tf.compat.v1.set_random_seed` for behavior. 1221 dtype: Default data type, used if no `dtype` argument is provided when 1222 calling the initializer. Only floating point types are supported. 1223 References: 1224 [Glorot et al., 2010](http://proceedings.mlr.press/v9/glorot10a.html) 1225 ([pdf](http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf)) 1226 """ 1227 1228 @deprecated_args(None, 1229 "Call initializer instance with the dtype argument instead " 1230 "of passing it to the constructor", "dtype") 1231 def __init__(self, seed=None, dtype=dtypes.float32): 1232 super(GlorotUniform, self).__init__( 1233 scale=1.0, mode="fan_avg", distribution="uniform", seed=seed) 1234 1235 def get_config(self): 1236 return {"seed": self.seed, "dtype": self.dtype.name} 1237 1238 1239@tf_export(v1=["glorot_normal_initializer", "initializers.glorot_normal"]) 1240@deprecation.deprecated_endpoints("glorot_normal_initializer", 1241 "initializers.glorot_normal") 1242class GlorotNormal(VarianceScaling): 1243 """The Glorot normal initializer, also called Xavier normal initializer. 1244 1245 It draws samples from a truncated normal distribution centered on 0 1246 with standard deviation (after truncation) given by 1247 `stddev = sqrt(2 / (fan_in + fan_out))` where `fan_in` is the number 1248 of input units in the weight tensor and `fan_out` is the number of 1249 output units in the weight tensor. 1250 1251 Args: 1252 seed: A Python integer. Used to create random seeds. See 1253 `tf.compat.v1.set_random_seed` for behavior. 1254 dtype: Default data type, used if no `dtype` argument is provided when 1255 calling the initializer. Only floating point types are supported. 1256 References: 1257 [Glorot et al., 2010](http://proceedings.mlr.press/v9/glorot10a.html) 1258 ([pdf](http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf)) 1259 """ 1260 1261 @deprecated_args(None, 1262 "Call initializer instance with the dtype argument instead " 1263 "of passing it to the constructor", "dtype") 1264 def __init__(self, seed=None, dtype=dtypes.float32): 1265 super(GlorotNormal, self).__init__( 1266 scale=1.0, mode="fan_avg", distribution="truncated_normal", seed=seed) 1267 1268 def get_config(self): 1269 return {"seed": self.seed, "dtype": self.dtype.name} 1270 1271 1272# Aliases. 1273 1274# pylint: disable=invalid-name 1275zeros_initializer = Zeros 1276ones_initializer = Ones 1277constant_initializer = Constant 1278random_uniform_initializer = RandomUniform 1279random_normal_initializer = RandomNormal 1280truncated_normal_initializer = TruncatedNormal 1281uniform_unit_scaling_initializer = UniformUnitScaling 1282variance_scaling_initializer = VarianceScaling 1283glorot_uniform_initializer = GlorotUniform 1284glorot_normal_initializer = GlorotNormal 1285orthogonal_initializer = Orthogonal 1286identity_initializer = Identity 1287convolutional_delta_orthogonal = ConvolutionDeltaOrthogonal 1288convolutional_orthogonal_1d = ConvolutionOrthogonal1D 1289convolutional_orthogonal_2d = ConvolutionOrthogonal2D 1290convolutional_orthogonal_3d = ConvolutionOrthogonal3D 1291# pylint: enable=invalid-name 1292 1293 1294@tf_export(v1=["initializers.lecun_normal"]) 1295def lecun_normal(seed=None): 1296 """LeCun normal initializer. 1297 1298 It draws samples from a truncated normal distribution centered on 0 1299 with standard deviation (after truncation) given by 1300 `stddev = sqrt(1 / fan_in)` where `fan_in` is the number of 1301 input units in the weight tensor. 1302 1303 Args: 1304 seed: A Python integer. Used to seed the random generator. 1305 1306 Returns: 1307 An initializer. 1308 1309 References: 1310 - Self-Normalizing Neural Networks, 1311 [Klambauer et al., 1312 2017](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks) 1313 # pylint: disable=line-too-long 1314 ([pdf](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf)) 1315 - Efficient Backprop, 1316 [Lecun et al., 1998](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf) 1317 """ 1318 return VarianceScaling( 1319 scale=1., mode="fan_in", distribution="truncated_normal", seed=seed) 1320 1321 1322@tf_export(v1=["initializers.lecun_uniform"]) 1323def lecun_uniform(seed=None): 1324 """LeCun uniform initializer. 1325 1326 It draws samples from a uniform distribution within [-limit, limit] 1327 where `limit` is `sqrt(3 / fan_in)` 1328 where `fan_in` is the number of input units in the weight tensor. 1329 1330 Args: 1331 seed: A Python integer. Used to seed the random generator. 1332 1333 Returns: 1334 An initializer. 1335 1336 References: 1337 - Self-Normalizing Neural Networks, 1338 [Klambauer et al., 1339 2017](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks) 1340 # pylint: disable=line-too-long 1341 ([pdf](https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf)) 1342 - Efficient Backprop, 1343 [Lecun et al., 1998](http://yann.lecun.com/exdb/publis/pdf/lecun-98b.pdf) 1344 """ 1345 return VarianceScaling( 1346 scale=1., mode="fan_in", distribution="uniform", seed=seed) 1347 1348 1349@tf_export(v1=["initializers.he_normal"]) 1350def he_normal(seed=None): 1351 """He normal initializer. 1352 1353 It draws samples from a truncated normal distribution centered on 0 1354 with standard deviation (after truncation) given by 1355 `stddev = sqrt(2 / fan_in)` where `fan_in` is the number of 1356 input units in the weight tensor. 1357 1358 Args: 1359 seed: A Python integer. Used to seed the random generator. 1360 1361 Returns: 1362 An initializer. 1363 1364 References: 1365 [He et al., 2015] 1366 (https://www.cv-foundation.org/openaccess/content_iccv_2015/html/He_Delving_Deep_into_ICCV_2015_paper.html) 1367 # pylint: disable=line-too-long 1368 ([pdf](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf)) 1369 """ 1370 return VarianceScaling( 1371 scale=2., mode="fan_in", distribution="truncated_normal", seed=seed) 1372 1373 1374@tf_export(v1=["initializers.he_uniform"]) 1375def he_uniform(seed=None): 1376 """He uniform variance scaling initializer. 1377 1378 It draws samples from a uniform distribution within [-limit, limit] 1379 where `limit` is `sqrt(6 / fan_in)` 1380 where `fan_in` is the number of input units in the weight tensor. 1381 1382 Args: 1383 seed: A Python integer. Used to seed the random generator. 1384 1385 Returns: 1386 An initializer. 1387 1388 References: 1389 [He et al., 2015] 1390 (https://www.cv-foundation.org/openaccess/content_iccv_2015/html/He_Delving_Deep_into_ICCV_2015_paper.html) 1391 # pylint: disable=line-too-long 1392 ([pdf](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/He_Delving_Deep_into_ICCV_2015_paper.pdf)) 1393 """ 1394 return VarianceScaling( 1395 scale=2., mode="fan_in", distribution="uniform", seed=seed) 1396 1397 1398# Utility functions. 1399 1400 1401def _compute_fans(shape): 1402 """Computes the number of input and output units for a weight shape. 1403 1404 Args: 1405 shape: Integer shape tuple or TF tensor shape. 1406 1407 Returns: 1408 A tuple of integer scalars (fan_in, fan_out). 1409 """ 1410 if len(shape) < 1: # Just to avoid errors for constants. 1411 fan_in = fan_out = 1 1412 elif len(shape) == 1: 1413 fan_in = fan_out = shape[0] 1414 elif len(shape) == 2: 1415 fan_in = shape[0] 1416 fan_out = shape[1] 1417 else: 1418 # Assuming convolution kernels (2D, 3D, or more). 1419 # kernel shape: (..., input_depth, depth) 1420 receptive_field_size = 1 1421 for dim in shape[:-2]: 1422 receptive_field_size *= dim 1423 fan_in = shape[-2] * receptive_field_size 1424 fan_out = shape[-1] * receptive_field_size 1425 return int(fan_in), int(fan_out) 1426 1427 1428def _assert_float_dtype(dtype): 1429 """Validate and return floating point type based on `dtype`. 1430 1431 `dtype` must be a floating point type. 1432 1433 Args: 1434 dtype: The data type to validate. 1435 1436 Returns: 1437 Validated type. 1438 1439 Raises: 1440 ValueError: if `dtype` is not a floating point type. 1441 """ 1442 if not dtype.is_floating: 1443 raise ValueError("Expected floating point type, got %s." % dtype) 1444 return dtype 1445