• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" test control ops """
16import pytest
17import numpy as np
18from mindspore import dtype as ms
19from mindspore import Tensor
20from mindspore import context
21from mindspore import nn
22from mindspore import ms_function
23from mindspore.common.parameter import Parameter, ParameterTuple
24from mindspore.ops import composite as C
25from mindspore.ops import operations as P
26# from tests.vm_impl.math_ops_vm_impl import *
27# from tests.vm_impl.vm_interface import *
28# from tests.vm_impl import *
29
30grad_by_list = C.GradOperation(get_by_list=True)
31grad_all = C.GradOperation(get_all=True)
32
33
34@pytest.fixture(scope="module", autouse=True)
35def setup_teardown():
36    context.set_context(mode=context.PYNATIVE_MODE, precompile_only=True)
37    yield
38    context.set_context(mode=context.GRAPH_MODE, precompile_only=False)
39
40
41def test_while_with_param_forward_with_const_branch():
42    class MyWhileNet(nn.Cell):
43        def __init__(self):
44            super().__init__()
45            self.max = P.ReduceMax()
46            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
47            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
48            self.reduce = P.ReduceSum()
49
50        @ms_function
51        def construct(self, idx, end, x):
52            out = self.zero
53            while idx < end:
54                if 2 > 1:
55                    out = out + self.param
56                else:
57                    out = out + idx + self.param
58                idx = idx + 1
59            return out
60
61    while_net = MyWhileNet()
62    net = while_net
63    idx = Tensor(np.array(0), dtype=ms.int32)
64    end = Tensor(np.array(4), dtype=ms.int32)
65    x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
66    net(idx, end, x)
67
68
69def test_while_opt_endless():
70    """endless during optimization case"""
71    class MyWhileNet(nn.Cell):
72        def __init__(self):
73            super().__init__()
74            self.max = P.ReduceMax()
75            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
76            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
77            self.reduce = P.ReduceSum()
78            self.addn = P.AddN()
79
80        def construct(self, idx, end, x):
81            addn1 = self.addn((x, x, x))
82            out = addn1
83            while idx < end:
84                out = self.addn((out, addn1))
85                idx = idx + 1
86            out = self.addn((out, x))
87            return out
88
89    class GradNet(nn.Cell):
90        def __init__(self, net):
91            super(GradNet, self).__init__()
92            self.net = net
93
94        @ms_function
95        def construct(self, *inputs):
96            return grad_all(self.net)(*inputs)
97
98    while_net = MyWhileNet()
99    net = GradNet(while_net)
100    idx = Tensor(np.array(0), dtype=ms.int32)
101    end = Tensor(np.array(4), dtype=ms.int32)
102    x = Tensor(np.ones([2, 2, 2]).astype(np.float32) * 3, dtype=ms.float32)
103    net(idx, end, x)
104
105
106def test_no_while_call():
107    class MyWhileNet(nn.Cell):
108        def __init__(self):
109            super().__init__()
110            self.max = P.ReduceMax()
111            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
112            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
113            self.reduce = P.ReduceSum()
114
115        @ms_function
116        def construct(self, idx, end, x):
117            out = self.zero
118            if 2 > 1:
119                out = out + self.param
120            else:
121                out = out + idx + self.param
122            return out
123
124    while_net = MyWhileNet()
125    net = while_net
126    idx = Tensor(np.array(0), dtype=ms.int32)
127    end = Tensor(np.array(4), dtype=ms.int32)
128    x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
129    net(idx, end, x)
130
131
132def test_while_with_param_grad_with_const_branch():
133    class MyWhileNet(nn.Cell):
134        def __init__(self):
135            super().__init__()
136            self.max = P.ReduceMax()
137            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
138            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
139            self.reduce = P.ReduceSum()
140
141        def construct(self, idx, end, x):
142            out = self.zero
143            while idx < end:
144                if 2 > 1:
145                    out = out + self.param
146                else:
147                    out = out + idx + self.param
148                idx = idx + 1
149            return out
150
151    class GradNet(nn.Cell):
152        def __init__(self, net):
153            super(GradNet, self).__init__()
154            self.net = net
155            self.weights = ParameterTuple(net.trainable_params())
156
157        @ms_function
158        def construct(self, a, b, c):
159            return grad_by_list(self.net, self.weights)(a, b, c)
160
161    while_net = MyWhileNet()
162    net = GradNet(while_net)
163    idx = Tensor(np.array(0), dtype=ms.int32)
164    end = Tensor(np.array(4), dtype=ms.int32)
165    x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
166    net(idx, end, x)
167
168
169def test_for_while_with_param_grad_with_const_branch():
170    class MyWhileNet(nn.Cell):
171        def __init__(self):
172            super().__init__()
173            self.max = P.ReduceMax()
174            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
175            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
176            self.reduce = P.ReduceSum()
177            self.start = Tensor(np.array(0), dtype=ms.int32)
178
179        def construct(self, idx, end, x):
180            out = self.zero
181            for _ in range(0, 2):
182                idx = self.start
183                while idx < end:
184                    if 2 > 1:
185                        out = out + self.param
186                    else:
187                        out = out + idx + self.param
188                    idx = idx + 1
189            return out
190
191    class GradNet(nn.Cell):
192        def __init__(self, net):
193            super(GradNet, self).__init__()
194            self.net = net
195            self.weights = ParameterTuple(net.trainable_params())
196
197        @ms_function
198        def construct(self, a, b, c):
199            return grad_by_list(self.net, self.weights)(a, b, c)
200
201    while_net = MyWhileNet()
202    net = GradNet(while_net)
203    idx = Tensor(np.array(0), dtype=ms.int32)
204    end = Tensor(np.array(4), dtype=ms.int32)
205    x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
206    net(idx, end, x)
207
208
209def test_for_while_with_param_grad_basic():
210    class MyWhileNet(nn.Cell):
211        def __init__(self):
212            super().__init__()
213            self.max = P.ReduceMax()
214            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
215            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
216            self.reduce = P.ReduceSum()
217            self.start = Tensor(np.array(0), dtype=ms.int32)
218
219        def construct(self, idx, end, x):
220            out = self.zero
221            for _ in range(0, 2):
222                idx = self.start
223                while idx < end:
224                    out = out + self.param
225                    idx = idx + 1
226            return out
227
228    class GradNet(nn.Cell):
229        def __init__(self, net):
230            super(GradNet, self).__init__()
231            self.net = net
232            self.weights = ParameterTuple(net.trainable_params())
233
234        @ms_function
235        def construct(self, a, b, c):
236            return grad_by_list(self.net, self.weights)(a, b, c)
237
238    while_net = MyWhileNet()
239    net = GradNet(while_net)
240    idx = Tensor(np.array(0), dtype=ms.int32)
241    end = Tensor(np.array(4), dtype=ms.int32)
242    x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
243    net(idx, end, x)
244
245
246def test_for_while_with_param_grad_normal():
247    class MyWhileNet(nn.Cell):
248        def __init__(self):
249            super().__init__()
250            self.max = P.ReduceMax()
251            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
252            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
253            self.reduce = P.ReduceSum()
254            self.start = Tensor(np.array(0), dtype=ms.int32)
255
256        def construct(self, idx, end, x):
257            out = x
258            for _ in range(0, 2):
259                idx = self.start
260                while idx < end:
261                    out = out + self.param
262                    idx = idx + 1
263            return out
264
265    class GradNet(nn.Cell):
266        def __init__(self, net):
267            super(GradNet, self).__init__()
268            self.net = net
269            self.weights = ParameterTuple(net.trainable_params())
270
271        @ms_function
272        def construct(self, a, b, c):
273            return grad_by_list(self.net, self.weights)(a, b, c)
274
275    while_net = MyWhileNet()
276    net = GradNet(while_net)
277    idx = Tensor(np.array(0), dtype=ms.int32)
278    end = Tensor(np.array(4), dtype=ms.int32)
279    x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
280    net(idx, end, x)
281
282
283def test_while_with_param_basic_grad():
284    class MyWhileNet(nn.Cell):
285        def __init__(self):
286            super().__init__()
287            self.max = P.ReduceMax()
288            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
289            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
290            self.t2 = Tensor(np.array(2), dtype=ms.float32)
291
292        def construct(self, idx, end, x):
293            out = self.zero
294            while idx < end:
295                out = out + self.param
296                idx = idx + 1
297            return out + self.param
298
299    class GradNet(nn.Cell):
300        def __init__(self, net):
301            super(GradNet, self).__init__()
302            self.net = net
303            self.weights = ParameterTuple(net.trainable_params())
304
305        @ms_function
306        def construct(self, a, b, c):
307            return grad_by_list(self.net, self.weights)(a, b, c)
308
309    while_net = MyWhileNet()
310    net = GradNet(while_net)
311    idx = Tensor(np.array(0), dtype=ms.int32)
312    end = Tensor(np.array(3), dtype=ms.int32)
313    x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
314    net(idx, end, x)
315
316
317def test_while_with_param_basic_grad_mul():
318    class MyWhileNet(nn.Cell):
319        def __init__(self):
320            super().__init__()
321            self.max = P.ReduceMax()
322            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
323            self.zero = Tensor(np.ones(([2, 2, 2])), ms.float32)
324            self.t2 = Tensor(np.array(2), dtype=ms.float32)
325
326        def construct(self, idx, end, x):
327            out = self.zero
328            while idx < end:
329                out = out * self.param
330                idx = idx + 1
331            return out + self.param
332
333    class GradNet(nn.Cell):
334        def __init__(self, net):
335            super(GradNet, self).__init__()
336            self.net = net
337            self.weights = ParameterTuple(net.trainable_params())
338
339        @ms_function
340        def construct(self, a, b, c):
341            return grad_by_list(self.net, self.weights)(a, b, c)
342
343    while_net = MyWhileNet()
344    net = GradNet(while_net)
345    idx = Tensor(np.array(0), dtype=ms.int32)
346    end = Tensor(np.array(3), dtype=ms.int32)
347    x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
348    net(idx, end, x)
349
350
351def test_while_with_param_basic_grad_two():
352    class MyWhileNet(nn.Cell):
353        def __init__(self):
354            super().__init__()
355            self.max = P.ReduceMax()
356            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
357            self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss")
358            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
359            self.t2 = Tensor(np.array(2), dtype=ms.float32)
360
361        def construct(self, idx, end, x):
362            out = self.zero
363            while idx < end:
364                out = out + self.param + self.weight
365                idx = idx + 1
366            return out + self.param
367
368    class GradNet(nn.Cell):
369        def __init__(self, net):
370            super(GradNet, self).__init__()
371            self.net = net
372            self.weights = ParameterTuple(net.trainable_params())
373
374        @ms_function
375        def construct(self, a, b, c):
376            return grad_by_list(self.net, self.weights)(a, b, c)
377
378    while_net = MyWhileNet()
379    net = GradNet(while_net)
380    idx = Tensor(np.array(0), dtype=ms.int32)
381    end = Tensor(np.array(3), dtype=ms.int32)
382    x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
383    net(idx, end, x)
384
385
386def test_while_with_param_basic_grad_three():
387    class MyWhileNet(nn.Cell):
388        def __init__(self):
389            super().__init__()
390            self.max = P.ReduceMax()
391            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
392            self.weight = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="loss")
393            self.key = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="key")
394            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
395            self.t2 = Tensor(np.array(2), dtype=ms.float32)
396
397        def construct(self, idx, end, x):
398            out = self.zero
399            while idx < end:
400                out = out + self.param + self.weight + self.key
401                idx = idx + 1
402            return out + self.param
403
404    class GradNet(nn.Cell):
405        def __init__(self, net):
406            super(GradNet, self).__init__()
407            self.net = net
408            self.weights = ParameterTuple(net.trainable_params())
409
410        @ms_function
411        def construct(self, a, b, c):
412            return grad_by_list(self.net, self.weights)(a, b, c)
413
414    while_net = MyWhileNet()
415    net = GradNet(while_net)
416    idx = Tensor(np.array(0), dtype=ms.int32)
417    end = Tensor(np.array(3), dtype=ms.int32)
418    x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
419    net(idx, end, x)
420
421
422def test_while_if_with_param_grad():
423    class MyWhileNet(nn.Cell):
424        def __init__(self):
425            super().__init__()
426            self.max = P.ReduceMax()
427            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
428            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
429            self.t2 = Tensor(np.array(2), dtype=ms.float32)
430
431        def construct(self, idx, end, x):
432            out = self.zero
433            while idx < end:
434                if self.max(out) < self.max(x):
435                    out = out + self.param * 2
436                else:
437                    out = out + self.param
438                idx = idx + 1
439            return out + self.param
440
441    class GradNet(nn.Cell):
442        def __init__(self, net):
443            super(GradNet, self).__init__()
444            self.net = net
445            self.weights = ParameterTuple(net.trainable_params())
446
447        @ms_function
448        def construct(self, a, b, c):
449            return grad_by_list(self.net, self.weights)(a, b, c)
450
451    while_net = MyWhileNet()
452    net = GradNet(while_net)
453    idx = Tensor(np.array(0), dtype=ms.int32)
454    end = Tensor(np.array(3), dtype=ms.int32)
455    x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32)
456    net(idx, end, x)
457
458
459def test_while_with_param_grad_not_enter_while():
460    class MyWhileNet(nn.Cell):
461        def __init__(self):
462            super().__init__()
463            self.max = P.ReduceMax()
464            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
465            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
466
467        def construct(self, idx, end, x):
468            out = self.zero
469            while idx < end:
470                out = out + self.param * 3
471                idx = idx + 1
472            return out + self.param
473
474    class GradNet(nn.Cell):
475        def __init__(self, net):
476            super(GradNet, self).__init__()
477            self.net = net
478            self.weights = ParameterTuple(net.trainable_params())
479
480        @ms_function
481        def construct(self, a, b, c):
482            return grad_by_list(self.net, self.weights)(a, b, c)
483
484    while_net = MyWhileNet()
485    net = GradNet(while_net)
486    idx = Tensor(np.array(3), dtype=ms.int32)
487    end = Tensor(np.array(0), dtype=ms.int32)
488    x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
489    net(idx, end, x)
490
491
492def test_with_param_if_by_if_forward():
493    class MyIfByIfNet(nn.Cell):
494        def __init__(self):
495            super().__init__()
496            self.max = P.ReduceMax()
497            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
498            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
499
500        @ms_function
501        def construct(self, a, b, x):
502            out = self.zero
503            if a < b:
504                out = out + x + self.param
505            else:
506                out = out + x
507            if a == b:
508                out = out + x*3 + self.param
509            else:
510                out = out + x*2
511            return out
512
513    if_net = MyIfByIfNet()
514    net = if_net
515    idx = Tensor(np.array(0), dtype=ms.int32)
516    end = Tensor(np.array(4), dtype=ms.int32)
517    x = Tensor(np.ones([2, 2, 2]).astype(np.float32), dtype=ms.float32)
518    net(idx, end, x)
519
520
521def test_with_param_if_by_if_grad_inputs():
522    class MyIfByIfNet(nn.Cell):
523        def __init__(self):
524            super().__init__()
525            self.max = P.ReduceMax()
526            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
527            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
528
529        def construct(self, a, b, x):
530            out = self.zero
531            if a < b:
532                out = out + x + self.param * 4
533            if a == b:
534                out = out + x*3 + self.param * 3
535            return out
536
537    class GradNet(nn.Cell):
538        def __init__(self, net):
539            super(GradNet, self).__init__()
540            self.net = net
541
542        @ms_function
543        def construct(self, *inputs):
544            return grad_all(self.net)(*inputs)
545
546    if_net = MyIfByIfNet()
547    net = GradNet(if_net)
548    idx = Tensor(np.array(0), dtype=ms.int32)
549    end = Tensor(np.array(0), dtype=ms.int32)
550    x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
551    net(idx, end, x)
552
553
554def test_with_param_if_by_if_grad_parameter():
555    class MyIfByIfNet(nn.Cell):
556        def __init__(self):
557            super().__init__()
558            self.max = P.ReduceMax()
559            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
560            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
561
562        def construct(self, a, b, x):
563            out = self.zero
564            if a < b:
565                out = out + x + self.param * 2
566            if a == b:
567                out = out + x*3 + self.param
568            return out
569
570    class GradNet(nn.Cell):
571        def __init__(self, net):
572            super(GradNet, self).__init__()
573            self.net = net
574            self.weights = ParameterTuple(net.trainable_params())
575
576        @ms_function
577        def construct(self, *inputs):
578            return grad_by_list(self.net, self.weights)(*inputs)
579
580    if_net = MyIfByIfNet()
581    net = GradNet(if_net)
582    idx = Tensor(np.array(0), dtype=ms.int32)
583    end = Tensor(np.array(2), dtype=ms.int32)
584    x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
585    net(idx, end, x)
586
587
588def test_with_param_if_by_if_grad_param_excute_null():
589    class MyIfByIfNet(nn.Cell):
590        def __init__(self):
591            super().__init__()
592            self.max = P.ReduceMax()
593            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
594            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
595
596        def construct(self, a, b, x):
597            out = self.zero
598            if a < b:
599                out = out + x + self.param * 2
600            return out
601
602    class GradNet(nn.Cell):
603        def __init__(self, net):
604            super(GradNet, self).__init__()
605            self.net = net
606            self.weights = ParameterTuple(net.trainable_params())
607
608        @ms_function
609        def construct(self, *inputs):
610            return grad_by_list(self.net, self.weights)(*inputs)
611
612    if_net = MyIfByIfNet()
613    net = GradNet(if_net)
614    idx = Tensor(np.array(4), dtype=ms.int32)
615    end = Tensor(np.array(0), dtype=ms.int32)
616    x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
617    net(idx, end, x)
618
619
620def test_if_by_if_return_inside_grad():
621    class MyIfByIfNet(nn.Cell):
622        def __init__(self):
623            super().__init__()
624            self.max = P.ReduceMax()
625            self.param = Parameter(Tensor(np.arange(2 * 2 * 2).reshape((2, 2, 2)), ms.float32), name="weight")
626            self.zero = Tensor(np.zeros(([2, 2, 2])), ms.float32)
627
628        def construct(self, a, b, x):
629            out = self.zero
630            if a < b:
631                return out + x + self.param
632            if a == b:
633                return out + self.param * 2
634            return out + self.param * 3
635
636    class GradNet(nn.Cell):
637        def __init__(self, net):
638            super(GradNet, self).__init__()
639            self.net = net
640            self.weights = ParameterTuple(net.trainable_params())
641
642        @ms_function
643        def construct(self, *inputs):
644            return grad_by_list(self.net, self.weights)(*inputs)
645
646    if_net = MyIfByIfNet()
647    net = GradNet(if_net)
648    idx = Tensor(np.array(1), dtype=ms.int32)
649    end = Tensor(np.array(0), dtype=ms.int32)
650    x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32)
651    net(idx, end, x)
652
653
654def test_if_by_if_forward():
655    class MyIfByIfNet(nn.Cell):
656        def __init__(self):
657            super().__init__()
658            self.add = P.Add()
659            self.sub = P.Sub()
660            self.mul = P.Mul()
661            self.div = P.RealDiv()
662
663        @ms_function
664        def construct(self, a, b, x):
665            if a < b:
666                a = self.add(a, b)
667            else:
668                a = self.sub(a, b)
669            if a == x:
670                a = self.mul(a, b)
671            else:
672                a = self.div(a, b)
673            if b == x:
674                b = self.add(a, b)
675            else:
676                b = self.add(a, x)
677            a = a * b
678            out = a + b + x
679            return out
680
681    if_net = MyIfByIfNet()
682    net = if_net
683    idx = Tensor(np.array(2), dtype=ms.float32)
684    end = Tensor(np.array(3), dtype=ms.float32)
685    x = Tensor(np.array(4), dtype=ms.float32)
686    net(idx, end, x)
687
688
689def test_if_by_if_forward_control_tuple_switch():
690    """tuple_get from switch op will generate new switch inside to eliminate tuple_get"""
691    class Branch3Net(nn.Cell):
692        def __init__(self):
693            super().__init__()
694            self.add = P.Add()
695            self.sub = P.Sub()
696            self.mul = P.Mul()
697            self.div = P.RealDiv()
698
699        def construct(self, a, b, x):
700            if b == x:
701                b = self.add(a, b)
702            else:
703                b = self.add(a, x)
704            return a, b, x
705
706    class Branch2Net(nn.Cell):
707        def __init__(self):
708            super().__init__()
709            self.add = P.Add()
710            self.sub = P.Sub()
711            self.mul = P.Mul()
712            self.div = P.RealDiv()
713            self.net = Branch3Net()
714
715        def construct(self, a, b, x):
716            if a == x:
717                a = self.mul(a, b)
718            else:
719                a = self.div(a, b)
720            return self.net(a, b, x)
721
722    class MyIfByIfNet(nn.Cell):
723        def __init__(self):
724            super().__init__()
725            self.add = P.Add()
726            self.sub = P.Sub()
727            self.mul = P.Mul()
728            self.div = P.RealDiv()
729            self.net = Branch2Net()
730
731        @ms_function
732        def construct(self, a, b, x):
733            if a < b:
734                a = self.add(a, b)
735            else:
736                a = self.sub(a, b)
737            a, b, x = self.net(a, b, x)
738            a = a * b
739            out = a + b + x
740            return out
741
742    if_net = MyIfByIfNet()
743    net = if_net
744    idx = Tensor(np.array(2), dtype=ms.float32)
745    end = Tensor(np.array(3), dtype=ms.float32)
746    x = Tensor(np.array(0), dtype=ms.float32)
747    net(idx, end, x)
748
749
750def test_if_by_if_forward_control_inside_net():
751    class Branch3Net(nn.Cell):
752        def __init__(self):
753            super().__init__()
754            self.add = P.Add()
755            self.sub = P.Sub()
756            self.mul = P.Mul()
757            self.div = P.RealDiv()
758
759        def construct(self, a, b, x):
760            if b == x:
761                b = self.add(a, b)
762            else:
763                b = self.add(a, x)
764            a = a * b
765            out = a + b + x
766            return out
767
768    class Branch2Net(nn.Cell):
769        def __init__(self):
770            super().__init__()
771            self.add = P.Add()
772            self.sub = P.Sub()
773            self.mul = P.Mul()
774            self.div = P.RealDiv()
775            self.net = Branch3Net()
776
777        def construct(self, a, b, x):
778            if a == x:
779                a = self.mul(a, b)
780            else:
781                a = self.div(a, b)
782            return self.net(a, b, x)
783
784    class MyIfByIfNet(nn.Cell):
785        def __init__(self):
786            super().__init__()
787            self.add = P.Add()
788            self.sub = P.Sub()
789            self.mul = P.Mul()
790            self.div = P.RealDiv()
791            self.net = Branch2Net()
792
793        @ms_function
794        def construct(self, a, b, x):
795            if a < b:
796                a = self.add(a, b)
797            else:
798                a = self.sub(a, b)
799            out = self.net(a, b, x)
800            return out
801
802    if_net = MyIfByIfNet()
803    net = if_net
804    idx = Tensor(np.array(2), dtype=ms.float32)
805    end = Tensor(np.array(3), dtype=ms.float32)
806    x = Tensor(np.array(0), dtype=ms.float32)
807    net(idx, end, x)
808
809
810def test_if_by_if_forward_use_namespace():
811    class MyIfByIfNet(nn.Cell):
812        def __init__(self):
813            super().__init__()
814            self.add = P.Add()
815            self.sub = P.Sub()
816            self.mul = P.Mul()
817            self.div = P.RealDiv()
818
819        @ms_function
820        def construct(self, a, b, x):
821            if a < b:
822                a = P.Add()(a, b)
823            else:
824                a = P.Sub()(a, b)
825            if a == x:
826                a = P.Mul()(a, b)
827            else:
828                a = P.RealDiv()(a, b)
829            if b == x:
830                b = P.Add()(a, b)
831            else:
832                b = P.Add()(a, x)
833            a = a * b
834            out = a + b + x
835            return out
836
837    if_net = MyIfByIfNet()
838    net = if_net
839    idx = Tensor(np.array(2), dtype=ms.float32)
840    end = Tensor(np.array(3), dtype=ms.float32)
841    x = Tensor(np.array(0), dtype=ms.float32)
842    net(idx, end, x)
843
844
845def test_if_by_if_forward_use_global_op():
846    class MyIfByIfNet(nn.Cell):
847        def __init__(self):
848            super().__init__()
849            self.add = P.Add()
850            self.sub = P.Sub()
851            self.mul = P.Mul()
852            self.div = P.RealDiv()
853
854        @ms_function
855        def construct(self, a, b, x):
856            add = P.Add()
857            sub = P.Sub()
858            mul = P.Mul()
859            div = P.RealDiv()
860            if a < b:
861                a = add(a, b)
862            else:
863                a = sub(a, b)
864            if a == x:
865                a = mul(a, b)
866            else:
867                a = div(a, b)
868            if b == x:
869                b = add(a, b)
870            else:
871                b = add(a, x)
872            a = a * b
873            out = a + b + x
874            return out
875
876    if_net = MyIfByIfNet()
877    net = if_net
878    idx = Tensor(np.array(2), dtype=ms.float32)
879    end = Tensor(np.array(3), dtype=ms.float32)
880    x = Tensor(np.array(0), dtype=ms.float32)
881    net(idx, end, x)
882
883
884def test_for_with_if_by_if_forward():
885    class MyIfByIfNet(nn.Cell):
886        def __init__(self):
887            super().__init__()
888            self.add = P.Add()
889            self.sub = P.Sub()
890
891        @ms_function
892        def construct(self, a, b, x):
893            for _ in range(0, 4):
894                if a < b:
895                    a = self.add(a, b)
896                else:
897                    b = self.sub(b, x)
898            a = a * b
899            out = a + b + x
900            return out
901
902    if_net = MyIfByIfNet()
903    net = if_net
904    idx = Tensor(np.array(2), dtype=ms.float32)
905    end = Tensor(np.array(3), dtype=ms.float32)
906    x = Tensor(np.array(0), dtype=ms.float32)
907    net(idx, end, x)
908
909
910def test_for_with_if_by_if_forward_namespace():
911    class MyIfByIfNet(nn.Cell):
912        def __init__(self):
913            super().__init__()
914            self.add = P.Add()
915            self.sub = P.Sub()
916            self.mul = P.Mul()
917            self.div = P.RealDiv()
918
919        @ms_function
920        def construct(self, a, b, x):
921            for _ in range(0, 6):
922                if a < b:
923                    a = P.Add()(a, b)
924                else:
925                    b = P.Sub()(b, x)
926            a = a * b
927            out = a + b + x
928            return out
929
930    if_net = MyIfByIfNet()
931    net = if_net
932    idx = Tensor(np.array(2), dtype=ms.float32)
933    end = Tensor(np.array(3), dtype=ms.float32)
934    x = Tensor(np.array(0), dtype=ms.float32)
935    net(idx, end, x)
936
937
938def test_if_by_if_forward_const_branch_inner():
939    class MyIfByIfNet(nn.Cell):
940        def __init__(self):
941            super().__init__()
942            self.add = P.Add()
943            self.sub = P.Sub()
944            self.mul = P.Mul()
945            self.div = P.RealDiv()
946
947        @ms_function
948        def construct(self, a, b, x):
949            add = P.Add()
950            sub = P.Sub()
951            mul = P.Mul()
952            div = P.RealDiv()
953            if a < b:
954                a = add(a, b)
955            else:
956                a = sub(a, b)
957            if 2 > 1:
958                a = mul(a, b)
959            else:
960                a = div(a, b)
961            if b == x:
962                b = add(a, b)
963            else:
964                b = add(a, x)
965            a = a * b
966            out = a + b + x
967            return out
968
969    if_net = MyIfByIfNet()
970    net = if_net
971    idx = Tensor(np.array(2), dtype=ms.float32)
972    end = Tensor(np.array(3), dtype=ms.float32)
973    x = Tensor(np.array(0), dtype=ms.float32)
974    net(idx, end, x)
975
976
977def test_if_by_if_forward_all_const_branch():
978    class MyIfByIfNet(nn.Cell):
979        def __init__(self):
980            super().__init__()
981            self.add = P.Add()
982            self.sub = P.Sub()
983            self.mul = P.Mul()
984            self.div = P.RealDiv()
985
986        @ms_function
987        def construct(self, a, b, x):
988            add = P.Add()
989            sub = P.Sub()
990            mul = P.Mul()
991            div = P.RealDiv()
992            if 2 < 12:
993                a = add(a, b)
994            else:
995                a = sub(a, b)
996            if 2 > 1:
997                a = mul(a, b)
998            else:
999                a = div(a, b)
1000            if 2 == 1:
1001                b = add(a, b)
1002            else:
1003                b = add(a, x)
1004            a = a * b
1005            out = a + b + x
1006            return out
1007
1008    if_net = MyIfByIfNet()
1009    net = if_net
1010    idx = Tensor(np.array(2), dtype=ms.float32)
1011    end = Tensor(np.array(3), dtype=ms.float32)
1012    x = Tensor(np.array(0), dtype=ms.float32)
1013    net(idx, end, x)
1014