• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import pytest
17import numpy as np
18from mindspore import context, nn, Tensor
19from mindspore import dtype as mstype
20from mindspore.ops import composite as C
21from mindspore.ops import operations as P
22
23context.set_context(mode=context.GRAPH_MODE)
24
25single_element_fg = C.MultitypeFuncGraph("single_element_fg")
26@single_element_fg.register("Tensor")
27def single_element_fg_for_tensor(x):
28    return P.Square()(x)
29
30double_elements_fg = C.MultitypeFuncGraph("double_elements_fg")
31@double_elements_fg.register("Tensor", "Tuple")
32def double_elements_fg_for_tensor_tuple(x, y):
33    return P.Tile()(x, y)
34
35@double_elements_fg.register("Tensor", "List")
36def double_elements_fg_for_tensor_list(x, y):
37    return x + y[0]
38
39
40class HyperMapNet(nn.Cell):
41    def __init__(self, fg):
42        super(HyperMapNet, self).__init__()
43        self.common_map = C.HyperMap()
44        self.fg = fg
45
46    def construct(self, nest_tensor_list):
47        output = self.common_map(self.fg, *nest_tensor_list)
48        return output
49
50
51@pytest.mark.level1
52@pytest.mark.platform_x86_ascend_training
53@pytest.mark.platform_arm_ascend_training
54@pytest.mark.env_onecard
55def test_single_element_hypermap_with_tensor_input():
56    """
57    Feature: HyperMap
58    Description: Test whether the HyperMap with single tensor input can run successfully.
59    Expectation: success.
60    """
61    x = (Tensor(np.array([1, 2, 3]), mstype.float32), Tensor(np.array([4, 5, 6]), mstype.float32))
62    common_map = HyperMapNet(single_element_fg)
63    output = common_map((x,))
64    expect_output_1 = np.array([1.0, 4.0, 9.0])
65    expect_output_2 = np.array([16.0, 25.0, 36.0])
66    assert isinstance(output, tuple)
67    assert len(output) == 2
68    assert isinstance(output[0], Tensor)
69    assert isinstance(output[1], Tensor)
70    assert np.allclose(output[0].asnumpy(), expect_output_1)
71    assert np.allclose(output[1].asnumpy(), expect_output_2)
72
73
74@pytest.mark.level1
75@pytest.mark.platform_x86_ascend_training
76@pytest.mark.platform_arm_ascend_training
77@pytest.mark.env_onecard
78def test_double_elements_hypermap_tensor_tuple_inputs():
79    """
80    Feature: HyperMap
81    Description: Test whether the HyperMap with tensor and tuple inputs can run successfully.
82    Expectation: success.
83    """
84    x = (Tensor(np.array([1, 2, 3]), mstype.float32), Tensor(np.array([4, 5, 6]), mstype.float32))
85    y = ((1, 2), (2, 1))
86    common_map = HyperMapNet(double_elements_fg)
87    output = common_map((x, y))
88    expect_output_1 = np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0])
89    expect_output_2 = np.array([[4.0, 5.0, 6.0], [4.0, 5.0, 6.0]])
90    assert isinstance(output, tuple)
91    assert len(output) == 2
92    assert isinstance(output[0], Tensor)
93    assert isinstance(output[1], Tensor)
94    assert np.allclose(output[0].asnumpy(), expect_output_1)
95    assert np.allclose(output[1].asnumpy(), expect_output_2)
96
97
98@pytest.mark.level1
99@pytest.mark.platform_x86_ascend_training
100@pytest.mark.platform_arm_ascend_training
101@pytest.mark.env_onecard
102def test_double_elements_hypermap_tensor_list_inputs():
103    """
104    Feature: HyperMap
105    Description: Test whether the HyperMap with tensor and list inputs can run successfully.
106    Expectation: success.
107    """
108    x = (Tensor(np.array([1, 2, 3]), mstype.float32), Tensor(np.array([4, 5, 6]), mstype.float32))
109    y = ([1, 2], [2, 1])
110    common_map = HyperMapNet(double_elements_fg)
111    output = common_map((x, y))
112    expect_output_1 = np.array([2.0, 3.0, 4.0])
113    expect_output_2 = np.array([6.0, 7.0, 8.0])
114    assert isinstance(output, tuple)
115    assert len(output) == 2
116    assert isinstance(output[0], Tensor)
117    assert isinstance(output[1], Tensor)
118    assert np.allclose(output[0].asnumpy(), expect_output_1)
119    assert np.allclose(output[1].asnumpy(), expect_output_2)
120
121
122@pytest.mark.level1
123@pytest.mark.platform_x86_ascend_training
124@pytest.mark.platform_arm_ascend_training
125@pytest.mark.env_onecard
126def test_doubel_elements_hypermap_correct_mix_inputs():
127    """
128    Feature: HyperMap
129    Description: Test whether the HyperMap with mix correct inputs (Tensor + Tuple and Tensor + List)
130                 can run successfully.
131    Expectation: success.
132    """
133    x = (Tensor(np.array([1, 2, 3]), mstype.float32), Tensor(np.array([4, 5, 6]), mstype.float32))
134    y = ((1, 2), [2, 1])
135    common_map = HyperMapNet(double_elements_fg)
136    output = common_map((x, y))
137    expect_output_1 = np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0])
138    expect_output_2 = np.array([6.0, 7.0, 8.0])
139    assert isinstance(output, tuple)
140    assert len(output) == 2
141    assert isinstance(output[0], Tensor)
142    assert isinstance(output[1], Tensor)
143    assert np.allclose(output[0].asnumpy(), expect_output_1)
144    assert np.allclose(output[1].asnumpy(), expect_output_2)
145
146
147
148@pytest.mark.level1
149@pytest.mark.platform_x86_ascend_training
150@pytest.mark.platform_arm_ascend_training
151@pytest.mark.env_onecard
152def test_double_elements_hypermap_inputs_length_mismatch():
153    """
154    Feature: HyperMap
155    Description: When the inputs to hypermap have different length, error will be raised.
156    Expectation: error.
157    """
158    x = (Tensor(np.array([1, 2, 3]), mstype.float32), Tensor(np.array([4, 5, 6]), mstype.float32))
159    y = ((1, 2), (2, 1), (5, 6))
160    common_map = HyperMapNet(double_elements_fg)
161    with pytest.raises(Exception, match="The length of tuples in HyperMap must be the same"):
162        common_map((x, y))
163
164
165@pytest.mark.level1
166@pytest.mark.platform_x86_ascend_training
167@pytest.mark.platform_arm_ascend_training
168@pytest.mark.env_onecard
169def test_double_elements_hypermap_inconsistent_inputs():
170    """
171    Feature: HyperMap
172    Description: When the inputs to hypermap is inconsistent, error will be raised.
173    Expectation: error.
174    """
175    x = (Tensor(np.array([1, 2, 3]), mstype.float32), Tensor(np.array([4, 5, 6]), mstype.float32))
176    y = [(1, 2), (2, 1)]
177    common_map = HyperMapNet(double_elements_fg)
178    with pytest.raises(Exception, match="the types of arguments in HyperMap must be consistent"):
179        common_map((x, y))
180