• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import numpy as np
16import pytest
17
18import mindspore as ms
19import mindspore.ops.composite as C
20from mindspore import context
21import mindspore.nn as nn
22from mindspore.ops import operations as P
23from mindspore.ops import functional as F
24from mindspore import Tensor
25from mindspore.common.parameter import Parameter, ParameterTuple
26from tests.security_utils import security_off_wrap
27
28grad_all_list = C.GradOperation(get_all=True, get_by_list=True)
29grad_by_list = C.GradOperation(get_by_list=True)
30
31context.set_context(mode=context.GRAPH_MODE)
32
33
34def test_load_grad():
35    class LoadNet(nn.Cell):
36        def __init__(self):
37            super().__init__()
38            self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
39
40        def construct(self, x, y):
41            x = x * y * self.z
42            return x
43
44    x = Tensor(np.array([2.0], np.float32))
45    y = Tensor(np.array([3.0], np.float32))
46    load_net = LoadNet()
47    grad_net = grad_all_list(
48        load_net, ParameterTuple(load_net.trainable_params()))
49    print(grad_net(x, y))
50
51
52def test_assign_only_grad():
53    class AssignOnlyNet(nn.Cell):
54        def __init__(self):
55            super().__init__()
56            self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
57
58        def construct(self, x, y):
59            self.z = x
60            x = x * y
61            return x
62
63    class GradNet(nn.Cell):
64        def __init__(self, net):
65            super(GradNet, self).__init__()
66            self.net = net
67            self.parameter_tuple = ParameterTuple(self.trainable_params())
68
69        def construct(self, x, y):
70            return grad_all_list(self.net, self.parameter_tuple)(x, y)
71
72    assign_net = AssignOnlyNet()
73    net = GradNet(assign_net)
74    x = Tensor(np.array([2.0], np.float32))
75    y = Tensor(np.array([3.0], np.float32))
76    print(net(x, y))
77
78
79def test_load_assign_grad():
80    class AssignNet(nn.Cell):
81        def __init__(self):
82            super().__init__()
83            self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
84            self.assign = P.Assign()
85
86        def construct(self, x, y):
87            x = x * self.z
88            self.assign(self.z, x)
89            out = y * self.z
90            return out
91
92    class GradNet(nn.Cell):
93        def __init__(self, net):
94            super(GradNet, self).__init__()
95            self.net = net
96            self.parameter_tuple = ParameterTuple(net.trainable_params())
97
98        def construct(self, x, y):
99            return grad_all_list(self.net, self.parameter_tuple)(x, y)
100
101    assign_net = AssignNet()
102    net = GradNet(assign_net)
103    x = Tensor(np.array([2.0], np.float32))
104    y = Tensor(np.array([3.0], np.float32))
105    print(net(x, y))
106
107
108def test_insert_gradient_of():
109    class InsertGradientNet(nn.Cell):
110        def __init__(self):
111            super(InsertGradientNet, self).__init__()
112            self.gather = P.GatherV2()
113            self.damping = Tensor(np.array([0.03, 0.03], np.float32))
114            self.cov_step = Parameter(0, name="cov_step", requires_grad=False)
115            self.freq = Tensor(278, ms.int32)
116            self.getG = P.InsertGradientOf(self.save_gradient)
117
118        def save_gradient(self, dout):
119            self.cov_step = self.cov_step + self.freq
120            return dout
121
122        def construct(self, x):
123            self.gather(self.damping, self.cov_step, 0)
124            out = P.ReLU()(x)
125            out = self.getG(out)
126            out = self.getG(out)
127            return out
128
129    net = InsertGradientNet()
130    input_data = np.array([[1.2, 2.1], [2.2, 3.2]]).astype(np.float32)
131    grad_net = grad_all_list(net, ParameterTuple(net.trainable_params()))
132    print(grad_net(Tensor(input_data)))
133
134
135@security_off_wrap
136def test_user_defined_bprop():
137    class UserDefinedNet(nn.Cell):
138        def __init__(self):
139            super().__init__()
140            self.print = P.Print()
141
142        def construct(self, x, y):
143            out = x * y
144            return out
145
146        def bprop(self, x, y, out, dout):
147            self.print(out)
148            out = x * y
149            self.print(out)
150            self.print(dout)
151            return y, x
152
153    class GradNet(nn.Cell):
154        def __init__(self, net):
155            super(GradNet, self).__init__()
156            self.net = net
157            self.parameter_tuple = ParameterTuple(net.trainable_params())
158
159        def construct(self, x, y):
160            return grad_all_list(self.net, self.parameter_tuple)(x, y)
161
162    user_defined_net = UserDefinedNet()
163    net = GradNet(user_defined_net)
164    x = Tensor(np.array([2.0], np.float32))
165    y = Tensor(np.array([3.0], np.float32))
166    print(net(x, y))
167
168
169# user defined bprop don't have the same size of parameters with primal's
170@security_off_wrap
171def test_user_defined_bad_bprop():
172    class UserDefinedNet(nn.Cell):
173        def __init__(self):
174            super().__init__()
175            self.print = P.Print()
176
177        def construct(self, x, y):
178            out = x * y
179            return out
180
181        def bprop(self, x, out, dout):
182            self.print(out)
183            out = x
184            self.print(out)
185            self.print(dout)
186            return x, x
187
188    class GradNet(nn.Cell):
189        def __init__(self, net):
190            super(GradNet, self).__init__()
191            self.net = net
192            self.parameter_tuple = ParameterTuple(net.trainable_params())
193
194        def construct(self, x, y):
195            return grad_all_list(self.net, self.parameter_tuple)(x, y)
196
197    user_defined_net = UserDefinedNet()
198    net = GradNet(user_defined_net)
199    x = Tensor(np.array([2.0], np.float32))
200    y = Tensor(np.array([3.0], np.float32))
201    with pytest.raises(TypeError):
202        net(x, y)
203
204
205# shoul compile success and Print in presented in the final function graph.
206@security_off_wrap
207@pytest.mark.skip(reason="isolated nodes exception")
208def test_unused_var():
209    class UnusedVar(nn.Cell):
210        def __init__(self):
211            super().__init__()
212            self.print = P.Print()
213
214        def construct(self, x, y):
215            shape1 = self.get_shape(x)
216            out = x
217            for _ in range(shape1):
218                out = out + y
219            return out
220
221        def get_shape(self, x):
222            self.print(x)
223            _, c, _, _ = F.shape(x)
224            return c
225
226    net = UnusedVar()
227    x = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32)
228    y = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32)
229    print(net(x, y))
230
231
232# shoul compile success and Print in presented in the final function graph.
233@security_off_wrap
234@pytest.mark.skip(reason="isolated nodes exception")
235def test_hof_unused_var():
236    class UnusedVar(nn.Cell):
237        def __init__(self):
238            super().__init__()
239            self.print = P.Print()
240
241        def construct(self, x, y):
242            shape1 = self.hof_get_shape(self.get_shape, x)
243            out = x
244            for _ in range(shape1):
245                out = out + y
246            return out
247
248        def hof_get_shape(self, hof, x):
249            return hof(x)
250
251        def get_shape(self, x):
252            self.print(x)
253            _, c, _, _ = F.shape(x)
254            return c
255
256    net = UnusedVar()
257    x = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32)
258    y = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32)
259    print(net(x, y))
260
261
262# shoul compile success and Print in presented in the final function graph.
263@security_off_wrap
264@pytest.mark.skip(reason="isolated nodes exception")
265def test_partial_hof_unused_var():
266    class UnusedVar(nn.Cell):
267        def __init__(self):
268            super().__init__()
269            self.print = P.Print()
270
271        def construct(self, x, y):
272            shape1 = self.hof_get_shape(x)()
273            out = x
274            for _ in range(shape1):
275                out = out + y
276            return out
277
278        def hof_get_shape(self, x):
279            return F.partial(self.get_shape, x)
280
281        def get_shape(self, x):
282            self.print(x)
283            _, c, _, _ = F.shape(x)
284            return c
285
286    net = UnusedVar()
287    x = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32)
288    y = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32)
289    print(net(x, y))
290
291
292# should compile success without endless loop.
293def test_while_if():
294    class WhileIfNet(nn.Cell):
295        def __init__(self):
296            super().__init__()
297            self.zero = Tensor(np.zeros([1]).astype(np.float32))
298            self.param = Parameter(Tensor(np.zeros([1]).astype(np.float32)))
299
300        def construct(self, idx, end, x):
301            out = self.zero
302            while idx < end:
303                if x < end:
304                    out = out + self.param * 2
305                else:
306                    out = out + self.param
307                idx = idx + 1
308            return out
309
310    idx = Tensor(np.array(0), dtype=ms.int32)
311    end = Tensor(np.array(5), dtype=ms.int32)
312    x = Tensor(np.zeros([1]).astype(np.float32))
313    m = WhileIfNet()
314    m(idx, end, x)
315
316
317# should compile success without zeros_like_tensor args mismatch, the generated graph files
318# should not contain env_getitem or env_setitem.
319# InsertGradientOf primitive will make func_graph bprop_construct had BackPropAutoMonad flag set,
320# so all graph it used will be checked if any side effect it has, so the hyper_map_zeros_like
321# will have U as parameter, but the call site zeros_like(fv) don't have U argument.
322def test_grad_fv_and_insert_gradient_of():
323    class FvAndInsertGradientNet(nn.Cell):
324        def __init__(self):
325            super(FvAndInsertGradientNet, self).__init__()
326            self.gather = P.GatherV2()
327            self.damping = Tensor(np.array([0.03, 0.03], np.float32))
328            self.cov_step = Parameter(0, name="cov_step", requires_grad=False)
329            self.freq = Tensor(278, ms.int32)
330            self.getG = P.InsertGradientOf(self.save_gradient)
331
332            self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
333
334        def save_gradient(self, dout):
335            self.cov_step = self.cov_step + self.freq
336            return dout
337
338        def construct(self, *inputs):
339            # fv self.z from construct_wrapper
340            x, = inputs
341            self.z = x
342
343            # insert_gradient_of
344            self.gather(self.damping, self.cov_step, 0)
345            out = self.getG(x)
346            return out
347
348    net = FvAndInsertGradientNet()
349    input_data = Tensor(np.array([1.0], np.float32))
350    # if use grad_all_list, the generated graph will have env_setitem
351    # as gradient for inputs is constant zero, so it will depend on result of grad.
352    grad_net = grad_by_list(net, ParameterTuple(net.trainable_params()))
353    print(grad_net(input_data))
354
355
356# should compile success as cnode with Partial primitive will not bind an additional U monad.
357def test_partial_parameter():
358    z = Parameter(Tensor(np.array([True], np.bool_)), name='z')
359
360    class PartialNet(nn.Cell):
361        def __init__(self, input_z):
362            super().__init__()
363            self.input = input_z
364
365        def construct(self):
366            # getattr of all will be convert to Partial
367            out = self.input.all(axis=())
368            return out
369
370    net = PartialNet(z)
371    print(net())
372