• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test control ops """
16import os
17import numpy as np
18import pytest
19
20import mindspore as ms
21from mindspore import Tensor
22from mindspore import context
23from mindspore import nn
24from mindspore.common import dtype as mstype
25from mindspore.ops import composite as C
26from mindspore.ops import functional as F
27from mindspore.ops import operations as P
28from mindspore.common.parameter import Parameter, ParameterTuple
29from mindspore.common import ms_function
30
31context.set_context(mode=context.GRAPH_MODE)
32
33grad_by_list = C.GradOperation(get_by_list=True)
34grad_all = C.GradOperation(get_all=True)
35grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True)
36
37
38def cond_data_test(x_init, y_init):
39    class Net(nn.Cell):
40        def __init__(self):
41            """"""
42            super(Net, self).__init__()
43            self.square = P.Square()
44            self.add = P.Add()
45            self.value = Tensor(3, dtype=ms.float32)
46            self.switch = P.GeSwitch()
47            self.merge = P.Merge()
48            self.less = P.Less()
49
50        def construct(self, x, y):
51            cond = self.less(x, y)
52            st1, _ = self.switch(x, cond)
53            st2, _ = self.switch(y, cond)
54            add_ret = self.add(st1, st2)
55            _, sf3 = self.switch(self.value, cond)
56            sq_ret = self.square(sf3)
57            ret = self.merge((add_ret, sq_ret))
58            return ret[0]
59
60    x = Tensor(x_init, dtype=ms.float32)
61    y = Tensor(y_init, dtype=ms.float32)
62    net = Net()
63    output = net(x, y)
64    return output
65
66
67def test_cond_data_true():
68    output = cond_data_test(3, 8)
69    print("test_cond_data_true:", output)
70
71
72def test_cond_data_false():
73    output = cond_data_test(8, 3)
74    print("test_cond_data_false:", output)
75
76
77def if_compile_test(x_init, y_init):
78    class Net(nn.Cell):
79        def __init__(self):
80            """"""
81            super(Net, self).__init__()
82            self.square = P.Square()
83            self.add = P.Add()
84            self.value = Tensor(3, dtype=ms.float32)
85            self.switch = P.GeSwitch()
86            self.merge = P.Merge()
87            self.less = P.Less()
88
89        def construct(self, x, y):
90            cond = self.less(x, y)
91            ret = self.value
92            if cond:
93                ret = self.add(x, ret)
94                ret = self.add(y, ret)
95            else:
96                ret = self.square(self.value)
97            return ret
98
99    x = Tensor(x_init, dtype=ms.float32)
100    y = Tensor(y_init, dtype=ms.float32)
101    net = Net()
102    output = net(x, y)
103    return output
104
105
106def test_if_none():
107    class Net(nn.Cell):
108        def __init__(self, z: None):
109            """"""
110            super(Net, self).__init__()
111            self.z = z
112
113        def construct(self, x, y):
114            if self.z:
115                ret = x
116            else:
117                ret = y
118            return ret
119
120    x = Tensor(np.ones([6, 8, 10], np.int32))
121    y = Tensor(np.zeros([3, 4, 5], np.int32))
122    z = None
123    net = Net(z)
124    assert np.all(net(x, y).asnumpy() == y.asnumpy())
125
126
127def test_if_str_is_not_none_right():
128    class Net(nn.Cell):
129        def __init__(self, z: str):
130            """"""
131            super(Net, self).__init__()
132            self.z = z
133
134        def construct(self, x, y):
135            if self.z is None:
136                ret = x
137            else:
138                ret = y
139            return ret
140
141    x = Tensor(np.ones([6, 8, 10], np.int32))
142    y = Tensor(np.zeros([3, 4, 5], np.int32))
143    z = "ok"
144    net = Net(z)
145    assert np.all(net(x, y).asnumpy() == y.asnumpy())
146
147
148def test_if_str_is_not_none_left():
149    class Net(nn.Cell):
150        def __init__(self, z: str):
151            """"""
152            super(Net, self).__init__()
153            self.z = z
154
155        def construct(self, x, y):
156            if self.z is None:
157                ret = x
158            else:
159                ret = y
160            return ret
161
162    x = Tensor(np.ones([6, 8, 10], np.int32))
163    y = Tensor(np.zeros([3, 4, 5], np.int32))
164    z = "ok"
165    net = Net(z)
166    assert np.all(net(x, y).asnumpy() == y.asnumpy())
167
168
169def test_if_none_equal_none():
170    class Net(nn.Cell):
171        def __init__(self, z: None):
172            """"""
173            super(Net, self).__init__()
174            self.z = z
175
176        def construct(self, x, y):
177            if self.z is None:
178                ret = x
179            else:
180                ret = y
181            return ret
182
183    x = Tensor(np.ones([6, 8, 10], np.int32))
184    y = Tensor(np.zeros([3, 4, 5], np.int32))
185    z = None
186    net = Net(z)
187    assert np.all(net(x, y).asnumpy() == x.asnumpy())
188
189
190def test_if_str_is_null():
191    class Net(nn.Cell):
192        def __init__(self, z: str):
193            """"""
194            super(Net, self).__init__()
195            self.z = z
196
197        def construct(self, x, y):
198            if self.z:
199                ret = x
200            else:
201                ret = y
202            return ret
203
204    x = Tensor(np.ones([6, 8, 10], np.int32))
205    y = Tensor(np.zeros([3, 4, 5], np.int32))
206    z = ""
207    net = Net(z)
208    assert np.all(net(x, y).asnumpy() == y.asnumpy())
209
210
211def test_if_str_is_true():
212    class Net(nn.Cell):
213        def __init__(self, z: str):
214            """"""
215            super(Net, self).__init__()
216            self.z = z
217
218        def construct(self, x, y):
219            if self.z:
220                ret = x
221            else:
222                ret = y
223            return ret
224
225    x = Tensor(np.ones([6, 9, 10], np.int32))
226    y = Tensor(np.zeros([3, 4, 5], np.int32))
227    z = "ok"
228    net = Net(z)
229    assert np.all(net(x, y).asnumpy() == x.asnumpy())
230
231
232def test_if_str_equal():
233    class Net(nn.Cell):
234        def __init__(self, z: str):
235            """"""
236            super(Net, self).__init__()
237            self.z = z
238
239        def construct(self, x, y):
240            if self.z == "ok":
241                ret = x
242            else:
243                ret = y
244            return ret
245
246    x = Tensor(np.ones([6, 8, 10], np.int32))
247    y = Tensor(np.zeros([3, 4, 5], np.int32))
248    z = "ok"
249    net = Net(z)
250    assert np.all(net(x, y).asnumpy() == x.asnumpy())
251
252
253def test_if_tuple_is_null():
254    class Net(nn.Cell):
255        def __init__(self, z: tuple):
256            """"""
257            super(Net, self).__init__()
258            self.z = z
259
260        def construct(self, x, y):
261            if self.z:
262                ret = x
263            else:
264                ret = y
265            return ret
266
267    x = Tensor(np.ones([6, 8, 10], np.int32))
268    y = Tensor(np.zeros([3, 4, 5], np.int32))
269    z = ()
270    net = Net(z)
271    assert np.all(net(x, y).asnumpy() == y.asnumpy())
272
273
274def test_if_tuple_is_not_null():
275    class Net(nn.Cell):
276        def __init__(self, z: tuple):
277            """"""
278            super(Net, self).__init__()
279            self.z = z
280
281        def construct(self, x, y):
282            if self.z:
283                ret = x
284            else:
285                ret = y
286            return ret
287
288    x = Tensor(np.ones([6, 8, 10], np.int32))
289    y = Tensor(np.zeros([3, 4, 5], np.int32))
290    z = (1, 2, 3)
291    net = Net(z)
292    assert np.all(net(x, y).asnumpy() == x.asnumpy())
293
294
295def test_if_dict_is_null():
296    class Net(nn.Cell):
297        def __init__(self, z: dict):
298            """"""
299            super(Net, self).__init__()
300            self.z = z
301
302        def construct(self, x, y):
303            if self.z:
304                ret = x
305            else:
306                ret = y
307            return ret
308
309    x = Tensor(np.ones([6, 8, 10], np.int32))
310    y = Tensor(np.zeros([3, 4, 5], np.int32))
311    z = {}
312    net = Net(z)
313    assert np.all(net(x, y).asnumpy() == y.asnumpy())
314
315
316def test_if_dict_is_not_null():
317    class Net(nn.Cell):
318        def __init__(self, z: dict):
319            """"""
320            super(Net, self).__init__()
321            self.z = z
322
323        def construct(self, x, y):
324            if self.z:
325                ret = x
326            else:
327                ret = y
328            return ret
329
330    x = Tensor(np.ones([6, 8, 10], np.int32))
331    y = Tensor(np.zeros([3, 4, 5], np.int32))
332    z = {"one": 1, "two": 2}
333    net = Net(z)
334    assert np.all(net(x, y).asnumpy() == x.asnumpy())
335
336
337def test_if_else_assign():
338    class Net(nn.Cell):
339        def __init__(self, m: list):
340            """"""
341            super(Net, self).__init__()
342            self.m = m
343            self.n = [4, 5, 6]
344
345        def construct(self, x, y):
346            exp_1 = self.m if self.m else self.n
347            exp_2 = self.m if exp_1 == self.n else self.n
348            if exp_2 == self.m:
349                if self.m:
350                    ret = x
351                else:
352                    ret = y
353            else:
354                if self.m:
355                    ret = x
356                else:
357                    ret = y
358            return ret
359
360    x = Tensor(np.ones([6, 8, 10], np.int32))
361    y = Tensor(np.zeros([3, 4, 5], np.int32))
362    z = [1, 2]
363    net = Net(z)
364    assert np.all(net(x, y).asnumpy() == x.asnumpy())
365
366
367def test_if_compile_true():
368    output = if_compile_test(3, 8)
369    print("test_if_compile_true:", output)
370
371
372def test_if_compile_false():
373    output = if_compile_test(8, 3)
374    print("test_if_compile_false:", output)
375
376
377def test_switch_layer():
378    class Layer1(nn.Cell):
379        def __init__(self):
380            super(Layer1, self).__init__()
381            self.z1 = Parameter(
382                Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z1')
383
384        def construct(self, x):
385            return x * self.z1
386
387    class Layer2(nn.Cell):
388        def __init__(self):
389            super(Layer2, self).__init__()
390            self.z2 = Parameter(
391                Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z2')
392
393        def construct(self, x):
394            return x * self.z2
395
396    class SwitchLayerCell(nn.Cell):
397        def __init__(self):
398            super(SwitchLayerCell, self).__init__()
399            self.layers = (Layer1(), Layer2())
400            self.z3 = Parameter(
401                Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3')
402
403        def construct(self, index, x):
404            ret = F.switch_layer(index, self.layers)(x) * self.z3
405            return ret
406
407    index = Tensor(0, dtype=mstype.int32)
408    net = SwitchLayerCell()
409    net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
410    grad_by_list(net, ParameterTuple(net.trainable_params()))(index,
411                                                              Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
412    grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
413
414
415def test_index_to_switch_layer():
416    class Layer1(nn.Cell):
417        def __init__(self):
418            super(Layer1, self).__init__()
419            self.z1 = Parameter(
420                Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z1')
421
422        def construct(self, x):
423            return x * self.z1
424
425    class Layer2(nn.Cell):
426        def __init__(self):
427            super(Layer2, self).__init__()
428            self.z2 = Parameter(
429                Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z2')
430
431        def construct(self, x):
432            return x * self.z2
433
434    class SwitchLayerCell(nn.Cell):
435        def __init__(self):
436            super(SwitchLayerCell, self).__init__()
437            self.layers = (Layer1(), Layer2())
438            self.z3 = Parameter(
439                Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3')
440
441        def construct(self, index, x):
442            ret = self.layers[index](x) * self.z3
443            return ret
444
445    index = Tensor(0, dtype=mstype.int32)
446    net = SwitchLayerCell()
447    net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
448    grad_by_list(net, ParameterTuple(net.trainable_params()))(index,
449                                                              Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
450    grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
451
452
453def test_parser_switch_layer_switch_in_bprop():
454    class OneInputBprop(nn.Cell):
455        def __init__(self, funcs):
456            super(OneInputBprop, self).__init__()
457            self.op = P.ReLU()
458            self.funcs = funcs
459
460        def construct(self, i, x):
461            return self.op(x)
462
463        def bprop(self, i, x, out, dout):
464            return i, self.funcs[i](x, dout)
465
466    class Add(nn.Cell):
467        def __init__(self):
468            super().__init__()
469            self.op = P.Add()
470
471        def construct(self, x, y):
472            return self.op(x, y)
473
474    class Mul(nn.Cell):
475        def __init__(self):
476            super().__init__()
477            self.op = P.Mul()
478
479        def construct(self, x, y):
480            return self.op(x, y)
481
482    func1 = Add()
483    func2 = Mul()
484    funcs = (func1, func2)
485    net = OneInputBprop(funcs)
486    input1 = Tensor(np.ones([2, 2]).astype(np.float32))
487    grad = Tensor(np.random.randn(2, 2).astype(np.float32))
488    i = Tensor(1, mstype.int32)
489    grad_net = grad_all_with_sens(net)
490    grad_net(i, input1, grad)
491
492
493def test_parser_switch_layer_inputs_tuple():
494    class TwoInputTupleFinalNet(nn.Cell):
495        def __init__(self, funcs):
496            super().__init__()
497            self.funcs = funcs
498
499        def construct(self, i, inputa, inputb):
500            inputs = (inputa, inputb)
501            x = self.funcs[i](inputs)
502            return x
503
504    class Add(nn.Cell):
505        def __init__(self):
506            super().__init__()
507            self.op = P.Add()
508
509        def construct(self, x):
510            y = self.op(x[0], x[1])
511            return self.op(x[0], y)
512
513    class Mul(nn.Cell):
514        def __init__(self):
515            super().__init__()
516            self.op = P.Mul()
517
518        def construct(self, x):
519            y = self.op(x[0], x[1])
520            return self.op(x[0], y)
521
522    func1 = Add()
523    func2 = Mul()
524
525    funcs = (func1, func2)
526    net = TwoInputTupleFinalNet(funcs)
527
528    input1 = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
529    input2 = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
530    i = Tensor(1, mstype.int32)
531    grad = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
532    back_net = grad_all_with_sens(net)
533    back_out = back_net(i, input1, input2, grad)
534
535
536def test_switch_layer_with_single_prim():
537    class SwitchLayerCell(nn.Cell):
538        def __init__(self):
539            super(SwitchLayerCell, self).__init__()
540            self.layers = (nn.ReLU(), nn.ReLU())
541            self.z3 = Parameter(
542                Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3')
543
544        def construct(self, index, x):
545            ret = self.layers[index](x) * self.z3
546            return ret
547
548    index = Tensor(0, dtype=mstype.int32)
549    net = SwitchLayerCell()
550    net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
551    grad_by_list(net, ParameterTuple(net.trainable_params()))(index,
552                                                              Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
553    grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
554
555
556def test_switch_layer_env_eliminate():
557    class Net(nn.Cell):
558        def __init__(self):
559            super(Net, self).__init__()
560            self.conv = nn.Conv2d(1, 1, 3, pad_mode='same')
561            self.conv2 = nn.Conv2d(1, 1, 5, pad_mode='same')
562            self.funs = (self.conv, self.conv2)
563
564        def construct(self, x, index):
565            x = self.funs[index](x)
566            return x
567
568    class NetGrad(nn.Cell):
569        def __init__(self, net):
570            super(NetGrad, self).__init__()
571            self.grad_op = C.GradOperation(get_by_list=True, sens_param=False)
572            self.net = net
573            self.weights = ParameterTuple(self.net.trainable_params())
574
575        def construct(self, x, index):
576            weights = self.weights
577            grad = self.grad_op(self.net, weights)(x, index)
578            return grad
579
580    net = Net()
581    net2 = NetGrad(net)
582    x = Tensor(np.ones((3, 1, 12, 12)), ms.float32)
583    i = Tensor(1, ms.int32)
584    net2(x, i)
585
586
587def test_switch_layer_single_layer():
588    class Net(nn.Cell):
589        def __init__(self):
590            super(Net, self).__init__()
591            self.conv = nn.Conv2d(1, 1, 3, pad_mode='same')
592            self.funs = (self.conv,)
593
594        def construct(self, x, index):
595            x = self.funs[index](x)
596            return x
597
598    class NetGrad(nn.Cell):
599        def __init__(self, net):
600            super(NetGrad, self).__init__()
601            self.grad_op = C.GradOperation(get_by_list=True, sens_param=False)
602            self.net = net
603            self.weights = ParameterTuple(self.net.trainable_params())
604
605        def construct(self, x, index):
606            weights = self.weights
607            grad = self.grad_op(self.net, weights)(x, index)
608            return grad
609
610    net = Net()
611    net2 = NetGrad(net)
612    x = Tensor(np.ones((3, 1, 12, 12)), ms.float32)
613    i = Tensor(1, ms.int32)
614    net2(x, i)
615
616
617def test_if_nested_compile():
618    class Net(nn.Cell):
619        def __init__(self, auto_prefix=True):
620            super().__init__(auto_prefix=auto_prefix)
621            self.squre = P.Square()
622            self.value = Tensor(3, dtype=ms.float32)
623
624        def construct(self, x, y):
625            res = self.value
626            if x <= y:
627                res = x + res
628                res = y + res
629            else:
630                if x == y:
631                    res = self.squre(self.value * y)
632                else:
633                    res = self.squre(self.value)
634            return res
635
636    x = Tensor(1.0, dtype=ms.float32)
637    y = Tensor(2.0, dtype=ms.float32)
638    net = Net()
639    net(x, y)
640
641
642def test_if_inside_for():
643    class Net(nn.Cell):
644        def __init__(self, auto_prefix=True):
645            super().__init__(auto_prefix=auto_prefix)
646            self.squre = P.Square()
647            self.value = Tensor(3, dtype=ms.float32)
648            self.count = 4
649
650        def construct(self, x, y):
651            res = 0
652            for i in range(self.count):
653                if i == x:
654                    res = res + x
655                else:
656                    res = res - y
657            return res
658
659    c1 = Tensor(1, dtype=ms.int32)
660    c2 = Tensor(1, dtype=ms.int32)
661    net = Net()
662    net(c1, c2)
663
664
665def test_while_in_while():
666    c1 = Tensor(1, dtype=ms.int32)
667    c2 = Tensor(2, dtype=ms.int32)
668    c3 = Tensor(3, dtype=ms.int32)
669    c4 = Tensor(4, dtype=ms.int32)
670
671    @ms_function
672    def while_in_while(x, y, z, u):
673        out = c4
674        while x < y:
675            z = c4 + c4
676            while z < y:
677                z = z + 1
678                out = out + 1
679            x = x + 1
680
681        out = out + 3
682        return out
683
684    while_in_while(c1, c2, c3, c4)
685
686
687def test_tensor_cond():
688    class Net(nn.Cell):
689        def __init__(self):
690            super(Net, self).__init__()
691            self.t = Tensor(np.array(0, np.bool))
692            self.t1 = Tensor(np.array([True], np.bool))
693
694        def construct(self, x, y):
695            t = 0
696            if self.t:
697                t = t - x * y
698            else:
699                t = t - x / y
700            if self.t1:
701                t = t + x / y
702            else:
703                t = t + x * y
704            return t
705
706    x = Tensor(np.ones([6, 8, 10], np.int32))
707    y = Tensor(np.ones([6, 8, 10], np.int32))
708    net = Net()
709    out = net(x, y)
710
711
712def test_tensor_cond_exception():
713    class Net(nn.Cell):
714        def __init__(self):
715            super(Net, self).__init__()
716            self.t = Tensor(np.array([True, False], np.bool))
717
718        def construct(self, x, y):
719            t = 0
720            if self.t:
721                t = t - x * y
722            else:
723                t = t - x / y
724            return t
725
726    x = Tensor(np.ones([6, 8, 10], np.int32))
727    y = Tensor(np.ones([6, 8, 10], np.int32))
728    net = Net()
729    with pytest.raises(ValueError):
730        out = net(x, y)
731
732
733def test_while_scalar():
734    class Net(nn.Cell):
735        def __init__(self):
736            super(Net, self).__init__()
737            self.x = 10
738
739        def construct(self, x, y):
740            i = 0
741            t = 0
742            while (i < 10):
743                t = t + x + y
744                i = i + 1
745            return t
746
747    net = Net()
748    x = Tensor(np.ones([6, 8, 10], np.int32))
749    y = Tensor(np.ones([6, 8, 10], np.int32))
750    out = net(x, y)
751
752
753def test_while_with_weight_in_condition():
754    class Net(nn.Cell):
755        def __init__(self):
756            super(Net, self).__init__()
757            self.loop = Parameter(Tensor(1, dtype=ms.float32), name="loop")
758
759        def construct(self, x):
760            while self.loop < 5:
761                self.loop += 1
762                x += 1
763            return x
764
765    net = Net()
766    x = Tensor(-1, dtype=ms.float32)
767    grad_all(net)(x)
768
769
770def test_mixed_precision_cast():
771    x = Tensor(np.ones([2, 3], dtype=np.float32))
772    z = F.mixed_precision_cast(mstype.float16, x)
773    assert z.dtype == mstype.float16
774
775
776def test_while_add():
777    class Net(nn.Cell):
778        def __init__(self, data):
779            super(Net, self).__init__()
780            self.start = Tensor(0, dtype=mstype.int32)
781            self.end = Tensor(2, dtype=mstype.int32)
782            self.out = Tensor(np.zeros([2, 3], dtype=np.float32))
783            self.add = P.Add()
784
785        def construct(self, inputs):
786            idx = self.start
787            end = self.end
788            out = self.out
789            while idx < end:
790                xi = inputs[idx, :, :]
791                out = self.add(out, xi)
792                idx = idx + 1
793            return out
794
795    x = Tensor(np.arange(10 * 2 * 3).reshape(10, 2, 3).astype(np.float32))
796    net = Net(x)
797    net(x)
798
799
800def test_tensor_all_construct_lack_branch():
801    class NetConditionLackBranch(nn.Cell):
802        def __init__(self):
803            super(NetConditionLackBranch, self).__init__()
804            self.logicaland = P.LogicalAnd()
805            self.logicalor = P.LogicalOr()
806
807        def construct(self, input1, input2):
808            if input1.all():
809                return self.logicaland(input1, input2)
810            while input1.any():
811                return self.logicalor(input1, input2)
812            # NOTICE: here missing return statement, default return None
813
814    input_np_1 = np.random.choice([True], size=(2, 3, 4, 5))
815    input_tensor_1 = Tensor(input_np_1)
816    input_np_2 = np.random.choice([True, False], size=(2, 3, 4, 5))
817    input_tensor_2 = Tensor(input_np_2)
818    net = NetConditionLackBranch()
819    with pytest.raises(Exception):
820        net(input_tensor_1, input_tensor_2)
821
822
823def test_parser_switch_layer_func_primitive():
824    class FinalNet(nn.Cell):
825        def __init__(self, funcs):
826            super().__init__()
827            self.funcs = funcs
828
829        def construct(self, i, input1):
830            x = self.funcs[i](input1)
831            return x
832
833    func1 = P.ReLU()
834    func2 = P.Softmax()
835    funcs = (func1, func2)
836    net = FinalNet(funcs)
837
838    input1 = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
839    i = Tensor(1, mstype.int32)
840
841    with pytest.raises(ValueError):
842        net(i, input1)
843
844
845def test_switch_layer_shape_join_failed():
846    class AddFuncNet(nn.Cell):
847        def __init__(self, funcs, new_func):
848            super(AddFuncNet, self).__init__()
849            self.funcs = funcs
850            self.new_func = new_func
851
852        def construct(self, i, inputs):
853            final_funcs = self.funcs + (self.new_func,)
854            x = final_funcs[i](inputs)
855            return x
856
857    class ReLUTuple(nn.Cell):
858        def __init__(self):
859            super(ReLUTuple, self).__init__()
860            self.op = nn.ReLU()
861
862        def construct(self, x):
863            return self.op(x[0])
864
865    func1 = nn.Softmax()
866    func2 = nn.ReLU()
867    func3 = ReLUTuple()
868
869    funcs = (func1, func2)
870
871    net = AddFuncNet(funcs, func3)
872
873    inp = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
874    i = Tensor(1, mstype.int32)
875    with pytest.raises(ValueError) as err:
876        net(i, inp)
877
878
879def test_switch_layer_dtype_join_failed():
880    class Cast(nn.Cell):
881        def __init__(self, dtype):
882            super(Cast, self).__init__()
883            self.op = P.Cast()
884            self.dtype = dtype
885
886        def construct(self, x):
887            y = self.op(x, self.dtype)
888            return y + y
889
890    class SwitchNegNet(nn.Cell):
891        def __init__(self, funcs):
892            super(SwitchNegNet, self).__init__()
893            self.funcs = funcs
894            self.op = P.Neg()
895
896        def construct(self, i, inputs):
897            x = self.funcs[i](inputs)
898            x = self.op(x)
899            return x
900
901    func1 = nn.ReLU()
902    func2 = Cast(mstype.int32)
903    funcs = (func1, func2)
904    net = SwitchNegNet(funcs)
905
906    inp = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
907    i = Tensor(0, mstype.int32)
908
909    with pytest.raises(TypeError) as err:
910        net(i, inp)
911
912
913def test_large_for_loop():
914    class Net(nn.Cell):
915        def __init__(self):
916            super(Net, self).__init__()
917            self.flatten = P.ReLU()  # nn.Flatten()
918
919        def construct(self, x):
920            for elem in range(1, 1900):
921                x = self.flatten(x + elem)
922            return x
923
924    t = Tensor(np.ones([2, 3], dtype=np.float32))
925    net = Net()
926    os.environ['ENV_RECURSIVE_EVAL'] = '1'
927    old_max_call_depth = context.get_context('max_call_depth')
928    context.set_context(max_call_depth=60)
929    with pytest.raises(RuntimeError) as err:
930        net(t)
931    context.set_context(max_call_depth=old_max_call_depth)
932    os.environ['ENV_RECURSIVE_EVAL'] = '0'
933    assert 'Exceed function call depth limit 60' in str(err.value)
934
935
936def test_large_for_loop_case2():
937    class Menet(nn.Cell):
938        def __init__(self, axis, flag_boottom, flag_top):
939            super(Menet, self).__init__()
940            self.squeeze = P.Squeeze(axis)
941            self.expanddims = P.ExpandDims()
942            self.flatten = nn.Flatten()
943            self.neg = P.Neg()
944            self.axis = axis
945            self.flag_boottom = flag_boottom
946            self.flag_top = flag_top
947
948        def construct(self, x):
949            if self.flag_boottom:
950                x = self.neg(x)
951            for i in range(0, 1500):
952                x = self.expanddims(x, self.axis)
953                x = self.squeeze(x)
954                x = self.flatten(x)
955            if self.flag_top:
956                x = self.neg(x)
957            return x
958
959    x = Tensor(np.ones([2, 3], dtype=np.float32))
960    net = Menet(axis=0, flag_boottom=True, flag_top=True)
961    os.environ['ENV_RECURSIVE_EVAL'] = '1'
962    old_max_call_depth = context.get_context('max_call_depth')
963    context.set_context(max_call_depth=80)
964    with pytest.raises(RuntimeError) as err:
965        net(x)
966    os.environ['ENV_RECURSIVE_EVAL'] = '0'
967    context.set_context(max_call_depth=old_max_call_depth)
968    assert 'Exceed function call depth limit 80' in str(err.value)
969
970
971def test_large_for_loop_with_continue_break():
972    class Net(nn.Cell):
973        def __init__(self):
974            super(Net, self).__init__()
975            self.flatten = P.ReLU()  # nn.Flatten()
976
977        def construct(self, x):
978            idx = 0
979            for elem1 in range(200):
980                idx = idx + 1
981                if idx < 10:
982                    x = x + 0.5
983                    continue
984                if idx > 500:
985                    break
986                x = self.flatten(x + elem1)
987            return x
988
989    os.environ['ENV_RECURSIVE_EVAL'] = '1'
990    old_max_call_depth = context.get_context('max_call_depth')
991    context.set_context(max_call_depth=2000)
992    t = Tensor(np.ones([2, 3], dtype=np.float32))
993    net = Net()
994    net(t)
995    os.environ['ENV_RECURSIVE_EVAL'] = '0'
996    context.set_context(max_call_depth=old_max_call_depth)
997
998
999def test_recursive_call():
1000    class Net(nn.Cell):
1001        """ Net definition """
1002
1003        def __init__(self):
1004            super(Net, self).__init__()
1005            self.fc = nn.Dense(10, 10)  # padding=0
1006            # self.net2 = Net2()
1007
1008        def construct(self, x):
1009            net2 = Net2()
1010            x = net2(x)
1011            out = self.fc(x)
1012            return out
1013
1014    class Net2(nn.Cell):
1015        def __init__(self):
1016            super(Net2, self).__init__()
1017            self.net = Net()
1018            self.fc = nn.Dense(10, 10)
1019
1020        def construct(self, x):
1021            x = self.net(x)
1022            out = self.fc(x)
1023            return out
1024
1025    context.set_context(mode=context.GRAPH_MODE)
1026    os.environ['ENV_RECURSIVE_EVAL'] = '1'
1027    old_max_call_depth = context.get_context('max_call_depth')
1028    context.set_context(max_call_depth=80)
1029    input_data = Tensor(np.identity(10).astype(np.float32))
1030    net = Net2()
1031    with pytest.raises(RuntimeError):
1032        net(input_data)
1033    os.environ['ENV_RECURSIVE_EVAL'] = '0'
1034    context.set_context(max_call_depth=old_max_call_depth)
1035
1036
1037# grad for Tensor(Bool) input and eliminate AddN(MakeTuple(Xs, zeros_like(Bool)))
1038def test_grad_tensor_bool():
1039    class Net(nn.Cell):
1040        def __init__(self):
1041            super(Net, self).__init__()
1042
1043        def construct(self, x, y, z):
1044            out = z
1045            while x:
1046                out = out + z
1047                x = y
1048            return out
1049
1050    x = Tensor(np.array(False).astype(np.bool))
1051    y = Tensor(np.array(False).astype(np.bool))
1052    z = Tensor(np.ones([2, 3], dtype=np.float32))
1053    net = grad_all(Net())
1054    net(x, y, z)
1055