• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import numpy as np
16import pytest
17from mindspore.common import dtype as mstype
18from mindspore import nn
19from mindspore import Tensor
20from mindspore.ops import composite as C
21from mindspore.ops import operations as P
22from mindspore import context
23from mindspore.common.parameter import Parameter
24
25context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
26grad_all = C.GradOperation(get_all=True)
27
28
29class Grad(nn.Cell):
30    def __init__(self, net):
31        super(Grad, self).__init__(auto_prefix=False)
32        self.forward_net = net
33        self.grad = C.GradOperation(get_all=True)
34
35    def construct(self, *inputs):
36        grads = self.grad(self.forward_net)(*inputs)
37        return grads
38
39
40class ForBreakForwardNet(nn.Cell):
41    def __init__(self, max_cycles=10):
42        super(ForBreakForwardNet, self).__init__()
43        self.max_cycles = max_cycles
44        self.zero = Tensor(np.array(0), mstype.int32)
45
46    def construct(self, x, y):
47        out = self.zero
48        for i in range(self.max_cycles):
49            if i % 2 == 0:
50                continue
51            out = x * y + out
52            if out == 20:
53                return out
54            if out > 20:
55                break
56
57        return out
58
59
60@pytest.mark.level1
61@pytest.mark.platform_x86_gpu_training
62@pytest.mark.platform_arm_ascend_training
63@pytest.mark.platform_x86_ascend_training
64@pytest.mark.env_onecard
65def test_for_break_forward():
66    x = Tensor(np.array(1), mstype.int32)
67    y = Tensor(np.array(3), mstype.int32)
68    forward_net = ForBreakForwardNet(max_cycles=3)
69    graph_out = forward_net(x, y)
70    assert graph_out == Tensor(np.array(3), mstype.int32)
71
72
73@pytest.mark.level0
74@pytest.mark.platform_x86_gpu_training
75@pytest.mark.platform_arm_ascend_training
76@pytest.mark.platform_x86_ascend_training
77@pytest.mark.env_onecard
78def test_for_break_backward():
79    x = Tensor(np.array(1), mstype.int32)
80    y = Tensor(np.array(3), mstype.int32)
81    forward_net = ForBreakForwardNet(max_cycles=3)
82    backward_net = Grad(forward_net)
83    graph_grads = backward_net(x, y)
84    assert graph_grads == (Tensor(np.array(3), mstype.int32), Tensor(np.array(1), mstype.int32))
85
86
87class WhileBreakForwardNet(nn.Cell):
88    def __init__(self, max_cycles=10):
89        super(WhileBreakForwardNet, self).__init__()
90        self.max_cycles = max_cycles
91        self.i = Tensor(np.array(0), mstype.int32)
92        self.zero = Tensor(np.array(0), mstype.int32)
93
94    def construct(self, x, y):
95        i = self.i
96        out = self.zero
97        while i < self.max_cycles:
98            if i % 2 == 0:
99                i = i + 1
100                continue
101            out = x * y + out
102            if out > 20:
103                break
104            if out == 20:
105                return out
106            i = i + 1
107        return out
108
109
110@pytest.mark.level1
111@pytest.mark.platform_x86_gpu_training
112@pytest.mark.platform_arm_ascend_training
113@pytest.mark.platform_x86_ascend_training
114@pytest.mark.env_onecard
115def test_while_break_forward():
116    x = Tensor(np.array(1), mstype.int32)
117    y = Tensor(np.array(3), mstype.int32)
118    forward_net = WhileBreakForwardNet(max_cycles=10)
119    graph_mode_out = forward_net(x, y)
120    assert graph_mode_out == Tensor(np.array(15))
121
122
123@pytest.mark.level0
124@pytest.mark.platform_arm_ascend_training
125@pytest.mark.platform_x86_ascend_training
126@pytest.mark.env_onecard
127def test_while_break_backward():
128    context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
129    x = Tensor(np.array(1), mstype.int32)
130    y = Tensor(np.array(3), mstype.int32)
131    forward_net = WhileBreakForwardNet(max_cycles=10)
132    backward_net = Grad(forward_net)
133    graph_grads = backward_net(x, y)
134    assert graph_grads == (Tensor(np.array(15), mstype.int32), Tensor(np.array(5), mstype.int32))
135
136
137class IfAfterIfInWhileBreakForwardNet(nn.Cell):
138    def __init__(self, max_cycles=10):
139        super(IfAfterIfInWhileBreakForwardNet, self).__init__()
140        self.max_cycles = max_cycles
141        self.i = Tensor(np.array(0), mstype.int32)
142        self.zero = Tensor(np.array(0), mstype.int32)
143        self.weight = Parameter(Tensor(np.array(0), mstype.int32))
144
145    def construct(self, x, y):
146        i = self.i
147        out = self.zero
148        while i < self.max_cycles:
149            self.weight = i
150            if self.weight % 2 == 0:
151                i = i + 1
152                continue
153            if out <= 20:
154                self.weight = i
155                out = x * y + out
156            else:
157                break
158            i = i + 1
159        if out >= 30:
160            self.weight = out
161            out = out - 30
162            return out
163        out = out + 1
164        return out
165
166
167@pytest.mark.level1
168@pytest.mark.platform_x86_gpu_training
169@pytest.mark.platform_arm_ascend_training
170@pytest.mark.platform_x86_ascend_training
171@pytest.mark.env_onecard
172def test_if_after_if_in_while_break_forward():
173    x = Tensor(np.array(1), mstype.int32)
174    y = Tensor(np.array(3), mstype.int32)
175    # Graph Mode
176    context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
177    graph_forward_net = IfAfterIfInWhileBreakForwardNet(max_cycles=10)
178    graph_mode_out = graph_forward_net(x, y)
179    assert graph_mode_out == Tensor(np.array(16), mstype.int32)
180
181
182@pytest.mark.level1
183@pytest.mark.platform_x86_gpu_training
184@pytest.mark.platform_arm_ascend_training
185@pytest.mark.platform_x86_ascend_training
186@pytest.mark.env_onecard
187def test_if_after_if_in_while_break_backward():
188    x = Tensor(np.array(1), mstype.int32)
189    y = Tensor(np.array(3), mstype.int32)
190    # Graph Mode
191    context.set_context(mode=context.GRAPH_MODE)
192    graph_forward_net = IfAfterIfInWhileBreakForwardNet(max_cycles=10)
193    graph_backward_net = Grad(graph_forward_net)
194    graph_mode_grads = graph_backward_net(x, y)
195
196    assert graph_mode_grads == (Tensor(np.array(15), mstype.int32), Tensor(np.array(5), mstype.int32))
197
198
199@pytest.mark.level1
200@pytest.mark.platform_x86_gpu_training
201@pytest.mark.platform_arm_ascend_training
202@pytest.mark.platform_x86_ascend_training
203@pytest.mark.env_onecard
204def test_if_after_for_in_if_break():
205    class IfAfterForInIfNet(nn.Cell):
206        def __init__(self):
207            super().__init__()
208            self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
209            self.param_b = Parameter(Tensor(4, mstype.int32), name='b')
210
211        def construct(self, x):
212            out = x + self.param_a
213            if self.param_a > self.param_b:
214                for _ in range(4):
215                    self.param_a += 1
216                    if self.param_b < 0:
217                        continue
218                    self.param_b -= 3
219                    if self.param_a > 6:
220                        break
221
222            self.param_b += 15
223            if x < self.param_b:
224                out -= self.param_b
225                return out
226            out = self.param_b + out
227            return out
228
229    x = Tensor(2, mstype.int32)
230
231    # graph mode
232
233    forward_net = IfAfterForInIfNet()
234    graph_forward_res = forward_net(x)
235
236    context.set_context(mode=context.GRAPH_MODE)
237    if_after_for_in_if_net = IfAfterForInIfNet()
238    net = Grad(if_after_for_in_if_net)
239    graph_backward_res = net(x)
240
241    assert graph_forward_res == Tensor(-6, mstype.int32)
242    assert graph_backward_res == (Tensor(1, mstype.int32),)
243
244
245@pytest.mark.skip(reason="ME EvalCNode error.")
246@pytest.mark.level0
247@pytest.mark.platform_x86_gpu_training
248@pytest.mark.platform_arm_ascend_training
249@pytest.mark.platform_x86_ascend_training
250@pytest.mark.env_onecard
251def test_if_after_for_in_for_break():
252    class IfAfterForInForNet(nn.Cell):
253        def __init__(self):
254            super().__init__()
255            self.param_a = Parameter(Tensor(5, mstype.int32), name='a')
256            self.param_b = Parameter(Tensor(2, mstype.int32), name='b')
257
258        def construct(self, x):
259            out = x + self.param_a
260            for _ in range(0, 10):
261                x *= 2
262                if self.param_a % 2 == 0:
263                    self.param_a += 1
264                    continue
265                for _ in range(0, 5):
266                    self.param_a += 1
267                    x += self.param_b
268                    if x > 10:
269                        break
270                if x > 100:
271                    return x
272            if self.param_a > self.param_b:
273                out += x
274            return out
275
276    x = Tensor(2, mstype.int32)
277
278    # graph mode
279    forward_net = IfAfterForInForNet()
280    graph_forward_res = forward_net(x)
281
282    if_after_for_in_for_net = IfAfterForInForNet()
283    net = Grad(if_after_for_in_for_net)
284    graph_backward_res = net(x)
285
286    print("test_if_after_for_in_for_break graph_forward_res:", graph_forward_res)
287    print("test_if_after_for_in_for_break graph_backward_res:", graph_backward_res)
288    # assert graph_forward_res == Tensor(12285, mstype.int32)
289    # assert graph_backward_res == (Tensor(1025, mstype.int32),)
290
291
292class WhileAfterWhileInWhileBreakForwardNet(nn.Cell):
293    def __init__(self, max_cycles=10):
294        super(WhileAfterWhileInWhileBreakForwardNet, self).__init__()
295        self.max_cycles = max_cycles
296        self.zero = Tensor(np.array(0), mstype.int32)
297        self.i = Tensor(np.array(0), mstype.int32)
298
299    def construct(self, x, y):
300        out = self.zero
301        i = self.i
302        while i < self.max_cycles:
303            j = self.i
304            while j < self.max_cycles + 3:
305                out = x * y + out
306                j = j + 1
307                if j > 4:
308                    break
309            i = i + 1
310            if i > 2:
311                break
312        i = self.i
313        while i < self.max_cycles:
314            out = x * y + out
315            i = i + 1
316        return out
317
318
319@pytest.mark.level1
320@pytest.mark.platform_x86_gpu_training
321@pytest.mark.platform_arm_ascend_training
322@pytest.mark.platform_x86_ascend_training
323@pytest.mark.env_onecard
324def test_while_after_while_in_while_break_forward():
325    context.set_context(mode=context.GRAPH_MODE)
326    x = Tensor(np.array(1), mstype.int32)
327    y = Tensor(np.array(3), mstype.int32)
328    forward_net = WhileAfterWhileInWhileBreakForwardNet(max_cycles=3)
329    graph_out = forward_net(x, y)
330
331    assert graph_out == Tensor(np.array(54), mstype.int32)
332
333
334@pytest.mark.level1
335@pytest.mark.platform_x86_gpu_training
336@pytest.mark.platform_arm_ascend_training
337@pytest.mark.platform_x86_ascend_training
338@pytest.mark.env_onecard
339def test_while_after_while_in_while_break_backward():
340    context.set_context(mode=context.GRAPH_MODE)
341    x = Tensor(np.array(1), mstype.int32)
342    y = Tensor(np.array(3), mstype.int32)
343    forward_net = WhileAfterWhileInWhileBreakForwardNet(max_cycles=3)
344    backward_net = Grad(forward_net)
345    graph_grads = backward_net(x, y)
346
347    assert graph_grads == (Tensor(np.array(54), mstype.int32), Tensor(np.array(18), mstype.int32))
348
349
350class TwoBreakDeadForwardNet(nn.Cell):
351    def __init__(self):
352        super(TwoBreakDeadForwardNet, self).__init__()
353        self.zero = Tensor(np.array(0), mstype.int32)
354
355    def construct(self, x):
356        while x < 5:
357            if x > 3:
358                x -= 2
359            elif x == 3:
360                break
361            else:
362                break
363            x = x + 1
364        return x
365
366
367@pytest.mark.level1
368@pytest.mark.platform_x86_gpu_training
369@pytest.mark.platform_arm_ascend_training
370@pytest.mark.platform_x86_ascend_training
371@pytest.mark.env_onecard
372def test_2break_dead_block():
373    x = Tensor(np.array(1), mstype.int32)
374    forward_net = TwoBreakDeadForwardNet()
375    graph_out = forward_net(x)
376
377    assert graph_out == Tensor(np.array(1), mstype.int32)
378
379
380class ForInFor2BreakForwardNet(nn.Cell):
381    def __init__(self):
382        super(ForInFor2BreakForwardNet, self).__init__()
383        self.relu = P.ReLU()
384        self.add = P.TensorAdd()
385
386    def construct(self, x, y, z):
387        out = z
388        for _ in range(2):
389            for _ in range(3):
390                if 2 * x < y:
391                    out = self.add(out, out)
392                    x = x + 1
393                    if x + 6 == y:
394                        break
395        out = self.relu(out)
396        return out
397
398
399@pytest.mark.skip(reason="Get wrong parent graph")
400def test_for_in_for_break():
401    x = Tensor(np.array(7), mstype.float32)
402    y = Tensor(np.array(20), mstype.float32)
403    z = Tensor(np.array(2), mstype.float32)
404    forward_net = ForInFor2BreakForwardNet()
405    graph_out = forward_net(x, y, z)
406    print("test_for_in_for_break graph out:", graph_out)
407
408
409# raise a endless loop exception.
410@pytest.mark.skip(reason="Infer raise a endless loop exception")
411def test_while_true_break():
412    context.set_context(save_graphs=True)
413
414    class WhileTrueBreakNet(nn.Cell):
415        def __init__(self, t):
416            super(WhileTrueBreakNet, self).__init__()
417            self.add = P.Add()
418            self.mul = P.Mul()
419            self.para = Parameter(Tensor(t, mstype.int32), name="a")
420
421        def construct(self, x, y):
422            out = self.mul(y, self.para)
423            while True:
424                if x == 5:
425                    x = x - 3
426                    continue
427                if x == 2:
428                    break
429                out = self.add(out, out)
430            return out
431
432    t = np.array([1]).astype(np.int32)
433    y = Tensor([1], mstype.int32)
434    x = Tensor([5], mstype.int32)
435    net = WhileTrueBreakNet(t)
436    grad_net = Grad(net)
437    grad_out = grad_net(x, y)
438    print(grad_out)
439
440
441# stuck in vm backend
442@pytest.mark.skip(reason="Stuck in vm backend")
443def test_continue_stuck_in_vm():
444    context.set_context(save_graphs=True)
445
446    class NetWork(nn.Cell):
447        def __init__(self, t):
448            super().__init__()
449            self.add = P.Add()
450            self.mul = P.Mul()
451            self.para = Parameter(Tensor(t, mstype.int32), name="a")
452
453        def construct(self, x, y):
454            out = self.mul(y, y)
455            while x != 3:
456                while self.para > 5:
457                    # self.param -= 1 if set after if_switch, which is wrong
458                    self.para -= 1
459                    x += 1
460                    if x > 3:
461                        self.para -= x
462                        return out
463                    out = self.add(out, y)
464                continue
465            out = self.mul(out, y)
466            return out
467
468    x = Tensor(2, mstype.int32)
469    t = 8
470    y = Tensor(1, mstype.int32)
471    net = NetWork(t)
472    grad_net = Grad(net)
473    grad = grad_net(x, y)
474    print(grad)
475