1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import pytest 17import numpy as np 18from mindspore import context, nn, Tensor 19from mindspore import dtype as mstype 20from mindspore.ops import composite as C 21from mindspore.ops import operations as P 22 23context.set_context(mode=context.GRAPH_MODE) 24 25single_element_fg = C.MultitypeFuncGraph("single_element_fg") 26@single_element_fg.register("Tensor") 27def single_element_fg_for_tensor(x): 28 return P.Square()(x) 29 30double_elements_fg = C.MultitypeFuncGraph("double_elements_fg") 31@double_elements_fg.register("Tensor", "Tuple") 32def double_elements_fg_for_tensor_tuple(x, y): 33 return P.Tile()(x, y) 34 35@double_elements_fg.register("Tensor", "List") 36def double_elements_fg_for_tensor_list(x, y): 37 return x + y[0] 38 39 40class HyperMapNet(nn.Cell): 41 def __init__(self, fg): 42 super(HyperMapNet, self).__init__() 43 self.common_map = C.HyperMap() 44 self.fg = fg 45 46 def construct(self, nest_tensor_list): 47 output = self.common_map(self.fg, *nest_tensor_list) 48 return output 49 50 51@pytest.mark.level1 52@pytest.mark.platform_x86_ascend_training 53@pytest.mark.platform_arm_ascend_training 54@pytest.mark.env_onecard 55def test_single_element_hypermap_with_tensor_input(): 56 """ 57 Feature: HyperMap 58 Description: Test whether the HyperMap with single tensor input can run successfully. 59 Expectation: success. 60 """ 61 x = (Tensor(np.array([1, 2, 3]), mstype.float32), Tensor(np.array([4, 5, 6]), mstype.float32)) 62 common_map = HyperMapNet(single_element_fg) 63 output = common_map((x,)) 64 expect_output_1 = np.array([1.0, 4.0, 9.0]) 65 expect_output_2 = np.array([16.0, 25.0, 36.0]) 66 assert isinstance(output, tuple) 67 assert len(output) == 2 68 assert isinstance(output[0], Tensor) 69 assert isinstance(output[1], Tensor) 70 assert np.allclose(output[0].asnumpy(), expect_output_1) 71 assert np.allclose(output[1].asnumpy(), expect_output_2) 72 73 74@pytest.mark.level1 75@pytest.mark.platform_x86_ascend_training 76@pytest.mark.platform_arm_ascend_training 77@pytest.mark.env_onecard 78def test_double_elements_hypermap_tensor_tuple_inputs(): 79 """ 80 Feature: HyperMap 81 Description: Test whether the HyperMap with tensor and tuple inputs can run successfully. 82 Expectation: success. 83 """ 84 x = (Tensor(np.array([1, 2, 3]), mstype.float32), Tensor(np.array([4, 5, 6]), mstype.float32)) 85 y = ((1, 2), (2, 1)) 86 common_map = HyperMapNet(double_elements_fg) 87 output = common_map((x, y)) 88 expect_output_1 = np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0]) 89 expect_output_2 = np.array([[4.0, 5.0, 6.0], [4.0, 5.0, 6.0]]) 90 assert isinstance(output, tuple) 91 assert len(output) == 2 92 assert isinstance(output[0], Tensor) 93 assert isinstance(output[1], Tensor) 94 assert np.allclose(output[0].asnumpy(), expect_output_1) 95 assert np.allclose(output[1].asnumpy(), expect_output_2) 96 97 98@pytest.mark.level1 99@pytest.mark.platform_x86_ascend_training 100@pytest.mark.platform_arm_ascend_training 101@pytest.mark.env_onecard 102def test_double_elements_hypermap_tensor_list_inputs(): 103 """ 104 Feature: HyperMap 105 Description: Test whether the HyperMap with tensor and list inputs can run successfully. 106 Expectation: success. 107 """ 108 x = (Tensor(np.array([1, 2, 3]), mstype.float32), Tensor(np.array([4, 5, 6]), mstype.float32)) 109 y = ([1, 2], [2, 1]) 110 common_map = HyperMapNet(double_elements_fg) 111 output = common_map((x, y)) 112 expect_output_1 = np.array([2.0, 3.0, 4.0]) 113 expect_output_2 = np.array([6.0, 7.0, 8.0]) 114 assert isinstance(output, tuple) 115 assert len(output) == 2 116 assert isinstance(output[0], Tensor) 117 assert isinstance(output[1], Tensor) 118 assert np.allclose(output[0].asnumpy(), expect_output_1) 119 assert np.allclose(output[1].asnumpy(), expect_output_2) 120 121 122@pytest.mark.level1 123@pytest.mark.platform_x86_ascend_training 124@pytest.mark.platform_arm_ascend_training 125@pytest.mark.env_onecard 126def test_doubel_elements_hypermap_correct_mix_inputs(): 127 """ 128 Feature: HyperMap 129 Description: Test whether the HyperMap with mix correct inputs (Tensor + Tuple and Tensor + List) 130 can run successfully. 131 Expectation: success. 132 """ 133 x = (Tensor(np.array([1, 2, 3]), mstype.float32), Tensor(np.array([4, 5, 6]), mstype.float32)) 134 y = ((1, 2), [2, 1]) 135 common_map = HyperMapNet(double_elements_fg) 136 output = common_map((x, y)) 137 expect_output_1 = np.array([1.0, 2.0, 3.0, 1.0, 2.0, 3.0]) 138 expect_output_2 = np.array([6.0, 7.0, 8.0]) 139 assert isinstance(output, tuple) 140 assert len(output) == 2 141 assert isinstance(output[0], Tensor) 142 assert isinstance(output[1], Tensor) 143 assert np.allclose(output[0].asnumpy(), expect_output_1) 144 assert np.allclose(output[1].asnumpy(), expect_output_2) 145 146 147 148@pytest.mark.level1 149@pytest.mark.platform_x86_ascend_training 150@pytest.mark.platform_arm_ascend_training 151@pytest.mark.env_onecard 152def test_double_elements_hypermap_inputs_length_mismatch(): 153 """ 154 Feature: HyperMap 155 Description: When the inputs to hypermap have different length, error will be raised. 156 Expectation: error. 157 """ 158 x = (Tensor(np.array([1, 2, 3]), mstype.float32), Tensor(np.array([4, 5, 6]), mstype.float32)) 159 y = ((1, 2), (2, 1), (5, 6)) 160 common_map = HyperMapNet(double_elements_fg) 161 with pytest.raises(Exception, match="The length of tuples in HyperMap must be the same"): 162 common_map((x, y)) 163 164 165@pytest.mark.level1 166@pytest.mark.platform_x86_ascend_training 167@pytest.mark.platform_arm_ascend_training 168@pytest.mark.env_onecard 169def test_double_elements_hypermap_inconsistent_inputs(): 170 """ 171 Feature: HyperMap 172 Description: When the inputs to hypermap is inconsistent, error will be raised. 173 Expectation: error. 174 """ 175 x = (Tensor(np.array([1, 2, 3]), mstype.float32), Tensor(np.array([4, 5, 6]), mstype.float32)) 176 y = [(1, 2), (2, 1)] 177 common_map = HyperMapNet(double_elements_fg) 178 with pytest.raises(Exception, match="the types of arguments in HyperMap must be consistent"): 179 common_map((x, y)) 180