1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import os 16import pytest 17from mindspore import context 18from mindspore import Tensor, nn 19from mindspore.ops import composite as C 20from mindspore.ops import operations as P 21from mindspore.common import dtype as mstype 22 23grad_all = C.GradOperation(get_all=True) 24context.set_context(device_target="Ascend") 25 26@pytest.mark.level0 27@pytest.mark.platform_arm_ascend_training 28@pytest.mark.platform_x86_ascend_training 29@pytest.mark.env_onecard 30def test_single_for_01(): 31 class SingleForNet(nn.Cell): 32 def __init__(self): 33 super().__init__() 34 self.add = P.Add() 35 self.mul = P.Mul() 36 37 def construct(self, x, y, z): 38 x = self.add(x, y) 39 for _ in range(0, 3): 40 z = self.add(z, x) 41 y = self.mul(z, y) 42 return y 43 44 class GradNet(nn.Cell): 45 def __init__(self, net): 46 super(GradNet, self).__init__() 47 self.net = net 48 49 def construct(self, *inputs): 50 return grad_all(self.net)(*inputs) 51 52 x = Tensor([2], mstype.int32) 53 y = Tensor([5], mstype.int32) 54 z = Tensor([4], mstype.int32) 55 56 os.environ['ENV_FOR_TO_WHILE_LOOP'] = '1' 57 # graph mode 58 context.set_context(mode=context.GRAPH_MODE) 59 for_net = SingleForNet() 60 net = GradNet(for_net) 61 graph_forward_res = for_net(x, y, z) 62 graph_backward_res = net(x, y, z) 63 64 # pynative mode 65 context.set_context(mode=context.PYNATIVE_MODE) 66 for_net = SingleForNet() 67 net = GradNet(for_net) 68 pynative_forward_res = for_net(x, y, z) 69 pynative_backward_res = net(x, y, z) 70 os.environ['ENV_FOR_TO_WHILE_LOOP'] = '' 71 72 assert graph_forward_res == pynative_forward_res 73 assert graph_backward_res == pynative_backward_res 74