1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for initializers in init_ops.""" 16 17import numpy as np 18 19from tensorflow.core.protobuf import config_pb2 20from tensorflow.python.client import session 21from tensorflow.python.eager import context 22from tensorflow.python.framework import ops 23from tensorflow.python.framework import tensor_shape as tensor_shape_lib 24from tensorflow.python.framework import test_util 25from tensorflow.python.ops import init_ops 26from tensorflow.python.ops import variable_scope 27from tensorflow.python.ops import variables 28from tensorflow.python.platform import test 29 30 31@test_util.run_all_in_graph_and_eager_modes 32class InitializersTest(test.TestCase): 33 34 def _runner(self, 35 init, 36 shape, 37 target_mean=None, 38 target_std=None, 39 target_max=None, 40 target_min=None): 41 output = self.evaluate(init(shape)) 42 self.assertEqual(output.shape, shape) 43 lim = 3e-2 44 if target_std is not None: 45 self.assertGreater(lim, abs(output.std() - target_std)) 46 if target_mean is not None: 47 self.assertGreater(lim, abs(output.mean() - target_mean)) 48 if target_max is not None: 49 self.assertGreater(lim, abs(output.max() - target_max)) 50 if target_min is not None: 51 self.assertGreater(lim, abs(output.min() - target_min)) 52 53 def test_uniform(self): 54 shape = (9, 6, 99) 55 with self.cached_session(): 56 for tensor_shape in [shape, tensor_shape_lib.TensorShape(shape)]: 57 self._runner( 58 init_ops.RandomUniform(minval=-1, maxval=1, seed=124), 59 tensor_shape, 60 target_mean=0., 61 target_max=1, 62 target_min=-1) 63 64 def test_normal(self): 65 shape = (8, 12, 99) 66 with self.cached_session(): 67 for tensor_shape in [shape, tensor_shape_lib.TensorShape(shape)]: 68 self._runner( 69 init_ops.RandomNormal(mean=0, stddev=1, seed=153), 70 tensor_shape, 71 target_mean=0., 72 target_std=1) 73 74 def test_truncated_normal(self): 75 shape = (12, 99, 7) 76 with self.cached_session(): 77 for tensor_shape in [shape, tensor_shape_lib.TensorShape(shape)]: 78 self._runner( 79 init_ops.TruncatedNormal(mean=0, stddev=1, seed=126), 80 tensor_shape, 81 target_mean=0., 82 target_max=2, 83 target_min=-2) 84 85 def test_constant(self): 86 shape = (5, 6, 4) 87 with self.cached_session(): 88 for tensor_shape in [shape, tensor_shape_lib.TensorShape(shape)]: 89 self._runner( 90 init_ops.Constant(2), 91 tensor_shape, 92 target_mean=2, 93 target_max=2, 94 target_min=2) 95 96 def test_lecun_uniform(self): 97 shape = (5, 6, 4, 2) 98 with self.cached_session(): 99 for tensor_shape in [shape, tensor_shape_lib.TensorShape(shape)]: 100 fan_in, _ = init_ops._compute_fans(tensor_shape) 101 std = np.sqrt(1. / fan_in) 102 self._runner( 103 init_ops.lecun_uniform(seed=123), 104 tensor_shape, 105 target_mean=0., 106 target_std=std) 107 108 def test_glorot_uniform_initializer(self): 109 shape = (5, 6, 4, 2) 110 with self.cached_session(): 111 for tensor_shape in [shape, tensor_shape_lib.TensorShape(shape)]: 112 fan_in, fan_out = init_ops._compute_fans(tensor_shape) 113 std = np.sqrt(2. / (fan_in + fan_out)) 114 self._runner( 115 init_ops.glorot_uniform_initializer(seed=123), 116 tensor_shape, 117 target_mean=0., 118 target_std=std) 119 120 def test_he_uniform(self): 121 shape = (5, 6, 4, 2) 122 with self.cached_session(): 123 for tensor_shape in [shape, tensor_shape_lib.TensorShape(shape)]: 124 fan_in, _ = init_ops._compute_fans(tensor_shape) 125 std = np.sqrt(2. / fan_in) 126 self._runner( 127 init_ops.he_uniform(seed=123), 128 tensor_shape, 129 target_mean=0., 130 target_std=std) 131 132 def test_lecun_normal(self): 133 shape = (5, 6, 4, 2) 134 with self.cached_session(): 135 for tensor_shape in [shape, tensor_shape_lib.TensorShape(shape)]: 136 fan_in, _ = init_ops._compute_fans(tensor_shape) 137 std = np.sqrt(1. / fan_in) 138 self._runner( 139 init_ops.lecun_normal(seed=123), 140 tensor_shape, 141 target_mean=0., 142 target_std=std) 143 144 def test_glorot_normal_initializer(self): 145 shape = (5, 6, 4, 2) 146 with self.cached_session(): 147 for tensor_shape in [shape, tensor_shape_lib.TensorShape(shape)]: 148 fan_in, fan_out = init_ops._compute_fans(tensor_shape) 149 std = np.sqrt(2. / (fan_in + fan_out)) 150 self._runner( 151 init_ops.glorot_normal_initializer(seed=123), 152 tensor_shape, 153 target_mean=0., 154 target_std=std) 155 156 def test_he_normal(self): 157 shape = (5, 6, 4, 2) 158 with self.cached_session(): 159 for tensor_shape in [shape, tensor_shape_lib.TensorShape(shape)]: 160 fan_in, _ = init_ops._compute_fans(tensor_shape) 161 std = np.sqrt(2. / fan_in) 162 self._runner( 163 init_ops.he_normal(seed=123), 164 tensor_shape, 165 target_mean=0., 166 target_std=std) 167 168 def test_Orthogonal(self): 169 shape = (20, 20) 170 with self.cached_session(): 171 for tensor_shape in [shape, tensor_shape_lib.TensorShape(shape)]: 172 self._runner( 173 init_ops.Orthogonal(seed=123), tensor_shape, target_mean=0.) 174 175 @test.disable_with_predicate( 176 pred=test.is_built_with_rocm, 177 skip_message='Disable subtest on ROCm due to missing QR op support') 178 @test_util.run_gpu_only 179 def testVariablePlacementWithOrthogonalInitializer(self): 180 with ops.Graph().as_default() as g: 181 with ops.device('gpu:0'): 182 variable_scope.get_variable( 183 name='v', shape=[8, 2], initializer=init_ops.Orthogonal) 184 variable_scope.get_variable( 185 name='w', shape=[8, 2], initializer=init_ops.RandomNormal) 186 run_metadata = config_pb2.RunMetadata() 187 run_options = config_pb2.RunOptions( 188 trace_level=config_pb2.RunOptions.FULL_TRACE) 189 config = config_pb2.ConfigProto( 190 allow_soft_placement=False, log_device_placement=True) 191 192 # Note: allow_soft_placement=False will fail whenever we cannot satisfy 193 # the colocation constraints. 194 with session.Session(config=config, graph=g) as sess: 195 sess.run( 196 variables.global_variables_initializer(), 197 options=run_options, 198 run_metadata=run_metadata) 199 200 @test_util.run_gpu_only 201 def test_eager_orthogonal_gpu(self): 202 with context.eager_mode(): 203 v = variable_scope.get_variable( 204 name='v', shape=[8, 2], initializer=init_ops.Orthogonal) 205 w = variable_scope.get_variable( 206 name='w', shape=[8, 2], initializer=init_ops.RandomNormal) 207 self.assertTrue('GPU' in v.handle.device) 208 self.assertTrue('GPU' in w.handle.device) 209 210 def test_Identity(self): 211 with self.cached_session(): 212 shape = (3, 4, 5) 213 for tensor_shape in [shape, tensor_shape_lib.TensorShape(shape)]: 214 with self.assertRaises(ValueError): 215 self._runner( 216 init_ops.Identity(), 217 tensor_shape, 218 target_mean=1. / int(tensor_shape[0]), 219 target_max=1.) 220 221 shape = (3, 3) 222 for tensor_shape in [shape, tensor_shape_lib.TensorShape(shape)]: 223 self._runner( 224 init_ops.Identity(), 225 tensor_shape, 226 target_mean=1. / int(tensor_shape[0]), 227 target_max=1.) 228 229 def test_Zeros(self): 230 shape = (4, 5) 231 with self.cached_session(): 232 for tensor_shape in [shape, tensor_shape_lib.TensorShape(shape)]: 233 self._runner( 234 init_ops.Zeros(), tensor_shape, target_mean=0., target_max=0.) 235 236 def test_Ones(self): 237 shape = (4, 5) 238 with self.cached_session(): 239 for tensor_shape in [shape, tensor_shape_lib.TensorShape(shape)]: 240 self._runner( 241 init_ops.Ones(), tensor_shape, target_mean=1., target_max=1.) 242 243 244if __name__ == '__main__': 245 test.main() 246