• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import pytest
16from mindspore import context
17from mindspore import Tensor, nn
18from mindspore.ops import composite as C
19from mindspore.common import dtype as mstype
20
21grad_all = C.GradOperation(get_all=True)
22
23
24class SingleIfNet(nn.Cell):
25    def construct(self, x, y):
26        x += 1
27        if x < y:
28            y += x
29        else:
30            y -= x
31        y += 5
32        return y
33
34
35class SingleIfNet1(nn.Cell):
36    def construct(self, x, y):
37        x += 1
38        out = self.func(x, y)
39        out *= 2
40        return out
41
42    def func(self, x, y):
43        if x < y:
44            y += x
45        else:
46            y -= x
47        y += 5
48        return y
49
50
51class GradNet(nn.Cell):
52    def __init__(self, net):
53        super(GradNet, self).__init__()
54        self.net = net
55
56    def construct(self, *inputs):
57        return grad_all(self.net)(*inputs)
58
59
60def control_flow_single_if(input_net, x, y, expect1, expect2):
61    # graph mode
62    context.set_context(mode=context.GRAPH_MODE)
63    net = input_net()
64    grad_net = GradNet(net)
65
66    forward_net = input_net()
67    graph_forward_res = forward_net(x, y)
68    graph_backward_res = grad_net(x, y)
69
70    assert graph_forward_res == expect1
71    assert graph_backward_res == expect2
72
73
74@pytest.mark.level1
75@pytest.mark.platform_x86_gpu_training
76@pytest.mark.platform_arm_ascend_training
77@pytest.mark.platform_x86_ascend_training
78@pytest.mark.env_onecard
79def test_single_if():
80    x = Tensor(2, mstype.int32)
81    y = Tensor(5, mstype.int32)
82    expect1 = Tensor(26, mstype.int32)
83    expect2 = (Tensor(2, mstype.int32), Tensor(2, mstype.int32))
84    control_flow_single_if(SingleIfNet1, x, y, expect1, expect2)
85
86
87@pytest.mark.level1
88@pytest.mark.platform_x86_gpu_training
89@pytest.mark.platform_arm_ascend_training
90@pytest.mark.platform_x86_ascend_training
91@pytest.mark.env_onecard
92def test_single_if_01():
93    x = Tensor(2, mstype.int32)
94    y = Tensor(5, mstype.int32)
95    expect1 = Tensor(26, mstype.int32)
96    expect2 = (Tensor(2, mstype.int32), Tensor(2, mstype.int32))
97    control_flow_single_if(SingleIfNet1, x, y, expect1, expect2)
98