1# Copyright 2024 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import numpy as np 17import mindspore as ms 18from mindspore import ops, nn, Tensor, Symbol 19from mindspore.ops import functional as F 20from mindspore.common.api import jit 21import pytest 22 23 24@pytest.mark.level0 25@pytest.mark.platform_x86_cpu_training 26@pytest.mark.env_onecard 27def test_symbol_graphmode_setinputs(): 28 """ 29 Feature: Symbol 30 Description: graphmode, set symbolic info with cell.set_inputs 31 Expectation: success 32 """ 33 class Net(nn.Cell): 34 def __init__(self): 35 super(Net, self).__init__() 36 self.add = ops.Add() 37 38 def construct(self, x, y): 39 return self.add(x, y) 40 41 ms.set_context(mode=ms.GRAPH_MODE, device_target="CPU") 42 43 s1 = Symbol(max=16, divisor=8) # the value can be 8, 16 44 s2 = Symbol(unique=True) 45 x_dyn = Tensor(shape=[s1, s2, s2], dtype=ms.float32) 46 y_dyn = Tensor(shape=[s1, s1, s2], dtype=ms.float32) 47 net = Net() 48 net.set_inputs(x_dyn, y_dyn) 49 with pytest.raises(ValueError): 50 x = Tensor(np.ones((32, 32, 32), np.float32)) 51 net(x, x) # s1 > max 52 53 with pytest.raises(ValueError): 54 x = Tensor(np.ones((16, 8, 8), np.float32)) 55 y = Tensor(np.ones((16, 8, 1), np.float32)) 56 net(x, y) # s2 is unique, but y.shape[2] != x.shape[2] 57 58 with pytest.raises(ValueError): 59 x = Tensor(np.ones((10, 8, 8), np.float32)) 60 net(x, x) # s1.divisor = 8, but x.shape[0] == 10 61 62 x = Tensor(np.ones((16, 8, 8), np.float32)) 63 assert net(x, x).shape == (16, 8, 8) 64 65 66@pytest.mark.level0 67@pytest.mark.platform_x86_cpu_training 68@pytest.mark.env_onecard 69def test_symbol_pynativemode_setinputs(): 70 """ 71 Feature: Symbol 72 Description: pynativemode, set symbolic info with cell.set_inputs 73 Expectation: success 74 """ 75 class Net(nn.Cell): 76 def __init__(self): 77 super(Net, self).__init__() 78 self.add = ops.Add() 79 80 @jit 81 def construct(self, x, y): 82 return self.add(x, y) 83 84 ms.set_context(mode=ms.PYNATIVE_MODE, device_target="CPU") 85 86 s1 = Symbol(max=16, divisor=8) # the value can be 8, 16 87 s2 = Symbol(min=4, unique=True) 88 x_dyn = Tensor(shape=[s1, s2], dtype=ms.float32) 89 net = Net() 90 net.set_inputs(x_dyn, x_dyn) 91 92 with pytest.raises(ValueError): 93 x = Tensor(np.ones((8, 1), np.float32)) 94 net(x, x) # s2.min = 8, but y.shape[1] == 1 95 96 x = Tensor(np.ones((16, 8), np.float32)) 97 assert net(x, x).shape == (16, 8) 98 99 100@pytest.mark.level0 101@pytest.mark.platform_x86_cpu_training 102@pytest.mark.env_onecard 103def test_symbol_pynativemode_signature(): 104 """ 105 Feature: Symbol 106 Description: pynativemode, set symbolic info with input_signature 107 Expectation: success 108 """ 109 s1 = Symbol(max=16, unique=True) 110 s2 = Symbol(min=4, unique=True) 111 x_dyn = Tensor(shape=[s1, s1], dtype=ms.float32) 112 y_dyn = Tensor(shape=[s2, s2], dtype=ms.float32) 113 @jit(input_signature=(x_dyn, y_dyn)) 114 def add_func(x, y): 115 return F.tensor_add(x, y) 116 117 ms.set_context(mode=ms.PYNATIVE_MODE, device_target="CPU") 118 119 with pytest.raises(ValueError): 120 x = Tensor(np.ones((1, 1), np.float32)) 121 y = Tensor(np.ones((4, 8), np.float32)) 122 add_func(x, y) # s2 is unique, but y.shape[0] != y.shape[1] 123 124 x = Tensor(np.ones((1, 1), np.float32)) 125 y = Tensor(np.ones((4, 4), np.float32)) 126 assert add_func(x, y).shape == (4, 4) 127