1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test_initializer """ 16import math 17from functools import reduce 18import numpy as np 19import pytest as py 20from scipy import stats 21 22import mindspore as ms 23import mindspore.common.initializer as init 24import mindspore.nn as nn 25from mindspore import context 26from mindspore.common.parameter import Parameter 27from mindspore.common.tensor import Tensor 28from mindspore.nn import Conv2d 29from mindspore.ops import operations as P 30from ..ut_filter import non_graph_engine 31 32 33# pylint: disable=W0212 34# W0212: protected-access 35 36class InitTwo(init.Initializer): 37 """Initialize the array to two.""" 38 39 def _initialize(self, arr): 40 init._assignment(arr, 2) 41 42 43def _check_value(tensor, value_min, value_max): 44 nd = tensor.asnumpy() 45 for ele in nd.flatten(): 46 if value_min <= ele <= value_max: 47 continue 48 raise ValueError('value_min = %d, ele = %d, value_max = %d' 49 % (value_min, ele, value_max)) 50 51 52def _check_uniform(tensor, boundary_a, boundary_b): 53 samples = tensor.asnumpy().reshape((-1)) 54 _, p = stats.kstest(samples, 'uniform', (boundary_a, (boundary_b - boundary_a))) 55 print("p-value is %f" % p) 56 return p > 0.0001 57 58 59def test_init_Initializer(): 60 tensor = init.initializer(InitTwo(), [2, 2], ms.int32) 61 assert tensor.shape == (2, 2) 62 _check_value(tensor.init_data(), 2, 2) 63 64 65def test_init_tensor(): 66 tensor = ms.Tensor(np.zeros([1, 2, 3])) 67 tensor = init.initializer(tensor, [1, 2, 3], ms.float32) 68 assert tensor.shape == (1, 2, 3) 69 70 71def test_init_zero_default_dtype(): 72 tensor = init.initializer(init.Zero(), [2, 2]) 73 assert tensor.dtype == ms.float32 74 _check_value(tensor.init_data(), 0, 0) 75 76 77def test_init_zero(): 78 tensor = init.initializer(init.Zero(), [2, 2], ms.float32) 79 _check_value(tensor.init_data(), 0, 0) 80 81 82def test_init_zero_alias_default_dtype(): 83 tensor = init.initializer('zeros', [1, 2]) 84 assert tensor.dtype == ms.float32 85 _check_value(tensor.init_data(), 0, 0) 86 87 88def test_init_zero_alias(): 89 tensor = init.initializer('zeros', [1, 2], ms.float32) 90 _check_value(tensor.init_data(), 0, 0) 91 92 93def test_init_one(): 94 tensor = init.initializer(init.One(), [2, 2], ms.float32) 95 _check_value(tensor.init_data(), 1, 1) 96 97 98def test_init_one_alias(): 99 tensor = init.initializer('ones', [1, 2], ms.float32) 100 _check_value(tensor.init_data(), 1, 1) 101 102 103def test_init_constant(): 104 tensor = init.initializer(init.Constant(1), [2, 2], ms.float32) 105 _check_value(tensor.init_data(), 1, 1) 106 107 108def test_init_uniform(): 109 scale = 10 110 tensor = init.initializer(init.Uniform(scale=scale), [5, 4], ms.float32) 111 _check_value(tensor.init_data(), -scale, scale) 112 113 114def test_init_uniform_alias(): 115 scale = 100 116 tensor = init.initializer('uniform', [5, 4], ms.float32) 117 _check_value(tensor.init_data(), -scale, scale) 118 119 120def test_init_normal(): 121 tensor = init.initializer(init.Normal(), [5, 4], ms.float32) 122 assert isinstance(tensor, Tensor), 'Normal init failed!' 123 124 125def test_init_truncated_normal(): 126 tensor = init.initializer(init.TruncatedNormal(), [5, 4], ms.float32) 127 assert isinstance(tensor, Tensor), 'TruncatedNormal init failed!' 128 129 130def test_init_normal_alias(): 131 tensor = init.initializer('normal', [5, 4], ms.float32) 132 assert isinstance(tensor, Tensor), 'Normal init failed!' 133 134 135def test_init_truncatednormal_alias(): 136 tensor = init.initializer('truncatednormal', [5, 4], ms.float32) 137 assert isinstance(tensor, Tensor), 'TruncatedNormal init failed!' 138 139 140def test_init_abnormal(): 141 with py.raises(TypeError): 142 init.initializer([''], [5, 4], ms.float32) 143 144 145def test_initializer_reinit(): 146 weights = init.initializer("XavierUniform", shape=(10, 1, 10, 10), dtype=ms.float16) 147 assert isinstance(weights, Tensor), 'XavierUniform init failed!' 148 149 150def test_init_xavier_uniform(): 151 """ test_init_xavier_uniform """ 152 gain = 1.2 153 tensor1 = init.initializer(init.XavierUniform(gain=gain), [20, 22], ms.float32).init_data() 154 tensor2 = init.initializer(init.XavierUniform(), [20, 22], ms.float32).init_data() 155 tensor3 = init.initializer(init.XavierUniform(gain=gain), [20, 22, 5, 5], ms.float32).init_data() 156 tensor4 = init.initializer(init.XavierUniform(), [20, 22, 5, 5], ms.float32).init_data() 157 tensor5 = init.initializer('xavier_uniform', [20, 22, 5, 5], ms.float32).init_data() 158 tensor6 = init.initializer('xavier_uniform', [20, 22], ms.float32).init_data() 159 tensor_dict = {tensor1: gain, tensor2: None, tensor3: gain, tensor4: None, tensor5: None, tensor6: None} 160 161 for tensor, gain_value in tensor_dict.items(): 162 if gain_value is None: 163 gain_value = 1 164 shape = tensor.asnumpy().shape 165 if len(shape) > 2: 166 s = reduce(lambda x, y: x * y, shape[2:]) 167 else: 168 s = 1 169 n_in = shape[1] * s 170 n_out = shape[0] * s 171 std = gain_value * math.sqrt(2 / (n_in + n_out)) 172 boundary = std * math.sqrt(3) 173 assert _check_uniform(tensor, -boundary, boundary) 174 175 176def test_init_xavier_uniform_error(): 177 with py.raises(ValueError): 178 init.initializer(init.XavierUniform(), [6], ms.float32).init_data() 179 180 181def test_init_he_uniform(): 182 """ test_init_he_uniform """ 183 tensor1 = init.initializer(init.HeUniform(), [20, 22], ms.float32) 184 tensor2 = init.initializer(init.HeUniform(), [20, 22, 5, 5], ms.float32) 185 tensor3 = init.initializer('he_uniform', [20, 22, 5, 5], ms.float32) 186 tensor4 = init.initializer('he_uniform', [20, 22], ms.float32) 187 tensors = [tensor1.init_data(), tensor2.init_data(), tensor3.init_data(), tensor4.init_data()] 188 189 for tensor in tensors: 190 shape = tensor.asnumpy().shape 191 if len(shape) > 2: 192 s = reduce(lambda x, y: x * y, shape[2:]) 193 else: 194 s = 1 195 n_in = shape[1] * s 196 std = math.sqrt(2 / n_in) 197 boundary = std * math.sqrt(3) 198 assert _check_uniform(tensor, -boundary, boundary) 199 200 201def test_init_he_uniform_error(): 202 with py.raises(ValueError): 203 init.initializer(init.HeUniform(), [6], ms.float32).init_data() 204 205 206def test_conv2d_abnormal_kernel_negative(): 207 kernel = np.random.randn(64, 3, 7, 7).astype(np.float32) 208 with py.raises(ValueError): 209 ms.Model( 210 Conv2d(in_channels=3, out_channels=64, kernel_size=-7, stride=3, 211 padding=0, weight_init=ms.Tensor(kernel))) 212 213 214@non_graph_engine 215def test_conv2d_abnormal_kernel_normal(): 216 kernel = np.random.randn(64, 3, 7, 7).astype(np.float32) 217 input_data = np.random.randn(32, 3, 224, 112).astype(np.float32) 218 context.set_context(mode=context.GRAPH_MODE) 219 model = ms.Model( 220 Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=3, 221 padding=0, weight_init=ms.Tensor(kernel))) 222 model.predict(ms.Tensor(input_data)) 223 224 225@non_graph_engine 226def test_conv2d_abnormal_kernel_truncated_normal(): 227 input_data = init.initializer(init.TruncatedNormal(), [64, 3, 7, 7], ms.float32).init_data() 228 context.set_context(mode=context.GRAPH_MODE) 229 model = ms.Model( 230 Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=3, 231 padding=0, weight_init="truncatednormal")) 232 model.predict(input_data) 233 234 235class Net(nn.Cell): 236 def __init__(self): 237 super(Net, self).__init__() 238 self.add = P.Add() 239 self.t1 = Parameter(init.initializer('uniform', [5, 4], ms.float32), name="w1") 240 self.t2 = Parameter(init.initializer(init.TruncatedNormal(), [5, 4], ms.float32), name="w2") 241 242 def construct(self, x): 243 z = self.add(x, self.t1) 244 z = self.add(z, self.t2) 245 return z 246 247 248def test_weight_shape(): 249 context.set_context(mode=context.GRAPH_MODE) 250 a = np.arange(20).reshape(5, 4) 251 t = Tensor(a, dtype=ms.float32) 252 net = Net() 253 out = net(t) 254 print(out) 255