• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test_initializer """
16import math
17from functools import reduce
18import numpy as np
19import pytest as py
20from scipy import stats
21
22import mindspore as ms
23import mindspore.common.initializer as init
24import mindspore.nn as nn
25from mindspore import context
26from mindspore.common.parameter import Parameter
27from mindspore.common.tensor import Tensor
28from mindspore.nn import Conv2d
29from mindspore.ops import operations as P
30from ..ut_filter import non_graph_engine
31
32
33# pylint: disable=W0212
34# W0212: protected-access
35
36class InitTwo(init.Initializer):
37    """Initialize the array to two."""
38
39    def _initialize(self, arr):
40        init._assignment(arr, 2)
41
42
43def _check_value(tensor, value_min, value_max):
44    nd = tensor.asnumpy()
45    for ele in nd.flatten():
46        if value_min <= ele <= value_max:
47            continue
48        raise ValueError('value_min = %d, ele = %d, value_max = %d'
49                         % (value_min, ele, value_max))
50
51
52def _check_uniform(tensor, boundary_a, boundary_b):
53    samples = tensor.asnumpy().reshape((-1))
54    _, p = stats.kstest(samples, 'uniform', (boundary_a, (boundary_b - boundary_a)))
55    print("p-value is %f" % p)
56    return p > 0.0001
57
58
59def test_init_Initializer():
60    tensor = init.initializer(InitTwo(), [2, 2], ms.int32)
61    assert tensor.shape == (2, 2)
62    _check_value(tensor.init_data(), 2, 2)
63
64
65def test_init_tensor():
66    tensor = ms.Tensor(np.zeros([1, 2, 3]))
67    tensor = init.initializer(tensor, [1, 2, 3], ms.float32)
68    assert tensor.shape == (1, 2, 3)
69
70
71def test_init_zero_default_dtype():
72    tensor = init.initializer(init.Zero(), [2, 2])
73    assert tensor.dtype == ms.float32
74    _check_value(tensor.init_data(), 0, 0)
75
76
77def test_init_zero():
78    tensor = init.initializer(init.Zero(), [2, 2], ms.float32)
79    _check_value(tensor.init_data(), 0, 0)
80
81
82def test_init_zero_alias_default_dtype():
83    tensor = init.initializer('zeros', [1, 2])
84    assert tensor.dtype == ms.float32
85    _check_value(tensor.init_data(), 0, 0)
86
87
88def test_init_zero_alias():
89    tensor = init.initializer('zeros', [1, 2], ms.float32)
90    _check_value(tensor.init_data(), 0, 0)
91
92
93def test_init_one():
94    tensor = init.initializer(init.One(), [2, 2], ms.float32)
95    _check_value(tensor.init_data(), 1, 1)
96
97
98def test_init_one_alias():
99    tensor = init.initializer('ones', [1, 2], ms.float32)
100    _check_value(tensor.init_data(), 1, 1)
101
102
103def test_init_constant():
104    tensor = init.initializer(init.Constant(1), [2, 2], ms.float32)
105    _check_value(tensor.init_data(), 1, 1)
106
107
108def test_init_uniform():
109    scale = 10
110    tensor = init.initializer(init.Uniform(scale=scale), [5, 4], ms.float32)
111    _check_value(tensor.init_data(), -scale, scale)
112
113
114def test_init_uniform_alias():
115    scale = 100
116    tensor = init.initializer('uniform', [5, 4], ms.float32)
117    _check_value(tensor.init_data(), -scale, scale)
118
119
120def test_init_normal():
121    tensor = init.initializer(init.Normal(), [5, 4], ms.float32)
122    assert isinstance(tensor, Tensor), 'Normal init failed!'
123
124
125def test_init_truncated_normal():
126    tensor = init.initializer(init.TruncatedNormal(), [5, 4], ms.float32)
127    assert isinstance(tensor, Tensor), 'TruncatedNormal init failed!'
128
129
130def test_init_normal_alias():
131    tensor = init.initializer('normal', [5, 4], ms.float32)
132    assert isinstance(tensor, Tensor), 'Normal init failed!'
133
134
135def test_init_truncatednormal_alias():
136    tensor = init.initializer('truncatednormal', [5, 4], ms.float32)
137    assert isinstance(tensor, Tensor), 'TruncatedNormal init failed!'
138
139
140def test_init_abnormal():
141    with py.raises(TypeError):
142        init.initializer([''], [5, 4], ms.float32)
143
144
145def test_initializer_reinit():
146    weights = init.initializer("XavierUniform", shape=(10, 1, 10, 10), dtype=ms.float16)
147    assert isinstance(weights, Tensor), 'XavierUniform init failed!'
148
149
150def test_init_xavier_uniform():
151    """ test_init_xavier_uniform """
152    gain = 1.2
153    tensor1 = init.initializer(init.XavierUniform(gain=gain), [20, 22], ms.float32).init_data()
154    tensor2 = init.initializer(init.XavierUniform(), [20, 22], ms.float32).init_data()
155    tensor3 = init.initializer(init.XavierUniform(gain=gain), [20, 22, 5, 5], ms.float32).init_data()
156    tensor4 = init.initializer(init.XavierUniform(), [20, 22, 5, 5], ms.float32).init_data()
157    tensor5 = init.initializer('xavier_uniform', [20, 22, 5, 5], ms.float32).init_data()
158    tensor6 = init.initializer('xavier_uniform', [20, 22], ms.float32).init_data()
159    tensor_dict = {tensor1: gain, tensor2: None, tensor3: gain, tensor4: None, tensor5: None, tensor6: None}
160
161    for tensor, gain_value in tensor_dict.items():
162        if gain_value is None:
163            gain_value = 1
164        shape = tensor.asnumpy().shape
165        if len(shape) > 2:
166            s = reduce(lambda x, y: x * y, shape[2:])
167        else:
168            s = 1
169        n_in = shape[1] * s
170        n_out = shape[0] * s
171        std = gain_value * math.sqrt(2 / (n_in + n_out))
172        boundary = std * math.sqrt(3)
173        assert _check_uniform(tensor, -boundary, boundary)
174
175
176def test_init_xavier_uniform_error():
177    with py.raises(ValueError):
178        init.initializer(init.XavierUniform(), [6], ms.float32).init_data()
179
180
181def test_init_he_uniform():
182    """ test_init_he_uniform """
183    tensor1 = init.initializer(init.HeUniform(), [20, 22], ms.float32)
184    tensor2 = init.initializer(init.HeUniform(), [20, 22, 5, 5], ms.float32)
185    tensor3 = init.initializer('he_uniform', [20, 22, 5, 5], ms.float32)
186    tensor4 = init.initializer('he_uniform', [20, 22], ms.float32)
187    tensors = [tensor1.init_data(), tensor2.init_data(), tensor3.init_data(), tensor4.init_data()]
188
189    for tensor in tensors:
190        shape = tensor.asnumpy().shape
191        if len(shape) > 2:
192            s = reduce(lambda x, y: x * y, shape[2:])
193        else:
194            s = 1
195        n_in = shape[1] * s
196        std = math.sqrt(2 / n_in)
197        boundary = std * math.sqrt(3)
198        assert _check_uniform(tensor, -boundary, boundary)
199
200
201def test_init_he_uniform_error():
202    with py.raises(ValueError):
203        init.initializer(init.HeUniform(), [6], ms.float32).init_data()
204
205
206def test_conv2d_abnormal_kernel_negative():
207    kernel = np.random.randn(64, 3, 7, 7).astype(np.float32)
208    with py.raises(ValueError):
209        ms.Model(
210            Conv2d(in_channels=3, out_channels=64, kernel_size=-7, stride=3,
211                   padding=0, weight_init=ms.Tensor(kernel)))
212
213
214@non_graph_engine
215def test_conv2d_abnormal_kernel_normal():
216    kernel = np.random.randn(64, 3, 7, 7).astype(np.float32)
217    input_data = np.random.randn(32, 3, 224, 112).astype(np.float32)
218    context.set_context(mode=context.GRAPH_MODE)
219    model = ms.Model(
220        Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=3,
221               padding=0, weight_init=ms.Tensor(kernel)))
222    model.predict(ms.Tensor(input_data))
223
224
225@non_graph_engine
226def test_conv2d_abnormal_kernel_truncated_normal():
227    input_data = init.initializer(init.TruncatedNormal(), [64, 3, 7, 7], ms.float32).init_data()
228    context.set_context(mode=context.GRAPH_MODE)
229    model = ms.Model(
230        Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=3,
231               padding=0, weight_init="truncatednormal"))
232    model.predict(input_data)
233
234
235class Net(nn.Cell):
236    def __init__(self):
237        super(Net, self).__init__()
238        self.add = P.Add()
239        self.t1 = Parameter(init.initializer('uniform', [5, 4], ms.float32), name="w1")
240        self.t2 = Parameter(init.initializer(init.TruncatedNormal(), [5, 4], ms.float32), name="w2")
241
242    def construct(self, x):
243        z = self.add(x, self.t1)
244        z = self.add(z, self.t2)
245        return z
246
247
248def test_weight_shape():
249    context.set_context(mode=context.GRAPH_MODE)
250    a = np.arange(20).reshape(5, 4)
251    t = Tensor(a, dtype=ms.float32)
252    net = Net()
253    out = net(t)
254    print(out)
255