• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1import numpy as np
2import pytest
3import mindspore.context as context
4from mindspore import Tensor
5from mindspore.common.parameter import Parameter
6from mindspore.nn import Cell
7import mindspore.ops.operations as P
8
9context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
10
11
12@pytest.mark.level0
13@pytest.mark.platform_arm_ascend_training
14@pytest.mark.platform_x86_ascend_training
15@pytest.mark.env_onecard
16def test_if_by_if_basic():
17    class SubNet(Cell):
18        def __init__(self):
19            super().__init__()
20            self.mul = P.Mul()
21            self.add = P.Add()
22            a = np.full((1,), 5, dtype=np.float32)
23            self.a = Parameter(Tensor(a), name='a')
24            b = np.full((1,), 4, dtype=np.float32)
25            self.b = Parameter(Tensor(b), name='b')
26
27        def construct(self, x):
28            if self.a > self.b:
29                x = self.mul(x, 1)
30                while self.b < 6:
31                    x = self.add(x, x)
32                    self.b += 1
33            return x
34
35    class Net(Cell):
36        def __init__(self):
37            super().__init__()
38            self.subnet = SubNet()
39            self.relu = P.ReLU()
40            self.add = P.Add()
41            a = np.full((1,), 5, dtype=np.float32)
42            self.a = Parameter(Tensor(a), name='a')
43            b = np.full((1,), 4, dtype=np.float32)
44            self.b = Parameter(Tensor(b), name='b')
45            c = np.full((1,), 7, dtype=np.float32)
46            self.c = Parameter(Tensor(c), name='c')
47
48        def func(self, x):
49            for _ in range(0, 2):
50                x = self.add(x, 0)
51            return x
52
53        def construct(self, x):
54            if self.a > self.b:
55                x = self.subnet(x)
56            else:
57                x = self.relu(x)
58            if self.a < self.c:
59                x = self.func(x)
60            else:
61                x = self.add(x, 2)
62            return x
63
64    input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
65    net = Net()
66    out_ms = net(Tensor(input_np))
67    out_np = input_np * 4
68    assert np.allclose(out_ms.asnumpy(), out_np)
69