• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import pytest
16from mindspore.common import dtype as mstype
17from mindspore import nn
18from mindspore import Tensor
19from mindspore.ops import composite as C
20from mindspore import context
21
22context.set_context(mode=context.GRAPH_MODE)
23
24
25class ForwardNet(nn.Cell):
26    def construct(self, x, y):
27        y = y + 10
28        while x < y:
29            x = (x + 2) * (y - 9)
30            y = y + 2
31        x = x + 5
32        return x
33
34
35class BackwardNet(nn.Cell):
36    def __init__(self, forward_net):
37        super(BackwardNet, self).__init__()
38        self.forward_net = forward_net
39        self.grad = C.GradOperation()
40
41    def construct(self, *inputs):
42        grads = self.grad(self.forward_net)(*inputs)
43        return grads
44
45
46@pytest.mark.level0
47@pytest.mark.platform_x86_gpu_training
48@pytest.mark.platform_arm_ascend_training
49@pytest.mark.platform_x86_ascend_training
50@pytest.mark.env_onecard
51def test_forward():
52    c1 = Tensor([0], mstype.int32)
53    c2 = Tensor([0], mstype.int32)
54    expect = Tensor([75], mstype.int32)
55    forward_net = ForwardNet()
56    output = forward_net(c1, c2)
57    assert expect == output
58
59
60@pytest.mark.level0
61@pytest.mark.platform_x86_gpu_training
62@pytest.mark.platform_arm_ascend_training
63@pytest.mark.platform_x86_ascend_training
64@pytest.mark.env_onecard
65def test_backward():
66    c1 = Tensor([0], mstype.int32)
67    c2 = Tensor([0], mstype.int32)
68    expect = Tensor([75], mstype.int32)
69    forward_net = ForwardNet()
70    output = forward_net(c1, c2)
71    assert expect == output
72