1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import pytest 16from mindspore import context 17from mindspore import Tensor, nn 18from mindspore.ops import composite as C 19from mindspore.common import dtype as mstype 20 21grad_all = C.GradOperation(get_all=True) 22 23 24class SingleIfNet(nn.Cell): 25 def construct(self, x, y): 26 x += 1 27 if x < y: 28 y += x 29 else: 30 y -= x 31 y += 5 32 return y 33 34 35class SingleIfNet1(nn.Cell): 36 def construct(self, x, y): 37 x += 1 38 out = self.func(x, y) 39 out *= 2 40 return out 41 42 def func(self, x, y): 43 if x < y: 44 y += x 45 else: 46 y -= x 47 y += 5 48 return y 49 50 51class GradNet(nn.Cell): 52 def __init__(self, net): 53 super(GradNet, self).__init__() 54 self.net = net 55 56 def construct(self, *inputs): 57 return grad_all(self.net)(*inputs) 58 59 60def control_flow_single_if(input_net, x, y, expect1, expect2): 61 # graph mode 62 context.set_context(mode=context.GRAPH_MODE) 63 net = input_net() 64 grad_net = GradNet(net) 65 66 forward_net = input_net() 67 graph_forward_res = forward_net(x, y) 68 graph_backward_res = grad_net(x, y) 69 70 assert graph_forward_res == expect1 71 assert graph_backward_res == expect2 72 73 74@pytest.mark.level1 75@pytest.mark.platform_x86_gpu_training 76@pytest.mark.platform_arm_ascend_training 77@pytest.mark.platform_x86_ascend_training 78@pytest.mark.env_onecard 79def test_single_if(): 80 x = Tensor(2, mstype.int32) 81 y = Tensor(5, mstype.int32) 82 expect1 = Tensor(26, mstype.int32) 83 expect2 = (Tensor(2, mstype.int32), Tensor(2, mstype.int32)) 84 control_flow_single_if(SingleIfNet1, x, y, expect1, expect2) 85 86 87@pytest.mark.level1 88@pytest.mark.platform_x86_gpu_training 89@pytest.mark.platform_arm_ascend_training 90@pytest.mark.platform_x86_ascend_training 91@pytest.mark.env_onecard 92def test_single_if_01(): 93 x = Tensor(2, mstype.int32) 94 y = Tensor(5, mstype.int32) 95 expect1 = Tensor(26, mstype.int32) 96 expect2 = (Tensor(2, mstype.int32), Tensor(2, mstype.int32)) 97 control_flow_single_if(SingleIfNet1, x, y, expect1, expect2) 98