1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Weight initializers for use with layers.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import math 22 23from tensorflow.python.framework import dtypes 24from tensorflow.python.ops import random_ops 25 26 27__all__ = ['xavier_initializer', 'xavier_initializer_conv2d', 28 'variance_scaling_initializer'] 29 30 31def xavier_initializer(uniform=True, seed=None, dtype=dtypes.float32): 32 """Returns an initializer performing "Xavier" initialization for weights. 33 34 This function implements the weight initialization from: 35 36 Xavier Glorot and Yoshua Bengio (2010): 37 [Understanding the difficulty of training deep feedforward neural 38 networks. International conference on artificial intelligence and 39 statistics.]( 40 http://www.jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf) 41 42 This initializer is designed to keep the scale of the gradients roughly the 43 same in all layers. In uniform distribution this ends up being the range: 44 `x = sqrt(6. / (in + out)); [-x, x]` and for normal distribution a standard 45 deviation of `sqrt(2. / (in + out))` is used. 46 47 Args: 48 uniform: Whether to use uniform or normal distributed random initialization. 49 seed: A Python integer. Used to create random seeds. See 50 `tf.set_random_seed` for behavior. 51 dtype: The data type. Only floating point types are supported. 52 53 Returns: 54 An initializer for a weight matrix. 55 """ 56 return variance_scaling_initializer(factor=1.0, mode='FAN_AVG', 57 uniform=uniform, seed=seed, dtype=dtype) 58 59xavier_initializer_conv2d = xavier_initializer 60 61 62def variance_scaling_initializer(factor=2.0, mode='FAN_IN', uniform=False, 63 seed=None, dtype=dtypes.float32): 64 """Returns an initializer that generates tensors without scaling variance. 65 66 When initializing a deep network, it is in principle advantageous to keep 67 the scale of the input variance constant, so it does not explode or diminish 68 by reaching the final layer. This initializer use the following formula: 69 70 ```python 71 if mode='FAN_IN': # Count only number of input connections. 72 n = fan_in 73 elif mode='FAN_OUT': # Count only number of output connections. 74 n = fan_out 75 elif mode='FAN_AVG': # Average number of inputs and output connections. 76 n = (fan_in + fan_out)/2.0 77 78 truncated_normal(shape, 0.0, stddev=sqrt(factor / n)) 79 ``` 80 81 * To get [Delving Deep into Rectifiers]( 82 http://arxiv.org/pdf/1502.01852v1.pdf) (also know as the "MSRA 83 initialization"), use (Default):<br/> 84 `factor=2.0 mode='FAN_IN' uniform=False` 85 * To get [Convolutional Architecture for Fast Feature Embedding]( 86 http://arxiv.org/abs/1408.5093), use:<br/> 87 `factor=1.0 mode='FAN_IN' uniform=True` 88 * To get [Understanding the difficulty of training deep feedforward neural 89 networks](http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf), 90 use:<br/> 91 `factor=1.0 mode='FAN_AVG' uniform=True.` 92 * To get `xavier_initializer` use either:<br/> 93 `factor=1.0 mode='FAN_AVG' uniform=True`, or<br/> 94 `factor=1.0 mode='FAN_AVG' uniform=False`. 95 96 Args: 97 factor: Float. A multiplicative factor. 98 mode: String. 'FAN_IN', 'FAN_OUT', 'FAN_AVG'. 99 uniform: Whether to use uniform or normal distributed random initialization. 100 seed: A Python integer. Used to create random seeds. See 101 `tf.set_random_seed` for behavior. 102 dtype: The data type. Only floating point types are supported. 103 104 Returns: 105 An initializer that generates tensors with unit variance. 106 107 Raises: 108 ValueError: if `dtype` is not a floating point type. 109 TypeError: if `mode` is not in ['FAN_IN', 'FAN_OUT', 'FAN_AVG']. 110 """ 111 if not dtype.is_floating: 112 raise TypeError('Cannot create initializer for non-floating point type.') 113 if mode not in ['FAN_IN', 'FAN_OUT', 'FAN_AVG']: 114 raise TypeError('Unknown mode %s [FAN_IN, FAN_OUT, FAN_AVG]', mode) 115 116 # pylint: disable=unused-argument 117 def _initializer(shape, dtype=dtype, partition_info=None): 118 """Initializer function.""" 119 if not dtype.is_floating: 120 raise TypeError('Cannot create initializer for non-floating point type.') 121 # Estimating fan_in and fan_out is not possible to do perfectly, but we try. 122 # This is the right thing for matrix multiply and convolutions. 123 if shape: 124 fan_in = float(shape[-2]) if len(shape) > 1 else float(shape[-1]) 125 fan_out = float(shape[-1]) 126 else: 127 fan_in = 1.0 128 fan_out = 1.0 129 for dim in shape[:-2]: 130 fan_in *= float(dim) 131 fan_out *= float(dim) 132 if mode == 'FAN_IN': 133 # Count only number of input connections. 134 n = fan_in 135 elif mode == 'FAN_OUT': 136 # Count only number of output connections. 137 n = fan_out 138 elif mode == 'FAN_AVG': 139 # Average number of inputs and output connections. 140 n = (fan_in + fan_out) / 2.0 141 if uniform: 142 # To get stddev = math.sqrt(factor / n) need to adjust for uniform. 143 limit = math.sqrt(3.0 * factor / n) 144 return random_ops.random_uniform(shape, -limit, limit, 145 dtype, seed=seed) 146 else: 147 # To get stddev = math.sqrt(factor / n) need to adjust for truncated. 148 trunc_stddev = math.sqrt(1.3 * factor / n) 149 return random_ops.truncated_normal(shape, 0.0, trunc_stddev, dtype, 150 seed=seed) 151 # pylint: enable=unused-argument 152 153 return _initializer 154