• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import pytest
16import numpy as np
17
18import mindspore as ms
19from mindspore.nn import Cell
20from mindspore.common.parameter import Parameter
21from mindspore.common import ParameterTuple
22from mindspore import Tensor, context
23
24
25@pytest.mark.level0
26@pytest.mark.platform_x86_cpu
27@pytest.mark.env_onecard
28@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
29def test_parameter_1_1(mode):
30    """
31    Feature: Check the names of parameters and the names of inputs of construct.
32    Description: If the name of the input of construct is same as the parameters, add suffix to the name of the input.
33    Expectation: No exception.
34    """
35
36    class ParamNet(Cell):
37        def __init__(self):
38            super(ParamNet, self).__init__()
39            self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
40            self.param_b = Parameter(Tensor([2], ms.float32), name="name_b")
41
42        def construct(self, name_a):
43            return self.param_a + self.param_b - name_a
44
45    context.set_context(mode=mode)
46    net = ParamNet()
47    res = net(Tensor([3], ms.float32))
48    assert res == 0
49    assert net.param_a.name == "name_a"
50    assert net.param_b.name == "name_b"
51
52
53@pytest.mark.level0
54@pytest.mark.platform_x86_cpu
55@pytest.mark.env_onecard
56@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
57def test_parameter_1_2(mode):
58    """
59    Feature: Check the names of parameters and the names of inputs of construct.
60    Description: If the name of the input of construct is same as the parameters, add suffix to the name of the input.
61    Expectation: No exception.
62    """
63
64    class ParamNet(Cell):
65        def __init__(self):
66            super(ParamNet, self).__init__()
67            self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
68            self.param_b = ParameterTuple((Parameter(Tensor([2], ms.float32), name="name_b"), self.param_a))
69
70        def construct(self, name_b):
71            return self.param_a + self.param_b[0] - name_b
72
73    context.set_context(mode=mode)
74    net = ParamNet()
75    res = net(Tensor([3], ms.float32))
76    assert res == 0
77    assert net.param_a.name == "name_a"
78    assert net.param_b[0].name == "name_b"
79    assert net.param_b[1].name == "name_a"
80
81
82@pytest.mark.level0
83@pytest.mark.platform_x86_cpu
84@pytest.mark.env_onecard
85@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
86def test_parameter_3(mode):
87    """
88    Feature: Check the names of parameters.
89    Description: Check the name of parameter in init.
90    Expectation: No exception.
91    """
92
93    class ParamNet(Cell):
94        def __init__(self):
95            super(ParamNet, self).__init__()
96            self.param_a = Parameter(Tensor([1], ms.float32))
97            self.param_b = Parameter(Tensor([2], ms.float32))
98
99        def construct(self):
100            return self.param_a + self.param_b
101
102    context.set_context(mode=mode)
103    net = ParamNet()
104    res = net()
105    assert res == 3
106    assert net.param_a.name == "param_a"
107    assert net.param_b.name == "param_b"
108
109
110@pytest.mark.level0
111@pytest.mark.platform_x86_cpu
112@pytest.mark.env_onecard
113@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
114def test_parameter_5_2(mode):
115    """
116    Feature: Check the names of parameters.
117    Description: Check the name of parameter in init.
118    Expectation: No exception.
119    """
120
121    class ParamNet(Cell):
122        def __init__(self):
123            super(ParamNet, self).__init__()
124            self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
125            self.res1 = ParameterTuple((Parameter(Tensor([2], ms.float32)), self.param_a))
126            self.param_a = Parameter(Tensor([3], ms.float32), name="name_b")
127            self.res2 = self.res1[0] + self.param_a
128
129        def construct(self):
130            return self.param_a + self.res1[0] + self.res2
131
132    context.set_context(mode=mode)
133    net = ParamNet()
134    res = net()
135    assert res == 10
136    assert net.param_a.name == "name_b"
137    assert net.res1[0].name == "Parameter$1"
138    assert net.res1[1].name == "name_a"
139
140
141@pytest.mark.level0
142@pytest.mark.platform_x86_cpu
143@pytest.mark.env_onecard
144@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
145def test_parameter_list_tuple_no_name(mode):
146    """
147    Feature: Check the names of parameters.
148    Description: Check the name of parameter in init.
149    Expectation: No exception.
150    """
151
152    class ParamNet(Cell):
153        def __init__(self):
154            super(ParamNet, self).__init__()
155            self.param_tuple = (Parameter(Tensor([5], ms.float32)), Parameter(Tensor([6], ms.float32)))
156            self.param_list = [Parameter(Tensor([7], ms.float32)), Parameter(Tensor([8], ms.float32))]
157
158        def construct(self):
159            return self.param_tuple[0] + self.param_tuple[1] + self.param_list[0] + self.param_list[1]
160
161    context.set_context(mode=mode)
162    net = ParamNet()
163    res = net()
164    assert res == 26
165    assert net.param_tuple[0].name == "Parameter$1"
166    assert net.param_tuple[1].name == "Parameter$2"
167    assert net.param_list[0].name == "Parameter$3"
168    assert net.param_list[1].name == "Parameter$4"
169
170
171@pytest.mark.level0
172@pytest.mark.platform_x86_cpu
173@pytest.mark.env_onecard
174@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
175def test_parameter_in_tuple(mode):
176    """
177    Feature: Check the names of parameters.
178    Description: Check the name of parameter in init.
179    Expectation: No exception.
180    """
181
182    class ParamNet(Cell):
183        def __init__(self):
184            super(ParamNet, self).__init__()
185            self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
186            self.param_b = Parameter(Tensor([2], ms.float32), name="name_b")
187            self.param_tuple = ParameterTuple((self.param_a, self.param_b))
188
189        def construct(self):
190            return self.param_a + self.param_b + self.param_tuple[0] + self.param_tuple[1]
191
192    context.set_context(mode=mode)
193    net = ParamNet()
194    res = net()
195    assert res == 6
196    assert net.param_a.name == "name_a"
197    assert net.param_b.name == "name_b"
198    assert net.param_tuple[0].name == "name_a"
199    assert net.param_tuple[1].name == "name_b"
200
201
202@pytest.mark.level0
203@pytest.mark.platform_x86_cpu
204@pytest.mark.env_onecard
205@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
206def test_parameter_parameter_tuple_2(mode):
207    """
208    Feature: Check the names of parameters.
209    Description: Check the name of parameter in init.
210    Expectation: No exception.
211    """
212
213    class ParamNet(Cell):
214        def __init__(self):
215            super(ParamNet, self).__init__()
216            self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
217            self.param_tuple = ParameterTuple((self.param_a, self.param_a, self.param_a))
218
219        def construct(self):
220            return self.param_a + self.param_tuple[0] + self.param_tuple[1] + self.param_tuple[2]
221
222    context.set_context(mode=mode)
223    net = ParamNet()
224    res = net()
225    assert res == 4
226    assert net.param_a.name == "name_a"
227    assert net.param_tuple[0].name == "name_a"
228    assert net.param_tuple[1].name == "name_a"
229    assert net.param_tuple[2].name == "name_a"
230
231
232@pytest.mark.level0
233@pytest.mark.platform_x86_cpu
234@pytest.mark.env_onecard
235@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
236def test_parameter(mode):
237    """
238    Feature: Check the names of parameters.
239    Description: If parameter in list or tuple is not given a name, will give it a unique name.
240    Expectation: No exception.
241    """
242
243    class ParamNet(Cell):
244        def __init__(self):
245            super(ParamNet, self).__init__()
246            self.param_a = Parameter(Tensor([1], ms.float32), name="name_a")
247            self.param_b = Parameter(Tensor([2], ms.float32), name="name_b")
248            self.param_c = Parameter(Tensor([3], ms.float32))
249            self.param_d = Parameter(Tensor([4], ms.float32))
250            self.param_tuple = (Parameter(Tensor([5], ms.float32)),
251                                Parameter(Tensor([6], ms.float32)))
252            self.param_list = [Parameter(Tensor([5], ms.float32)),
253                               Parameter(Tensor([6], ms.float32))]
254
255        def construct(self, x):
256            res1 = self.param_a + self.param_b + self.param_c + self.param_d
257            res1 = res1 - self.param_list[0] + self.param_list[1] + x
258            res2 = self.param_list[0] + self.param_list[1]
259            return res1, res2
260
261    context.set_context(mode=mode)
262    net = ParamNet()
263    x = Tensor([10], ms.float32)
264    output1, output2 = net(x)
265    output1_expect = Tensor(21, ms.float32)
266    output2_expect = Tensor(11, ms.float32)
267    assert output1 == output1_expect
268    assert output2 == output2_expect
269    assert net.param_a.name == "name_a"
270    assert net.param_b.name == "name_b"
271    assert net.param_c.name == "param_c"
272    assert net.param_d.name == "param_d"
273    assert net.param_tuple[0].name == "Parameter$1"
274    assert net.param_tuple[1].name == "Parameter$2"
275    assert net.param_list[0].name == "Parameter$3"
276    assert net.param_list[1].name == "Parameter$4"
277
278
279@pytest.mark.level0
280@pytest.mark.platform_x86_cpu
281@pytest.mark.env_onecard
282@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
283def test_parameter_argument_and_fv(mode):
284    """
285    Feature: Parameter argmument in top func graph.
286    Description: Use Parameter as input argmument.
287    Expectation: Parameter used as argument should equal to used as FV.
288    """
289    y = Parameter(Tensor([1]))
290
291    class Demo(Cell):
292        def construct(self, x):
293            ms.ops.Assign()(x, Tensor([0]))
294            ms.ops.Assign()(y, Tensor([0]))
295            return True
296
297    context.set_context(mode=mode)
298    x = Parameter(Tensor([1]))
299    net = Demo()
300    net(x)
301    print(Tensor(x))
302    print(Tensor(y))
303    assert x == y
304
305
306@pytest.mark.level0
307@pytest.mark.platform_x86_cpu
308@pytest.mark.env_onecard
309@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
310def test_parameter_argument_grad(mode):
311    """
312    Feature: Parameter argmument in top func graph.
313    Description: Use Parameter as input argmument, and pass it to varargs.
314    Expectation: Parameter used as argument should equal to used as FV.
315    """
316
317    class ParameterArgumentCell(Cell):
318        def __init__(self):
319            super(ParameterArgumentCell, self).__init__()
320            self.z = Parameter(Tensor(np.array([[1.0, 4.0], [-1, 8.0]]), ms.float32), name='z')
321
322        def construct(self, param, x, y):
323            ms.ops.Assign()(param, x * self.z)
324            ms.ops.Assign()(x, x + y)
325            ms.ops.Assign()(y, param)
326            return param
327
328    context.set_context(mode=mode)
329    param = Parameter(Tensor(np.array([[0, 0], [0, 0]]), ms.float32), name='param')
330    x = Parameter(Tensor(np.array([[4.0, -8.0], [-2.0, -5.0]]), ms.float32), name='x')
331    y = Parameter(Tensor(np.array([[1, 0], [1, 1]]), ms.float32), name='y')
332    net = ParameterArgumentCell()
333    net(param, x, y)
334
335    bparam = Parameter(Tensor(np.array([[0, 0], [0, 0]]), ms.float32), name='bparam')
336    bx = Parameter(Tensor(np.array([[4.0, -8.0], [-2.0, -5.0]]), ms.float32), name='bx')
337    by = Parameter(Tensor(np.array([[1, 0], [1, 1]]), ms.float32), name='by')
338    grad_by_list = ms.ops.GradOperation(get_by_list=True)
339    grad_by_list(net, ParameterTuple(net.trainable_params()))(bparam, bx, by)
340
341    assert np.array_equal(param.asnumpy(), bparam.asnumpy())
342    assert np.array_equal(x.asnumpy(), bx.asnumpy())
343    assert np.array_equal(y.asnumpy(), by.asnumpy())
344