• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1
2# Copyright 2020 Huawei Technologies Co., Ltd
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ============================================================================
16""" test parameter """
17import numpy as np
18import pytest
19
20from mindspore import context, Tensor, Parameter, ParameterTuple, nn
21from mindspore._checkparam import Validator
22from mindspore.common import dtype as mstype
23from mindspore.common.initializer import initializer
24
25def test_parameter_init():
26    dat = np.array([[1, 2, 3], [2, 3, 4]])
27    tensor = Tensor(dat)
28    Parameter(tensor, name="testParameter", requires_grad=True, layerwise_parallel=False)
29
30
31def test_parameter_tuple_illegal():
32    p1 = Parameter(initializer(0, [1], mstype.int32), name="global_step1")
33    p2 = Parameter(initializer(0, [1], mstype.int32), name="global_step2")
34    plist = [p1, p2]
35    plist2 = [p1, "str"]
36    ptuple = (p1, p2)
37    ptuple_str = ("2", "1")
38    pstr = "[2,3]"
39    pnum = 3
40
41    ParameterTuple(plist)
42    ParameterTuple(ptuple)
43    with pytest.raises(TypeError):
44        ParameterTuple(p1)
45    with pytest.raises(TypeError):
46        ParameterTuple(plist2)
47    with pytest.raises(TypeError):
48        ParameterTuple(ptuple_str)
49    with pytest.raises(TypeError):
50        ParameterTuple(pstr)
51    with pytest.raises(TypeError):
52        ParameterTuple(pnum)
53
54
55def test_parameter_init_illegal():
56    dat = np.array([[1, 2, 3], [2, 3, 4]])
57    tensor = Tensor(dat)
58    data_none = None
59    data_bool = True
60    data_str = "nicai"
61    data_int = 3
62    data_list = [1, "2", True]
63    data_tuple = (1, 2, 3)
64
65    # test data
66    Parameter(tensor, name=data_str)
67    Parameter(data_int, name=data_str)
68    Parameter(dat, name=data_str)
69    with pytest.raises(ValueError):
70        Parameter(data_bool, name=data_str)
71
72    # test name
73    Parameter(tensor, name=data_none)
74    with pytest.raises(ValueError):
75        Parameter(tensor, name=dat)
76    with pytest.raises(ValueError):
77        Parameter(tensor, name=tensor)
78    with pytest.raises(ValueError):
79        Parameter(tensor, name=data_bool)
80    with pytest.raises(ValueError):
81        Parameter(tensor, name=data_int)
82    with pytest.raises(ValueError):
83        Parameter(tensor, name=data_list)
84    with pytest.raises(ValueError):
85        Parameter(tensor, name=data_tuple)
86
87    Parameter(tensor, name=data_str, requires_grad=data_bool)
88    with pytest.raises(TypeError):
89        Parameter(tensor, name=data_str, requires_grad=data_none)
90    with pytest.raises(TypeError):
91        Parameter(tensor, name=data_str, requires_grad=dat)
92    with pytest.raises(TypeError):
93        Parameter(tensor, name=data_str, requires_grad=tensor)
94    with pytest.raises(TypeError):
95        Parameter(tensor, name=data_str, requires_grad=data_str)
96    with pytest.raises(TypeError):
97        Parameter(tensor, name=data_str, requires_grad=data_int)
98    with pytest.raises(TypeError):
99        Parameter(tensor, name=data_str, requires_grad=data_list)
100    with pytest.raises(TypeError):
101        Parameter(tensor, name=data_str, requires_grad=data_tuple)
102
103    Parameter(tensor, name=data_str, requires_grad=data_bool, layerwise_parallel=data_bool)
104    with pytest.raises(TypeError):
105        Parameter(tensor, name=data_str, requires_grad=data_bool, layerwise_parallel=dat)
106    with pytest.raises(TypeError):
107        Parameter(tensor, name=data_str, requires_grad=data_bool, layerwise_parallel=tensor)
108    with pytest.raises(TypeError):
109        Parameter(tensor, name=data_str, requires_grad=data_bool, layerwise_parallel=data_none)
110    with pytest.raises(TypeError):
111        Parameter(tensor, name=data_str, requires_grad=data_bool, layerwise_parallel=data_str)
112    with pytest.raises(TypeError):
113        Parameter(tensor, name=data_str, requires_grad=data_bool, layerwise_parallel=data_int)
114    with pytest.raises(TypeError):
115        Parameter(tensor, name=data_str, requires_grad=data_bool, layerwise_parallel=data_list)
116    with pytest.raises(TypeError):
117        Parameter(tensor, name=data_str, requires_grad=data_bool, layerwise_parallel=data_tuple)
118
119
120def test_check_str_by_regular():
121    str1 = "12_sf.asdf_"
122    str2 = "x12_sf.asdf."
123    str3 = "_x12_sf.asdf"
124    str4 = ".12_sf.asdf"
125    str5 = "12_sf.a$sdf."
126    str6 = "12+sf.asdf"
127    Validator.check_str_by_regular(str1)
128    Validator.check_str_by_regular(str2)
129    Validator.check_str_by_regular(str3)
130    with pytest.raises(ValueError):
131        Validator.check_str_by_regular(str4)
132    with pytest.raises(ValueError):
133        Validator.check_str_by_regular(str5)
134    with pytest.raises(ValueError):
135        Validator.check_str_by_regular(str6)
136
137def test_parameter_compute():
138    para_1 = Parameter(initializer('ones', [1, 2, 3], mstype.int32), 'test1')
139    para_2 = Parameter(initializer('ones', [1, 2, 3], mstype.int32), 'test2')
140
141    t3 = Tensor(np.ones((1, 2, 3)))
142
143    out = para_1 + para_2
144    assert np.array_equal(out.asnumpy(), np.ones((1, 2, 3)) * 2)
145
146    out = para_1 * para_2
147    assert np.array_equal(out.asnumpy(), np.ones((1, 2, 3)))
148
149    out = para_1 + t3
150    assert np.array_equal(out.asnumpy(), np.ones((1, 2, 3)) * 2)
151
152    out = para_1 * t3
153    assert np.array_equal(out.asnumpy(), np.ones((1, 2, 3)))
154
155    assert isinstance(para_1, Tensor)
156
157
158def test_scalar_parameter_update():
159    # float
160    fp = Parameter(0.5, 'fp')
161    fp.set_data(0.8)
162    assert np.array_equal(fp.data.asnumpy(), np.array(0.8, np.float32))
163    fp.set_data(1)
164    assert np.array_equal(fp.data.asnumpy(), np.array(1.0, np.float32))
165    int_ = Parameter(1, 'fp')
166    int_.set_data(2)
167    assert np.array_equal(int_.data.asnumpy(), np.array(2, np.int32))
168    with pytest.raises(TypeError):
169        int_.set_data(1.2)
170    # Tensor
171    fp32 = Tensor(0.5, mstype.float32)
172    int32 = Tensor(2, mstype.int32)
173    fp16 = Tensor(0.6, mstype.float16)
174    int16 = Tensor(3, mstype.int16)
175    bool_ = Tensor(np.array(True, dtype=np.bool_))
176    # updata_by_tensor
177    fp32_p = Parameter(fp32, 'fp32')
178    fp32_p.set_data(0.8)
179    fp32_p.set_data(1)
180    fp32_p.set_data(int32)
181    fp32_p.set_data(fp32)
182    fp32_p.set_data(int16)
183    fp32_p.set_data(fp16)
184    fp32_p.set_data(bool_)
185
186    # updata_by_tensor
187    fp16_p = Parameter(fp16, 'fp16')
188    with pytest.raises(TypeError):
189        fp16_p.set_data(fp32)
190
191
192def test_parameter_lazy_init():
193    # support lazy init in SEMI_AUTO_PARALLEL mode
194    context.reset_auto_parallel_context()
195    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8)
196    # Call init_data() without set set_data.
197    para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test1')
198    assert isinstance(para.data, Tensor)
199    para = para.init_data()
200    assert isinstance(para.data, Tensor)
201    assert np.array_equal(para.data.asnumpy(), np.ones((1, 2, 3)))
202
203    para = Parameter(initializer('ones', [1, 2, 3], mstype.complex64), 'test1')
204    assert isinstance(para.data, Tensor)
205    para = para.init_data()
206    assert isinstance(para.data, Tensor)
207    assert np.array_equal(para.data.asnumpy(), np.ones((1, 2, 3)))
208
209    # Call init_data() after set_data is set.
210    para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test2')
211    assert isinstance(para.data, Tensor)
212    # expect type error when not init
213    with pytest.raises(TypeError):
214        para.set_data(Tensor(np.zeros((1, 2, 3))))
215    # init then assign
216    para = para.init_data()
217    # check the type
218    with pytest.raises(TypeError):
219        para.set_data(Tensor(np.zeros((1, 2, 3))))
220    # check the shape
221    with pytest.raises(ValueError):
222        para.set_data(Tensor(np.zeros((1, 2))))
223    # expect change ok
224    para.set_data(Tensor(np.zeros((1, 2, 3)).astype(np.float32)))
225    assert np.array_equal(para.data.asnumpy(), np.zeros((1, 2, 3)))
226    para.set_data(initializer('ones', [1, 2, 3], mstype.float32))
227    assert isinstance(para.data, Tensor)
228    # same object and has inited
229    assert np.array_equal(para.data.asnumpy(), np.ones((1, 2, 3)))
230    # expect no effect.
231    para.init_data()
232    assert np.array_equal(para.data.asnumpy(), np.ones((1, 2, 3)))
233    para.set_data(Tensor(np.zeros((1, 2)).astype(np.float32)), slice_shape=True)
234    assert np.array_equal(para.data.asnumpy(), np.zeros((1, 2)))
235    para.set_data(initializer('ones', [1, 2], mstype.float32), slice_shape=True)
236    assert np.array_equal(para.data.asnumpy(), np.ones((1, 2)))
237    context.reset_auto_parallel_context()
238
239
240def test_parameter_as_output():
241    context.reset_auto_parallel_context()
242    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
243    initial_input = initializer('One', shape=(2,), dtype=mstype.int32)
244    updated_input = Tensor([2, 2], mstype.int32)
245    class Net(nn.Cell):
246        def __init__(self, initial, updated):
247            super().__init__()
248            self.initial = initial
249            self.updated = updated
250            self.p = Parameter(self.initial, name="weight")
251            self.new_p = self.p.init_data()
252            self.new_p.set_data(self.updated)
253        def construct(self):
254            return self.new_p
255
256    net = Net(initial_input, updated_input)
257    output = net()
258    assert np.array_equal(output.asnumpy(), np.array([2, 2], np.int32))
259    context.reset_auto_parallel_context()
260