• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15import os
16import re
17import time
18import pytest
19import numpy as np
20import mindspore as ms
21import mindspore.ops.operations as P
22import mindspore.nn as nn
23from mindspore.nn import Cell
24from mindspore.nn import ReLU, BatchNorm2d, Conv2d, ParameterUpdate
25from mindspore.nn import Momentum, SoftmaxCrossEntropyWithLogits
26from mindspore import context, Tensor
27from mindspore.common.parameter import Parameter
28from mindspore.common.initializer import initializer
29from mindspore.ops.primitive import constexpr
30from capture import Capture, capture, check_output
31from tests.security_utils import security_off_wrap
32
33context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
34
35
36@pytest.fixture(name="pynative_save_graphs")
37def _pynative_save_graphs():
38    context.set_context(mode=context.PYNATIVE_MODE, save_graphs=True)
39    yield
40    context.set_context(mode=context.GRAPH_MODE, save_graphs=False)
41    clean_all_ir_files('./')
42
43
44@pytest.fixture(name="with_save_graphs")
45def _with_save_graphs():
46    context.set_context(save_graphs=True)
47    yield
48    context.set_context(save_graphs=False)
49    clean_all_ir_files('./')
50
51
52@security_off_wrap
53def test_print():
54    class Print(Cell):
55        def __init__(self):
56            super().__init__()
57            self.print = P.Print()
58
59        def construct(self, x, y):
60            self.print("input_x:", x, "input_y:", y)
61            return x
62
63    cap = Capture()
64    with capture(cap):
65        input_x = Tensor(3, dtype=ms.int32)
66        input_y = Tensor(4, dtype=ms.int32)
67        net = Print()
68        net(input_x, input_y)
69        time.sleep(0.1)
70
71    patterns = {'input_x:\nTensor(shape=[], dtype=Int32, value=3)\n'
72                'input_y:\nTensor(shape=[], dtype=Int32, value=4)'}
73    check_output(cap.output, patterns)
74
75
76@security_off_wrap
77def test_print_add():
78    class Print_Add(Cell):
79        def __init__(self):
80            super().__init__()
81            self.print = P.Print()
82            self.add = P.Add()
83
84        def construct(self, x, y):
85            x = self.add(x, y)
86            self.print("input_x:", x, "input_y:", y)
87            return x
88
89    cap = Capture()
90    with capture(cap):
91        input_x = Tensor(3, dtype=ms.int32)
92        input_y = Tensor(4, dtype=ms.int32)
93        expect = Tensor(7, dtype=ms.int32)
94        net = Print_Add()
95        out = net(input_x, input_y)
96        time.sleep(0.1)
97        np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
98
99    patterns = {'input_x:\nTensor(shape=[], dtype=Int32, value=7)\n'
100                'input_y:\nTensor(shape=[], dtype=Int32, value=4)'}
101    check_output(cap.output, patterns)
102
103
104@security_off_wrap
105def test_print_assign():
106    class Print_Assign(Cell):
107        def __init__(self):
108            super().__init__()
109            self.print = P.Print()
110            self.para = Parameter(Tensor(1, dtype=ms.int32), name='para')
111
112        def construct(self, x):
113            self.print("before:", self.para)
114            self.para = x
115            self.print("after:", self.para)
116            return self.para
117
118    cap = Capture()
119    with capture(cap):
120        input_x = Tensor(3, dtype=ms.int32)
121        expect = Tensor(3, dtype=ms.int32)
122        net = Print_Assign()
123        out = net(input_x)
124        time.sleep(0.1)
125        np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
126
127    patterns = {'before:\nTensor(shape=[], dtype=Int32, value=1)',
128                'after:\nTensor(shape=[], dtype=Int32, value=3)'}
129    check_output(cap.output, patterns)
130
131
132@security_off_wrap
133def test_print_assign_add():
134    class Print_Assign_Add(Cell):
135        def __init__(self):
136            super().__init__()
137            self.print = P.Print()
138            self.add = P.Add()
139            self.para = Parameter(Tensor(1, dtype=ms.int32), name='para')
140
141        def construct(self, x, y):
142            self.print("before:", self.para)
143            self.para = x
144            self.print("after:", self.para)
145            x = self.add(self.para, y)
146            return x
147
148    cap = Capture()
149    with capture(cap):
150        input_x = Tensor(3, dtype=ms.int32)
151        input_y = Tensor(4, dtype=ms.int32)
152        expect = Tensor(7, dtype=ms.int32)
153        net = Print_Assign_Add()
154        out = net(input_x, input_y)
155        time.sleep(0.1)
156        np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
157
158    patterns = {'before:\nTensor(shape=[], dtype=Int32, value=1)',
159                'after:\nTensor(shape=[], dtype=Int32, value=3)'}
160    check_output(cap.output, patterns)
161
162
163@security_off_wrap
164def test_print_while():
165    class Print_While(Cell):
166        def __init__(self):
167            super().__init__()
168            self.print = P.Print()
169
170        def construct(self, x, y):
171            self.print("input_x before:", x, "input_y before:", y)
172            while x < y:
173                self.print("input_x after:", x, "input_y after:", y)
174                x = x + 1
175            return x
176
177    cap = Capture()
178    with capture(cap):
179        input_x = Tensor(1, dtype=ms.int32)
180        input_y = Tensor(4, dtype=ms.int32)
181        expect = Tensor(4, dtype=ms.int32)
182        net = Print_While()
183        out = net(input_x, input_y)
184        time.sleep(0.1)
185        np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
186
187    patterns = {'input_x before:\nTensor(shape=[], dtype=Int32, value=1)\n'
188                'input_y before:\nTensor(shape=[], dtype=Int32, value=4)',
189                'input_x after:\nTensor(shape=[], dtype=Int32, value=1)\n'
190                'input_y after:\nTensor(shape=[], dtype=Int32, value=4)',
191                'input_x after:\nTensor(shape=[], dtype=Int32, value=2)\n'
192                'input_y after:\nTensor(shape=[], dtype=Int32, value=4)',
193                'input_x after:\nTensor(shape=[], dtype=Int32, value=3)\n'
194                'input_y after:\nTensor(shape=[], dtype=Int32, value=4)'}
195    check_output(cap.output, patterns)
196
197
198@security_off_wrap
199def test_print_if():
200    class Print_If(Cell):
201        def __init__(self):
202            super().__init__()
203            self.print = P.Print()
204
205        def construct(self, x, y):
206            self.print("input_x before:", x, "input_y before:", y)
207            if x < y:
208                self.print("input_x after:", x, "input_y after:", y)
209                x = x + 1
210            return x
211
212    cap = Capture()
213    with capture(cap):
214        input_x = Tensor(3, dtype=ms.int32)
215        input_y = Tensor(4, dtype=ms.int32)
216        expect = Tensor(4, dtype=ms.int32)
217        net = Print_If()
218        out = net(input_x, input_y)
219        time.sleep(0.1)
220        np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
221
222    patterns = {'input_x before:\nTensor(shape=[], dtype=Int32, value=3)\n'
223                'input_y before:\nTensor(shape=[], dtype=Int32, value=4)',
224                'input_x after:\nTensor(shape=[], dtype=Int32, value=3)\n'
225                'input_y after:\nTensor(shape=[], dtype=Int32, value=4)'}
226    check_output(cap.output, patterns)
227
228
229@security_off_wrap
230def test_print_assign_while():
231    class Print_Assign_While(Cell):
232        def __init__(self):
233            super().__init__()
234            self.print = P.Print()
235            self.para = Parameter(Tensor(0, dtype=ms.int32), name='para')
236
237        def construct(self, x, y):
238            self.print("input_x before:", x, "input_y before:",
239                       y, "para before:", self.para)
240            while x < y:
241                self.para = x
242                x = self.para + 1
243                self.print("input_x after:", x, "input_y after:",
244                           y, "para after:", self.para)
245            return x
246
247    cap = Capture()
248    with capture(cap):
249        input_x = Tensor(1, dtype=ms.int32)
250        input_y = Tensor(4, dtype=ms.int32)
251        expect = Tensor(4, dtype=ms.int32)
252        net = Print_Assign_While()
253        out = net(input_x, input_y)
254        time.sleep(0.1)
255        np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
256
257    patterns = {
258        'input_x before:\nTensor(shape=[], dtype=Int32, value=1)\n'
259        'input_y before:\nTensor(shape=[], dtype=Int32, value=4)\n'
260        'para before:\nTensor(shape=[], dtype=Int32, value=0)',
261        'input_x after:\nTensor(shape=[], dtype=Int32, value=2)\n'
262        'input_y after:\nTensor(shape=[], dtype=Int32, value=4)\n'
263        'para after:\nTensor(shape=[], dtype=Int32, value=1)',
264        'input_x after:\nTensor(shape=[], dtype=Int32, value=3)\n'
265        'input_y after:\nTensor(shape=[], dtype=Int32, value=4)\n'
266        'para after:\nTensor(shape=[], dtype=Int32, value=2)',
267        'input_x after:\nTensor(shape=[], dtype=Int32, value=4)\n'
268        'input_y after:\nTensor(shape=[], dtype=Int32, value=4)\n'
269        'para after:\nTensor(shape=[], dtype=Int32, value=3)'}
270    check_output(cap.output, patterns)
271
272
273@security_off_wrap
274def test_print_assign_if():
275    class Print_Assign_If(Cell):
276        def __init__(self):
277            super().__init__()
278            self.print = P.Print()
279            self.para = Parameter(Tensor(1, dtype=ms.int32), name='para')
280
281        def construct(self, x, y):
282            self.print("input_x before:", x, "input_y before:",
283                       y, "para before:", self.para)
284            self.para = x
285            if x < y:
286                x = self.para + 1
287                self.print("input_x after:", x, "input_y after:",
288                           y, "para after:", self.para)
289            return x
290
291    cap = Capture()
292    with capture(cap):
293        input_x = Tensor(3, dtype=ms.int32)
294        input_y = Tensor(4, dtype=ms.int32)
295        expect = Tensor(4, dtype=ms.int32)
296        net = Print_Assign_If()
297        out = net(input_x, input_y)
298        time.sleep(0.1)
299        np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
300
301    patterns = {
302        'input_x before:\nTensor(shape=[], dtype=Int32, value=3)\n'
303        'input_y before:\nTensor(shape=[], dtype=Int32, value=4)\n'
304        'para before:\nTensor(shape=[], dtype=Int32, value=1)',
305        'input_x after:\nTensor(shape=[], dtype=Int32, value=4)\n'
306        'input_y after:\nTensor(shape=[], dtype=Int32, value=4)\n'
307        'para after:\nTensor(shape=[], dtype=Int32, value=3)'}
308    check_output(cap.output, patterns)
309
310
311@pytest.mark.level0
312@pytest.mark.platform_arm_ascend_training
313@pytest.mark.platform_x86_ascend_training
314@pytest.mark.env_onecard
315def test_assign():
316    class Assign(Cell):
317        def __init__(self):
318            super().__init__()
319            self.para = Parameter(Tensor(1, dtype=ms.int32), name='para')
320
321        def construct(self, value):
322            self.para = value
323            return self.para
324
325    input_x = Tensor(3, dtype=ms.int32)
326    expect = Tensor(3, dtype=ms.int32)
327    net = Assign()
328    out = net(input_x)
329    np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
330
331
332@pytest.mark.level1
333@pytest.mark.platform_arm_ascend_training
334@pytest.mark.platform_x86_ascend_training
335@pytest.mark.env_onecard
336def test_assign_implicit():
337    class Assign_Implicit(Cell):
338        def __init__(self):
339            super(Assign_Implicit, self).__init__()
340            self.b = Parameter(initializer(
341                1, [5], ms.float32), name="global_step")
342
343        def construct(self, w):
344            self.b = w
345            return self.b
346
347    input_data = Tensor(np.ones([5]).astype(np.int32))
348    net = Assign_Implicit()
349    out = net(input_data)
350    assert out.dtype == ms.float32
351
352
353@pytest.mark.level1
354@pytest.mark.platform_arm_ascend_training
355@pytest.mark.platform_x86_ascend_training
356@pytest.mark.env_onecard
357def test_assign_write_after_read():
358    class Assign_WAR(Cell):
359        def __init__(self):
360            super(Assign_WAR, self).__init__()
361            self.assign = P.Assign()
362            self.sub = P.Sub()
363            self.add = P.Add()
364            self.para = Parameter(Tensor(1, dtype=ms.int32), name='para')
365            self.weight = Parameter(Tensor(5, dtype=ms.int32), name='weight')
366
367        def construct(self, x, y):
368            # without auto_monad, execute order is wrong: Add - Assign - Sub - Assign
369            # expected execute order: Add - Assign - Assign - Sub
370            self.para = self.add(y, x)
371            self.assign(self.para, y)
372            return self.sub(self.para, self.weight)
373
374    input_x = Tensor(3, dtype=ms.int32)
375    input_y = Tensor(4, dtype=ms.int32)
376    expect = Tensor(-1, dtype=ms.int32)
377    net = Assign_WAR()
378    out = net(input_x, input_y)
379    np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
380
381
382@pytest.mark.level1
383@pytest.mark.platform_arm_ascend_training
384@pytest.mark.platform_x86_ascend_training
385@pytest.mark.env_onecard
386def test_assign_read_after_write():
387    class Assign_RAW(Cell):
388        def __init__(self):
389            super(Assign_RAW, self).__init__()
390            self.assign_add = P.AssignAdd()
391            self.greater = P.Greater()
392            self.add = P.Add()
393            self.para = Parameter(Tensor(1, dtype=ms.int32), name='para')
394
395        def construct(self, x, y):
396            # without auto_monad, execute order is wrong: Add - Assign - Greater - AssignAdd
397            # expected execute order: AssignAdd - Add - Assign
398            self.greater(x, y)
399            self.assign_add(self.para, x)
400            self.para = self.add(x, y)
401            return self.para
402
403    input_x = Tensor(3, dtype=ms.int32)
404    input_y = Tensor(4, dtype=ms.int32)
405    expect = Tensor(7, dtype=ms.int32)
406    net = Assign_RAW()
407    out = net(input_x, input_y)
408    np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
409
410
411@pytest.mark.level0
412@pytest.mark.platform_arm_ascend_training
413@pytest.mark.platform_x86_ascend_training
414@pytest.mark.env_onecard
415def test_assign_if():
416    class Assign_If(Cell):
417        def __init__(self):
418            super().__init__()
419            self.para = Parameter(Tensor(1, dtype=ms.int32), name='para')
420
421        def construct(self, x, y):
422            if x < y:
423                self.para = x
424            else:
425                self.para = y
426            return self.para
427
428    input_x = Tensor(3, dtype=ms.int32)
429    input_y = Tensor(4, dtype=ms.int32)
430    expect = Tensor(3, dtype=ms.int32)
431    net = Assign_If()
432    out = net(input_x, input_y)
433    np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
434
435
436@pytest.mark.level1
437@pytest.mark.platform_arm_ascend_training
438@pytest.mark.platform_x86_ascend_training
439@pytest.mark.env_onecard
440def test_if():
441    class If(Cell):
442        def __init__(self):
443            super().__init__()
444            self.add = P.Add()
445            self.sub = P.Sub()
446
447        def construct(self, x, y):
448            if x > y:
449                x = self.sub(x, y)
450            else:
451                x = self.add(x, y)
452            return x
453
454    input_x = Tensor(3, dtype=ms.int32)
455    input_y = Tensor(4, dtype=ms.int32)
456    expect = Tensor(7, dtype=ms.int32)
457    net = If()
458    out = net(input_x, input_y)
459    np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
460
461
462@pytest.mark.level1
463@pytest.mark.platform_arm_ascend_training
464@pytest.mark.platform_x86_ascend_training
465@pytest.mark.env_onecard
466def test_while():
467    class While(Cell):
468        def construct(self, x, y):
469            y = y + 4
470            while x < y:
471                x = x + 1
472            x = x + 3
473            return x
474
475    input_x = Tensor(2, dtype=ms.int32)
476    input_y = Tensor(14, dtype=ms.int32)
477    expect = Tensor(21, dtype=ms.int32)
478    net = While()
479    out = net(input_x, input_y)
480    np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
481
482
483@pytest.mark.level1
484@pytest.mark.platform_arm_ascend_training
485@pytest.mark.platform_x86_ascend_training
486@pytest.mark.env_onecard
487def test_assign_while():
488    class Assign_While(Cell):
489        def __init__(self):
490            super().__init__()
491            self.para = Parameter(Tensor(1, dtype=ms.int32), name='para')
492
493        def construct(self, x, y):
494            y = y + 4
495            while x < y:
496                x = x + 1
497                self.para = x
498            self.para = x - 1
499            return self.para
500
501    input_x = Tensor(2, dtype=ms.int32)
502    input_y = Tensor(14, dtype=ms.int32)
503    expect = Tensor(17, dtype=ms.int32)
504    net = Assign_While()
505    out = net(input_x, input_y)
506    np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
507
508
509@pytest.mark.level1
510@pytest.mark.platform_arm_ascend_training
511@pytest.mark.platform_x86_ascend_training
512@pytest.mark.env_onecard
513def test_for():
514    class For(Cell):
515        def construct(self, x, y):
516            y = x + y
517            for _ in range(20):
518                y = y + 1
519            return y
520
521    input_x = Tensor(2, dtype=ms.int32)
522    input_y = Tensor(4, dtype=ms.int32)
523    expect = Tensor(26, dtype=ms.int32)
524    net = For()
525    out = net(input_x, input_y)
526    np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
527
528
529@security_off_wrap
530def test_print_for():
531    class Print_For(Cell):
532        def __init__(self):
533            super().__init__()
534            self.print = P.Print()
535
536        def construct(self, x, y):
537            y = x + y
538            self.print("input_x before:", x, "input_y before:", y)
539            for _ in range(3):
540                y = y + 1
541                self.print("input_x after:", x, "input_y after:", y)
542            return y
543
544    cap = Capture()
545    with capture(cap):
546        input_x = Tensor(2, dtype=ms.int32)
547        input_y = Tensor(4, dtype=ms.int32)
548        expect = Tensor(9, dtype=ms.int32)
549        net = Print_For()
550        out = net(input_x, input_y)
551        time.sleep(0.1)
552        np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
553
554    patterns = {
555        'input_x before:\nTensor(shape=[], dtype=Int32, value=2)\n'
556        'input_y before:\nTensor(shape=[], dtype=Int32, value=6)',
557        'input_x after:\nTensor(shape=[], dtype=Int32, value=2)\n'
558        'input_y after:\nTensor(shape=[], dtype=Int32, value=7)',
559        'input_x after:\nTensor(shape=[], dtype=Int32, value=2)\n'
560        'input_y after:\nTensor(shape=[], dtype=Int32, value=8)',
561        'input_x after:\nTensor(shape=[], dtype=Int32, value=2)\n'
562        'input_y after:\nTensor(shape=[], dtype=Int32, value=9)'}
563    check_output(cap.output, patterns)
564
565
566@security_off_wrap
567def test_print_assign_for():
568    class Print_Assign_For(Cell):
569        def __init__(self):
570            super().__init__()
571            self.print = P.Print()
572            self.para = Parameter(Tensor(1, dtype=ms.int32), name='para')
573
574        def construct(self, x, y):
575            y = x + y
576            self.print("input_x before:", x, "input_y before:",
577                       y, "para before:", self.para)
578            for _ in range(3):
579                y = y + 1
580                self.para = x + y
581                self.print("input_x after:", x, "input_y after:",
582                           y, "para after:", self.para)
583            return y
584
585    cap = Capture()
586    with capture(cap):
587        input_x = Tensor(2, dtype=ms.int32)
588        input_y = Tensor(4, dtype=ms.int32)
589        expect = Tensor(9, dtype=ms.int32)
590        net = Print_Assign_For()
591        out = net(input_x, input_y)
592        time.sleep(0.1)
593        np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
594
595    patterns = {
596        'input_x before:\nTensor(shape=[], dtype=Int32, value=2)\n'
597        'input_y before:\nTensor(shape=[], dtype=Int32, value=6)\n'
598        'para before:\nTensor(shape=[], dtype=Int32, value=1)',
599        'input_x after:\nTensor(shape=[], dtype=Int32, value=2)\n'
600        'input_y after:\nTensor(shape=[], dtype=Int32, value=7)\n'
601        'para after:\nTensor(shape=[], dtype=Int32, value=9)',
602        'input_x after:\nTensor(shape=[], dtype=Int32, value=2)\n'
603        'input_y after:\nTensor(shape=[], dtype=Int32, value=8)\n'
604        'para after:\nTensor(shape=[], dtype=Int32, value=10)',
605        'input_x after:\nTensor(shape=[], dtype=Int32, value=2)\n'
606        'input_y after:\nTensor(shape=[], dtype=Int32, value=9)\n'
607        'para after:\nTensor(shape=[], dtype=Int32, value=11)'}
608    check_output(cap.output, patterns)
609
610
611@pytest.mark.level1
612@pytest.mark.platform_arm_ascend_training
613@pytest.mark.platform_x86_ascend_training
614@pytest.mark.env_onecard
615def test_assign_for():
616    class Assign_For(Cell):
617        def __init__(self):
618            super().__init__()
619            self.para = Parameter(Tensor(1, dtype=ms.int32), name='para')
620
621        def construct(self, x, y):
622            y = y + 4
623            for _ in range(5):
624                x = x + y
625                self.para = x
626            return self.para
627
628    input_x = Tensor(2, dtype=ms.int32)
629    input_y = Tensor(3, dtype=ms.int32)
630    expect = Tensor(37, dtype=ms.int32)
631    net = Assign_For()
632    out = net(input_x, input_y)
633    np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
634
635
636@constexpr
637def _check_shape(shape):
638    if len(shape) != 1:
639        raise ValueError(f"Invalid shape {shape}")
640
641
642@pytest.mark.level0
643@pytest.mark.platform_arm_ascend_training
644@pytest.mark.platform_x86_ascend_training
645@pytest.mark.env_onecard
646def test_constexpr_check():
647    class ConstexprCheck(Cell):
648        def __init__(self):
649            super(ConstexprCheck, self).__init__()
650            self.shape = P.Shape()
651
652        def construct(self, x, y):
653            s = self.shape(x)
654            _check_shape(s)
655            x = x + y
656            return x
657
658    x = Tensor([2], dtype=ms.int32)
659    y = Tensor([3], dtype=ms.int32)
660    expect = Tensor(5, dtype=ms.int32)
661    net = ConstexprCheck()
662    # Input with valid shape.
663    out = net(x, y)
664    np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
665    # Input with wrong shape, exception is expected.
666    with pytest.raises(ValueError):
667        wrong_x = Tensor(np.ones((2, 2)), dtype=ms.int32)
668        out = net(wrong_x, y)
669        print(out)
670
671
672@pytest.mark.level0
673@pytest.mark.platform_arm_ascend_training
674@pytest.mark.platform_x86_ascend_training
675@pytest.mark.env_onecard
676def test_if_lambda():
677    class If_Lambda(Cell):
678        def __init__(self):
679            super().__init__()
680            self.para = Parameter(Tensor(1, dtype=ms.int32), name='para')
681
682        def construct(self, x, y):
683            out = x
684            if x < y:
685                x2 = (lambda a: a + a)
686                out = x2(self.para)
687                out = out + y
688            return out
689
690    input_x = Tensor(2, dtype=ms.int32)
691    input_y = Tensor(3, dtype=ms.int32)
692    expect = Tensor(5, dtype=ms.int32)
693    net = If_Lambda()
694    out = net(input_x, input_y)
695    np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
696
697
698@pytest.mark.level1
699@pytest.mark.platform_arm_ascend_training
700@pytest.mark.platform_x86_ascend_training
701@pytest.mark.env_onecard
702def test_multi_assign():
703    class Multi_Assign(Cell):
704        def __init__(self):
705            super().__init__()
706            self.assign = P.Assign()
707            self.para1 = Parameter(Tensor(1, dtype=ms.int32), name='para1')
708            self.para2 = Parameter(Tensor(2, dtype=ms.int32), name='para2')
709            self.para3 = Parameter(Tensor(3, dtype=ms.int32), name='para3')
710
711        def construct(self, x, y, z):
712            a = self.assign(self.para1, x)
713            a = self.assign(self.para2, y)
714            a = self.assign(self.para3, z)
715            return self.para1 + self.para2 + a
716
717    x = Tensor(4, dtype=ms.int32)
718    y = Tensor(5, dtype=ms.int32)
719    z = Tensor(6, dtype=ms.int32)
720    expect = Tensor(15, dtype=ms.int32)
721    net = Multi_Assign()
722    out = net(x, y, z)
723    np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
724
725
726@pytest.mark.level1
727@pytest.mark.platform_arm_ascend_training
728@pytest.mark.platform_x86_ascend_training
729@pytest.mark.env_onecard
730def test_multi_assign_addn():
731    class Multi_Assign_Addn(Cell):
732        def __init__(self):
733            super().__init__()
734            self.addn = P.AddN()
735            self.assign = P.Assign()
736            self.para1 = Parameter(Tensor(1.0, dtype=ms.float32), name='para1')
737            self.para2 = Parameter(Tensor(3.0, dtype=ms.float32), name='para2')
738
739        def construct(self, inputs):
740            self.assign(self.para1, inputs)
741            out = self.addn((inputs, self.para1, self.para2))
742            self.assign(self.para2, inputs)
743            out = self.addn((out, self.para1, self.para2))
744            return out
745
746    x = Tensor(9.0, dtype=ms.float32)
747    expect = Tensor(39.0, dtype=ms.float32)
748    net = Multi_Assign_Addn()
749    out = net(x)
750    np.testing.assert_almost_equal(out.asnumpy(), expect.asnumpy())
751
752
753@security_off_wrap
754def test_multi_assign_print():
755    class Multi_Assign_Print(Cell):
756        def __init__(self):
757            super().__init__()
758            self.pow = P.Pow()
759            self.print = P.Print()
760            self.assign = P.Assign()
761            self.exponent = Tensor([2.0], ms.float32)
762            self.para1 = Parameter(Tensor(1.0, dtype=ms.float32), name='para1')
763            self.para2 = Parameter(Tensor(3.0, dtype=ms.float32), name='para2')
764
765        def construct(self, inputs):
766            self.assign(self.para1, inputs)
767            self.assign(self.para2, self.pow(inputs, self.exponent))
768            self.print(inputs)
769            self.print(self.para1)
770            self.print(self.para2)
771            return inputs
772
773    x = Tensor(9.0, dtype=ms.float32)
774    expect = Tensor(9.0, dtype=ms.float32)
775    expect_para1 = Tensor(9.0, dtype=ms.float32)
776    expect_para2 = Tensor(81.00001, dtype=ms.float32)
777    net = Multi_Assign_Print()
778    out = net(x)
779    np.testing.assert_almost_equal(out.asnumpy(), expect.asnumpy())
780    np.testing.assert_almost_equal(
781        net.para1.data.asnumpy(), expect_para1.asnumpy())
782    np.testing.assert_almost_equal(
783        net.para2.data.asnumpy(), expect_para2.asnumpy())
784
785
786@pytest.mark.level1
787@pytest.mark.platform_arm_ascend_training
788@pytest.mark.platform_x86_ascend_training
789@pytest.mark.env_onecard
790def test_matmul_assign_biasadd():
791    class Matmul_Assign_Biasadd(Cell):
792        def __init__(self):
793            super().__init__()
794            inputs = np.array([[1, 1], [1, 1]])
795            self.parameter1 = Parameter(
796                Tensor(inputs, ms.float32), name="parameter1")
797            biasadd = np.array([0, -1])
798            self.parameter2 = Parameter(
799                Tensor(biasadd, ms.float32), name="biasadd")
800            self.assign = P.Assign()
801            self.matmul = P.MatMul()
802            self.biasadd = P.BiasAdd()
803
804        def construct(self, x):
805            self.assign(self.parameter1, x)
806            x = self.matmul(x, self.parameter1)
807            self.assign(self.parameter1, x)
808            x = self.biasadd(x, self.parameter2)
809            return x
810
811    net = Matmul_Assign_Biasadd()
812    inputs = np.array([[1, 2], [3, 4]])
813    out1 = net(Tensor(inputs, ms.float32))
814    net = Matmul_Assign_Biasadd()
815    try:
816        context.set_context(mode=context.PYNATIVE_MODE)
817        out2 = net(Tensor(inputs, ms.float32))
818        np.testing.assert_almost_equal(out1.asnumpy(), out2.asnumpy())
819    finally:
820        context.set_context(mode=context.GRAPH_MODE)
821
822
823@pytest.mark.level1
824@pytest.mark.platform_arm_ascend_training
825@pytest.mark.platform_x86_ascend_training
826@pytest.mark.env_onecard
827def test_assign_while_if():
828    class Assign_While_If(Cell):
829        def __init__(self):
830            super().__init__()
831            self.mul = P.Mul()
832            self.addn = P.AddN()
833            self.assign = P.Assign()
834            self.assign_sub = P.AssignSub()
835            self.para = Parameter(Tensor(1.0, dtype=ms.float32), name='para')
836
837        def construct(self, x, y, z, w):
838            self.assign(self.para, x)
839            if self.para > y:
840                self.assign(self.para, y)
841                x = self.mul(x, x)
842            while self.para > z:
843                x = self.addn((x, self.para))
844                self.assign_sub(self.para, w)
845            return x
846
847    x = Tensor(99.0, dtype=ms.float32)
848    y = Tensor(44.0, dtype=ms.float32)
849    z = Tensor(11.0, dtype=ms.float32)
850    w = Tensor(1.0, dtype=ms.float32)
851    expect = Tensor(10725.0, dtype=ms.float32)
852    net = Assign_While_If()
853    out = net(x, y, z, w)
854    np.testing.assert_almost_equal(out.asnumpy(), expect.asnumpy())
855
856
857@pytest.mark.level0
858@pytest.mark.platform_arm_ascend_training
859@pytest.mark.platform_x86_ascend_training
860@pytest.mark.env_onecard
861def test_isolate_call():
862    class Net(Cell):
863        def __init__(self):
864            super().__init__()
865            self.para1 = Parameter(Tensor(1, dtype=ms.int32), name='para1')
866            self.para2 = Parameter(Tensor(2, dtype=ms.int32), name='para2')
867
868        def construct(self, x, y):
869            self.setpara(x, y)
870            return self.para1 + self.para2
871
872        def setpara(self, x, y):
873            self.para1 = x
874            self.setpara2(y)
875            return x
876
877        def setpara2(self, y):
878            self.para2 = y
879            return y
880
881    x = Tensor(4, dtype=ms.int32)
882    y = Tensor(5, dtype=ms.int32)
883    expect = Tensor(9, dtype=ms.int32)
884    net = Net()
885    out = net(x, y)
886    np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
887
888
889@pytest.mark.level1
890@pytest.mark.platform_arm_ascend_training
891@pytest.mark.platform_x86_ascend_training
892@pytest.mark.env_onecard
893def test_assign_return_true():
894    class Net(Cell):
895        def __init__(self):
896            super(Net, self).__init__()
897            self.para = Parameter(Tensor(1, dtype=ms.int32), name='para')
898
899        def construct(self, x, y):
900            if self.mycheck(x, y):
901                out = x + y
902            else:
903                out = x - y
904            out = self.para + out
905            return out
906
907        def mycheck(self, x, y):
908            self.setpara(x, y)
909            return True
910
911        def setpara(self, x, y):
912            self.para = x + y
913            return True
914
915    x = Tensor(2, dtype=ms.int32)
916    y = Tensor(3, dtype=ms.int32)
917    expect = Tensor(10, dtype=ms.int32)
918    net = Net()
919    out = net(x, y)
920    np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
921
922
923@pytest.mark.level0
924@pytest.mark.platform_arm_ascend_training
925@pytest.mark.platform_x86_ascend_training
926@pytest.mark.env_onecard
927def test_unpack_call():
928    class SetPara(Cell):
929        def __init__(self, para):
930            super(SetPara, self).__init__()
931            self.para = para
932
933        def construct(self, x, y):
934            self.para = x + y
935            return True
936
937    class MyNet(Cell):
938        def __init__(self):
939            super(MyNet, self).__init__()
940            self.para = Parameter(Tensor(1, dtype=ms.int32), name='para')
941            self.set_para = SetPara(self.para)
942
943        def construct(self, *inputs):
944            self.call_func(self.set_para, *inputs)
945            out = self.para + 1
946            return out
947
948        def call_func(self, func, *inputs):
949            func(*inputs)
950            return True
951
952    x = Tensor(2, dtype=ms.int32)
953    y = Tensor(3, dtype=ms.int32)
954    expect = Tensor(6, dtype=ms.int32)
955    net = MyNet()
956    out = net(x, y)
957    np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
958
959
960@pytest.mark.level0
961@pytest.mark.platform_arm_ascend_training
962@pytest.mark.platform_x86_ascend_training
963@pytest.mark.env_onecard
964def test_tuple_of_tuple():
965    class SetPara(Cell):
966        def __init__(self, para):
967            super(SetPara, self).__init__()
968            self.para = para
969
970        def construct(self, x, y):
971            self.para = x + y
972            return True
973
974    class MyNet(Cell):
975        def __init__(self):
976            super(MyNet, self).__init__()
977            self.para = Parameter(Tensor(1, dtype=ms.int32), name='para')
978            self.set_para = SetPara(self.para)
979
980        def construct(self, x, y):
981            t1 = (self.set_para, x)
982            t2 = (t1, y)
983            t2[0][0](t2[1], t1[1])
984            out = self.para + 1
985            return out
986
987        def call_func(self, func, *inputs):
988            func(*inputs)
989            return True
990
991    x = Tensor(2, dtype=ms.int32)
992    y = Tensor(3, dtype=ms.int32)
993    expect = Tensor(6, dtype=ms.int32)
994    net = MyNet()
995    out = net(x, y)
996    np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
997
998
999@pytest.mark.level0
1000@pytest.mark.platform_arm_ascend_training
1001@pytest.mark.platform_x86_ascend_training
1002@pytest.mark.env_onecard
1003def test_write_read_write():
1004    class MyNet(Cell):
1005        def __init__(self):
1006            super(MyNet, self).__init__()
1007            self.para1 = Parameter(Tensor(1, dtype=ms.int32), name='para1')
1008            self.para2 = Parameter(Tensor(2, dtype=ms.int32), name='para2')
1009
1010        def construct(self, x, y, x1, y1):
1011            self.para1 = x
1012            self.para2 = y
1013            a = self.para1 + self.para2
1014            self.para1 = x1
1015            self.para2 = y1
1016            return a + self.para1 + self.para2
1017
1018    x = Tensor(3, dtype=ms.int32)
1019    y = Tensor(4, dtype=ms.int32)
1020    x1 = Tensor(5, dtype=ms.int32)
1021    y1 = Tensor(6, dtype=ms.int32)
1022    expect = Tensor(18, dtype=ms.int32)
1023    net = MyNet()
1024    out = net(x, y, x1, y1)
1025    np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
1026
1027
1028@pytest.mark.level0
1029@pytest.mark.platform_arm_ascend_training
1030@pytest.mark.platform_x86_ascend_training
1031@pytest.mark.env_onecard
1032def test_variable_from_outer_graph():
1033    class MyNet(Cell):
1034        def __init__(self):
1035            super(MyNet, self).__init__()
1036            self.cond = False
1037            self.add = P.Add()
1038            self.para = Parameter(Tensor(1, dtype=ms.int32), name='para')
1039
1040        def construct(self, x, y):
1041            b = self.para + x
1042            a = self.para + b
1043            if self.cond:
1044                a = self.add(a, x)
1045            else:
1046                a = self.add(a, y)
1047            return a + b
1048
1049    x = Tensor(2, dtype=ms.int32)
1050    y = Tensor(3, dtype=ms.int32)
1051    expect = Tensor(10, dtype=ms.int32)
1052    net = MyNet()
1053    out = net(x, y)
1054    np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy())
1055
1056
1057@pytest.mark.level1
1058@pytest.mark.platform_arm_ascend_training
1059@pytest.mark.platform_x86_ascend_training
1060@pytest.mark.env_onecard
1061def test_ctrl_while_by_while_and_if_in_first_while():
1062    class Net(Cell):
1063        def __init__(self):
1064            super().__init__()
1065            self.relu = P.ReLU()
1066            self.sigmoid = P.Sigmoid()
1067            self.tanh = P.Tanh()
1068            self.add = P.Add()
1069            a = np.full((1,), 5, dtype=np.float32)
1070            self.a = Parameter(Tensor(a), name="a")
1071            b = np.full((1,), 4, dtype=np.float32)
1072            self.b = Parameter(Tensor(b), name="b")
1073            c = np.full((1,), 7, dtype=np.float32)
1074            self.c = Parameter(Tensor(c), name="c")
1075
1076        def construct(self, x):
1077            out = x
1078            while self.a < 7:
1079                if self.a < self.c:
1080                    out = self.relu(x)
1081                self.a += 1
1082            while self.c > 5:
1083                out = self.add(out, out)
1084                self.c -= 1
1085            return out
1086
1087    context.set_context(mode=context.GRAPH_MODE)
1088    input_np_a = np.random.randn(2, 3, 4, 5).astype(np.float32)
1089    input_me_a = Tensor(input_np_a)
1090    net = Net()
1091    net(input_me_a)
1092
1093
1094@pytest.mark.level1
1095@pytest.mark.platform_arm_ascend_training
1096@pytest.mark.platform_x86_ascend_training
1097@pytest.mark.env_onecard
1098def test_ctrl_if_by_while_and_while_in_first_if():
1099    class Net(Cell):
1100        def __init__(self):
1101            super().__init__()
1102            self.relu = P.ReLU()
1103            self.sigmoid = P.Sigmoid()
1104            self.tanh = P.Tanh()
1105            self.add = P.Add()
1106            a = np.full((1,), 5, dtype=np.float32)
1107            self.a = Parameter(Tensor(a), name="a")
1108            b = np.full((1,), 4, dtype=np.float32)
1109            self.b = Parameter(Tensor(b), name="b")
1110            c = np.full((1,), 7, dtype=np.float32)
1111            self.c = Parameter(Tensor(c), name="c")
1112
1113        def construct(self, x):
1114            out = x
1115            if self.a < self.c:
1116                out = self.relu(x)
1117                while self.a < 7:
1118                    self.a += 1
1119
1120            while self.c > 5:
1121                out = self.add(out, out)
1122                self.c -= 1
1123            return out
1124
1125    context.set_context(mode=context.GRAPH_MODE)
1126    input_np_a = np.random.randn(2, 3, 4, 5).astype(np.float32)
1127    input_me_a = Tensor(input_np_a)
1128    net = Net()
1129    net(input_me_a)
1130
1131
1132@pytest.mark.level0
1133@pytest.mark.platform_arm_ascend_training
1134@pytest.mark.platform_x86_ascend_training
1135@pytest.mark.env_onecard
1136def test_ctrl_while_by_while_and_while_in_first_while():
1137    class Net(Cell):
1138        def __init__(self):
1139            super().__init__()
1140            self.relu = P.ReLU()
1141            self.sigmoid = P.Sigmoid()
1142            self.tanh = P.Tanh()
1143            self.add = P.Add()
1144            a = np.full((1,), 5, dtype=np.float32)
1145            self.a = Parameter(Tensor(a), name="a")
1146            b = np.full((1,), 4, dtype=np.float32)
1147            self.b = Parameter(Tensor(b), name="b")
1148            c = np.full((1,), 7, dtype=np.float32)
1149            self.c = Parameter(Tensor(c), name="c")
1150
1151        def construct(self, x):
1152            out = x
1153            while self.a < self.c:
1154                out = self.relu(x)
1155                while self.b > 1:
1156                    self.b -= 1
1157                self.a += 1
1158
1159            while self.c > 5:
1160                out = self.add(out, out)
1161                self.c -= 1
1162            return out
1163
1164    context.set_context(mode=context.GRAPH_MODE)
1165    input_np_a = np.random.randn(2, 3, 4, 5).astype(np.float32)
1166    input_me_a = Tensor(input_np_a)
1167    net = Net()
1168    net(input_me_a)
1169
1170
1171def clear_json_info():
1172    os.system("rm -rf ./kernel_meta/*.json")
1173    os.system("rm -rf ./kernel_meta/*.info")
1174
1175
1176def find_json_info(file):
1177    result = os.system("ls -al ./kernel_meta/%s" % (file))
1178    return result
1179
1180
1181class MultiOutReluBywaySqrt(Cell):
1182    def __init__(self):
1183        super().__init__()
1184        self.relu = nn.ReLU()
1185        self.sqrt = P.Sqrt()
1186
1187    def construct(self, x):
1188        x = self.relu(x)
1189        x = self.relu(x)
1190        x1 = self.relu(x)
1191        x = self.relu(x1)
1192        y = self.sqrt(x1)
1193        return x, y
1194
1195
1196class MultiOutReluSqrtBywaySqrt(Cell):
1197    def __init__(self):
1198        super().__init__()
1199        self.relu = nn.ReLU()
1200        self.sqrt = P.Sqrt()
1201        self.sin = P.Sin()
1202
1203    def construct(self, x):
1204        x = self.relu(x)
1205        x = self.sqrt(x)
1206        x1 = self.relu(x)
1207        x = self.sin(x1)
1208        y = self.sqrt(x1)
1209        return x, y
1210
1211
1212def clean_all_ir_files(folder_path):
1213    if os.path.exists(folder_path):
1214        for file_name in os.listdir(folder_path):
1215            if file_name.endswith('.ir') or file_name.endswith('.dot') or \
1216                    file_name.endswith('.dat') or file_name.endswith('.pb') or \
1217                    file_name.startswith('trace_code_graph'):
1218                os.remove(os.path.join(folder_path, file_name))
1219
1220
1221def find_newest_validateir_file(folder_path):
1222    ckpt_files = map(lambda f: os.path.join(folder_path, f),
1223                     filter(lambda f: re.match(r'\d+_validate_\d+.ir', f),
1224                            os.listdir(folder_path)))
1225    return max(ckpt_files, key=os.path.getctime)
1226
1227
1228def read_file():
1229    filename = find_newest_validateir_file('./')
1230    with open((os.path.join(filename)), 'r') as f:
1231        content = f.read()
1232    return content
1233
1234
1235def check_keep_batchnorm_fp32_false(kwargs, level):
1236    if ms.context.get_context("device_target") == "GPU":
1237        if level == "O2":
1238            if "keep_batchnorm_fp32" in kwargs.keys() and (not kwargs["keep_batchnorm_fp32"]):
1239                if "cast_model_type" not in kwargs.keys() or kwargs["cast_model_type"] == ms.float16:
1240                    return True
1241        else:
1242            if "cast_model_type" in kwargs.keys() and kwargs["cast_model_type"] == ms.float16:
1243                if "keep_batchnorm_fp32" not in kwargs.keys() or (not kwargs["keep_batchnorm_fp32"]):
1244                    return True
1245    return False
1246
1247
1248def use_build_train_network_check_cast_num(network, level, inputs, label, cast_num, loss_flag=True, **kwargs):
1249    diff_cast = 0
1250    if check_keep_batchnorm_fp32_false(kwargs, level):
1251        diff_cast += 8
1252    opt = Momentum(learning_rate=0.0001, momentum=0.009,
1253                   params=network.trainable_params())
1254    loss = None
1255    if loss_flag:
1256        loss = SoftmaxCrossEntropyWithLogits(sparse=False, reduction='mean')
1257
1258    train_network = ms.amp.build_train_network(
1259        network, opt, loss, level=level, **kwargs)
1260    out_me = train_network(inputs, label)
1261    if context.get_context("mode") == 0:
1262        content = read_file()
1263        castnum = re.findall('Cast', content)
1264        assert len(castnum) == max(cast_num - diff_cast, 0)
1265    return out_me
1266
1267
1268class AssignNet(Cell):
1269    def __init__(self):
1270        super().__init__()
1271        self.relu = ReLU()
1272        self.mean = P.ReduceMean(keep_dims=False)
1273        self.assign_sub = P.AssignSub()
1274        self.input_data = Parameter(initializer(
1275            1, [1, 3, 2, 2], ms.float32), name='value')
1276
1277    def construct(self, x):
1278        x = self.assign_sub(self.input_data, x)
1279        x = self.relu(x)
1280        x = self.mean(x, (2, 3))
1281        return x
1282
1283@security_off_wrap
1284def test_auto_mixed_precision_train_1(pynative_save_graphs):
1285    net = AssignNet()
1286    input32 = Tensor(np.ones([1, 3, 2, 2]).astype(np.float32))
1287    label32 = Tensor(np.zeros([1, 3]).astype(np.float32))
1288    use_build_train_network_check_cast_num(net, "O0", input32, label32, 0)
1289
1290@security_off_wrap
1291def test_auto_mixed_precision_train_2(pynative_save_graphs):
1292    net = AssignNet()
1293    input32 = Tensor(np.ones([1, 3, 2, 2]).astype(np.float32))
1294    label32 = Tensor(np.zeros([1, 3]).astype(np.float32))
1295    use_build_train_network_check_cast_num(net, "O2", input32, label32, 2)
1296
1297
1298class MixControlNet(Cell):
1299    def __init__(self, in_channel, x):
1300        super().__init__()
1301        self.biasadd = P.BiasAdd()
1302        self.equal = P.Equal()
1303        self.addn = P.AddN()
1304        self.conv = Conv2d(in_channels=in_channel, out_channels=in_channel,
1305                           kernel_size=1, stride=1, has_bias=False,
1306                           weight_init='ones', pad_mode='same')
1307        self.bn = BatchNorm2d(num_features=in_channel)
1308        self.assignadd = P.AssignAdd()
1309        self.assign = P.Assign()
1310        self.relu = ReLU()
1311        self.mean = P.ReduceMean(keep_dims=False)
1312        self.bias = Parameter(
1313            Tensor(np.random.randint(2, size=(3,)).astype((np.float32))),
1314            name="bias")
1315        self.bias2 = Parameter(Tensor(np.ones([3]).astype(np.float32)),
1316                               name="bias2")
1317        self.parameterupdate = ParameterUpdate(self.bias)
1318        self.value = Tensor(np.random.randn(*(3,)), ms.float32)
1319        self.x = x
1320
1321    def construct(self, input_x):
1322        x = self.x
1323        z = self.x
1324        out = self.biasadd(input_x, self.bias)
1325        while x < 20:
1326            update = self.parameterupdate(self.bias2)
1327            out = self.biasadd(out, update)
1328            if x < 10:
1329                out = self.addn((input_x, out))
1330                while z < 20:
1331                    out = self.conv(out)
1332                    z = z + 1
1333            if x < 20:
1334                out = self.biasadd(out, self.bias)
1335                if x % 2 == 0:
1336                    out = self.biasadd(out, self.bias)
1337                    self.assignadd(self.bias, self.value)
1338                    out = self.bn(out)
1339                else:
1340                    out = self.conv(out)
1341            x = x + 1
1342        out = self.addn((out, out))
1343        out = self.mean(out, (2, 3))
1344        return out
1345
1346
1347def use_build_train_network_controlflow_check_cast_num(network, level, input_x,
1348                                                       label, cast_num,
1349                                                       sparse=False,
1350                                                       loss_flag=True,
1351                                                       **kwargs):
1352    opt = Momentum(learning_rate=0.0001, momentum=0.009,
1353                   params=network.trainable_params())
1354    loss = None
1355    if loss_flag:
1356        loss = SoftmaxCrossEntropyWithLogits(sparse=sparse, reduction='mean')
1357
1358    train_network = ms.amp.build_train_network(network, opt, loss, level=level,
1359                                               **kwargs)
1360    out_me = train_network(input_x, label)
1361    if context.get_context("mode") == 0:
1362        content = read_file()
1363        castnum = re.findall('Cast', content)
1364        assert len(castnum) == cast_num
1365    return out_me
1366
1367@security_off_wrap
1368def test_auto_mixed_precision_controlflow_auto(pynative_save_graphs):
1369    net = MixControlNet(3, 5)
1370    input_x = Tensor(
1371        np.random.randint(2, size=(1, 3, 2, 2)).astype((np.float32)))
1372    label = Tensor(np.zeros([1, 3]).astype(np.float32))
1373    if ms.context.get_context("device_target") == "Ascend":
1374        cast_num = 77
1375    if ms.context.get_context("device_target") == "GPU":
1376        cast_num = 73
1377    use_build_train_network_controlflow_check_cast_num(net, "auto", input_x,
1378                                                       label, cast_num)
1379
1380
1381# op_cast should be located in order_list after abstract_specialize.
1382# Besides Ascend, it can work on CPU.
1383@pytest.mark.level1
1384@pytest.mark.platform_arm_ascend_training
1385@pytest.mark.platform_x86_ascend_training
1386@pytest.mark.env_onecard
1387def test_if_cast():
1388    class Net(nn.Cell):
1389        def __init__(self, cond1):
1390            super().__init__()
1391            self.cond1 = cond1
1392            self.op_cast = P.Cast()
1393            self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
1394
1395        def construct(self, beta1, beta2):
1396            z_local = self.op_cast(self.z, ms.float16)
1397            self.z = beta2
1398            if self.cond1:
1399                out = z_local + beta1
1400            else:
1401                out = z_local - beta1
1402
1403            return out
1404
1405    net = Net(True)
1406    beta1 = Tensor(np.array([2]).astype(np.float32))
1407    beta2 = Tensor(np.array([10]).astype(np.float32))
1408    r1 = net(beta1, beta2)
1409    expect = Tensor(np.array([3]).astype(np.float32))
1410    np.testing.assert_array_equal(r1.asnumpy(), expect.asnumpy())
1411
1412
1413@pytest.mark.level0
1414@pytest.mark.platform_arm_ascend_training
1415@pytest.mark.platform_x86_ascend_training
1416@pytest.mark.env_onecard
1417def test_while_forward():
1418    class MyWhileNet(nn.Cell):
1419        def __init__(self):
1420            super().__init__()
1421            self.max = P.ReduceMax()
1422
1423        def construct(self, idx, end, x):
1424            while idx < end:
1425                part = x[idx, :, :]
1426                max_num = self.max(part)
1427                x[idx, :, 0:2] = max_num
1428                idx = idx + 1
1429            return x
1430
1431    net = MyWhileNet()
1432    idx = Tensor(np.array(0), dtype=ms.int32)
1433    end = Tensor(np.array(2), dtype=ms.int32)
1434    x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32)
1435    output = net(idx, end, x)
1436    expect = np.array([[[3, 3], [3, 3]], [[7, 7], [7, 7]]], dtype=np.int32)
1437    assert np.allclose(output.asnumpy(), expect, 0.0001, 0.0001)
1438
1439
1440@pytest.mark.level0
1441@pytest.mark.platform_arm_ascend_training
1442@pytest.mark.platform_x86_ascend_training
1443@pytest.mark.env_onecard
1444def test_multi_add_assign():
1445    class Net(Cell):
1446        def __init__(self, i1):
1447            super(Net, self).__init__()
1448            self.add = P.Add()
1449            self.sub = P.Sub()
1450            self.mul = P.Mul()
1451            self.assign = P.Assign()
1452            self.p = Parameter(i1, name='para')
1453
1454        def construct(self, a, d, e):
1455            res1 = self.add(self.add(self.add(self.p, a), a), a)
1456            mul = self.mul(d, e)
1457            self.assign(self.p, mul)
1458            res2 = self.sub(self.p, e)
1459            return res2, res1
1460
1461    def numpy_out(p, a, d, e):
1462        res1 = p + a + a + a
1463        res_as = d * e
1464        res2 = d * e - e
1465        return res2, res1, res_as
1466
1467    p = (np.abs(np.random.normal(0, 1, [3])) + 1).astype(np.float32)
1468    i0 = (np.abs(np.random.normal(0, 1, [3])) + 1).astype(np.float32)
1469    i1 = (np.abs(np.random.normal(0, 1, [3])) + 1).astype(np.float32)
1470    i2 = (np.abs(np.random.normal(0, 1, [3])) + 1).astype(np.float32)
1471
1472    net = Net(Tensor(p))
1473    r2, r1 = net(Tensor(i0), Tensor(i1), Tensor(i2))
1474
1475    outputs = [r2.asnumpy(), r1.asnumpy(), net.p.data.asnumpy()]
1476    expects = numpy_out(p, i0, i1, i2)
1477    np.testing.assert_array_equal(outputs, expects)
1478
1479
1480@pytest.mark.level0
1481@pytest.mark.platform_arm_ascend_training
1482@pytest.mark.platform_x86_ascend_training
1483@pytest.mark.env_onecard
1484def test_multi_abs_add_assign():
1485    class Net(Cell):
1486        def __init__(self, para):
1487            super(Net, self).__init__()
1488            self.add = P.Add()
1489            self.sub = P.Sub()
1490            self.mul = P.Mul()
1491            self.abs = P.Abs()
1492            self.assign = P.Assign()
1493            self.p = Parameter(para, name='para')
1494
1495        def construct(self, a, d, e):
1496            tmp = self.abs(self.add(self.abs(a), self.abs(self.mul(a, a))))
1497            res1 = self.add(self.p, tmp)
1498            mul = self.mul(d, e)
1499            self.assign(self.p, mul)
1500            res2 = self.sub(self.p, e)
1501            return res2, res1, tmp
1502
1503    def numpy_out(p, a, d, e):
1504        tmp = np.abs(np.abs(a) + np.abs(a * a))
1505        res1 = p + tmp
1506        res_as = d * e
1507        res2 = d * e - e
1508        return res2, res1, res_as, tmp
1509
1510    p = -(np.abs(np.random.normal(0, 1, [3])) + 1).astype(np.float32)
1511    i0 = -(np.abs(np.random.normal(0, 1, [3])) + 1).astype(np.float32)
1512    i1 = -(np.abs(np.random.normal(0, 1, [3])) + 1).astype(np.float32)
1513    i2 = -(np.abs(np.random.normal(0, 1, [3])) + 1).astype(np.float32)
1514
1515    net = Net(Tensor(p))
1516    r2, r1, tmp = net(Tensor(i0), Tensor(i1), Tensor(i2))
1517
1518    outputs = [r2.asnumpy(), r1.asnumpy(), net.p.data.asnumpy(), tmp.asnumpy()]
1519    expects = numpy_out(p, i0, i1, i2)
1520    np.testing.assert_array_equal(outputs, expects)
1521