• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2024 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import numpy as np
17import mindspore as ms
18from mindspore import ops, nn, Tensor, Symbol
19from mindspore.ops import functional as F
20from mindspore.common.api import jit
21import pytest
22
23
24@pytest.mark.level0
25@pytest.mark.platform_x86_cpu_training
26@pytest.mark.env_onecard
27def test_symbol_graphmode_setinputs():
28    """
29    Feature: Symbol
30    Description: graphmode, set symbolic info with cell.set_inputs
31    Expectation: success
32    """
33    class Net(nn.Cell):
34        def __init__(self):
35            super(Net, self).__init__()
36            self.add = ops.Add()
37
38        def construct(self, x, y):
39            return self.add(x, y)
40
41    ms.set_context(mode=ms.GRAPH_MODE, device_target="CPU")
42
43    s1 = Symbol(max=16, divisor=8)  # the value can be 8, 16
44    s2 = Symbol(unique=True)
45    x_dyn = Tensor(shape=[s1, s2, s2], dtype=ms.float32)
46    y_dyn = Tensor(shape=[s1, s1, s2], dtype=ms.float32)
47    net = Net()
48    net.set_inputs(x_dyn, y_dyn)
49    with pytest.raises(ValueError):
50        x = Tensor(np.ones((32, 32, 32), np.float32))
51        net(x, x)  # s1 > max
52
53    with pytest.raises(ValueError):
54        x = Tensor(np.ones((16, 8, 8), np.float32))
55        y = Tensor(np.ones((16, 8, 1), np.float32))
56        net(x, y)  # s2 is unique, but y.shape[2] != x.shape[2]
57
58    with pytest.raises(ValueError):
59        x = Tensor(np.ones((10, 8, 8), np.float32))
60        net(x, x)  # s1.divisor = 8, but x.shape[0] == 10
61
62    x = Tensor(np.ones((16, 8, 8), np.float32))
63    assert net(x, x).shape == (16, 8, 8)
64
65
66@pytest.mark.level0
67@pytest.mark.platform_x86_cpu_training
68@pytest.mark.env_onecard
69def test_symbol_pynativemode_setinputs():
70    """
71    Feature: Symbol
72    Description: pynativemode, set symbolic info with cell.set_inputs
73    Expectation: success
74    """
75    class Net(nn.Cell):
76        def __init__(self):
77            super(Net, self).__init__()
78            self.add = ops.Add()
79
80        @jit
81        def construct(self, x, y):
82            return self.add(x, y)
83
84    ms.set_context(mode=ms.PYNATIVE_MODE, device_target="CPU")
85
86    s1 = Symbol(max=16, divisor=8)  # the value can be 8, 16
87    s2 = Symbol(min=4, unique=True)
88    x_dyn = Tensor(shape=[s1, s2], dtype=ms.float32)
89    net = Net()
90    net.set_inputs(x_dyn, x_dyn)
91
92    with pytest.raises(ValueError):
93        x = Tensor(np.ones((8, 1), np.float32))
94        net(x, x)  # s2.min = 8, but y.shape[1] == 1
95
96    x = Tensor(np.ones((16, 8), np.float32))
97    assert net(x, x).shape == (16, 8)
98
99
100@pytest.mark.level0
101@pytest.mark.platform_x86_cpu_training
102@pytest.mark.env_onecard
103def test_symbol_pynativemode_signature():
104    """
105    Feature: Symbol
106    Description: pynativemode, set symbolic info with input_signature
107    Expectation: success
108    """
109    s1 = Symbol(max=16, unique=True)
110    s2 = Symbol(min=4, unique=True)
111    x_dyn = Tensor(shape=[s1, s1], dtype=ms.float32)
112    y_dyn = Tensor(shape=[s2, s2], dtype=ms.float32)
113    @jit(input_signature=(x_dyn, y_dyn))
114    def add_func(x, y):
115        return F.tensor_add(x, y)
116
117    ms.set_context(mode=ms.PYNATIVE_MODE, device_target="CPU")
118
119    with pytest.raises(ValueError):
120        x = Tensor(np.ones((1, 1), np.float32))
121        y = Tensor(np.ones((4, 8), np.float32))
122        add_func(x, y)  # s2 is unique, but y.shape[0] != y.shape[1]
123
124    x = Tensor(np.ones((1, 1), np.float32))
125    y = Tensor(np.ones((4, 4), np.float32))
126    assert add_func(x, y).shape == (4, 4)
127