1import numpy as np 2import pytest 3import mindspore.context as context 4from mindspore import Tensor 5from mindspore.common.parameter import Parameter 6from mindspore.nn import Cell 7import mindspore.ops.operations as P 8 9context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 10 11 12@pytest.mark.level0 13@pytest.mark.platform_arm_ascend_training 14@pytest.mark.platform_x86_ascend_training 15@pytest.mark.env_onecard 16def test_if_by_if_basic(): 17 class SubNet(Cell): 18 def __init__(self): 19 super().__init__() 20 self.mul = P.Mul() 21 self.add = P.Add() 22 a = np.full((1,), 5, dtype=np.float32) 23 self.a = Parameter(Tensor(a), name='a') 24 b = np.full((1,), 4, dtype=np.float32) 25 self.b = Parameter(Tensor(b), name='b') 26 27 def construct(self, x): 28 if self.a > self.b: 29 x = self.mul(x, 1) 30 while self.b < 6: 31 x = self.add(x, x) 32 self.b += 1 33 return x 34 35 class Net(Cell): 36 def __init__(self): 37 super().__init__() 38 self.subnet = SubNet() 39 self.relu = P.ReLU() 40 self.add = P.Add() 41 a = np.full((1,), 5, dtype=np.float32) 42 self.a = Parameter(Tensor(a), name='a') 43 b = np.full((1,), 4, dtype=np.float32) 44 self.b = Parameter(Tensor(b), name='b') 45 c = np.full((1,), 7, dtype=np.float32) 46 self.c = Parameter(Tensor(c), name='c') 47 48 def func(self, x): 49 for _ in range(0, 2): 50 x = self.add(x, 0) 51 return x 52 53 def construct(self, x): 54 if self.a > self.b: 55 x = self.subnet(x) 56 else: 57 x = self.relu(x) 58 if self.a < self.c: 59 x = self.func(x) 60 else: 61 x = self.add(x, 2) 62 return x 63 64 input_np = np.random.randn(2, 3, 4, 5).astype(np.float32) 65 net = Net() 66 out_ms = net(Tensor(input_np)) 67 out_np = input_np * 4 68 assert np.allclose(out_ms.asnumpy(), out_np) 69