1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test_hypermap """ 16import numpy as np 17 18from mindspore import Tensor 19from mindspore.common.api import ms_function 20from mindspore.ops import Primitive 21from mindspore.ops import _constants 22from mindspore.ops import composite as C 23from mindspore.ops import functional as F 24from mindspore.ops import operations as P 25from ...ut_filter import non_graph_engine 26 27# pylint: disable=W0613 28# W0613: unused-argument 29 30 31tensor_add = P.Add() 32scala_add = Primitive(_constants.kScalarAdd) 33add = C.MultitypeFuncGraph('add') 34 35 36@add.register("Number", "Number") 37def add_scala(x, y): 38 return scala_add(x, y) 39 40 41@add.register("Tensor", "Tensor") 42def add_tensor(x, y): 43 return tensor_add(x, y) 44 45 46hyper_add = C.HyperMap(add) 47 48 49@ms_function 50def mainf(x, y): 51 return hyper_add(x, y) 52 53 54@non_graph_engine 55def test_hypermap_tensor(): 56 tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) 57 tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) 58 print("test_hypermap_tensor:", mainf(tensor1, tensor2)) 59 60 61def test_hypermap_scalar(): 62 print("test_hypermap_scalar", mainf(1, 2)) 63 64 65def test_hypermap_tuple(): 66 print("test_hypermap_tuple", mainf((1, 1), (2, 2))) 67 68 69@non_graph_engine 70def test_hypermap_tuple_tensor(): 71 tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) 72 tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) 73 print("test_hypermap_tuple_tensor", mainf((tensor1, tensor1), (tensor2, tensor2))) 74 75 76@non_graph_engine 77def test_hypermap_tuple_mix(): 78 tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) 79 tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) 80 print("test_hypermap_tuple_mix", mainf((tensor1, 1), (tensor2, 2))) 81 82 83hyper_map = C.HyperMap() 84 85 86@ms_function 87def main_noleaf(x, y): 88 return hyper_map(add, x, y) 89 90 91def test_hypermap_noleaf_scalar(): 92 main_noleaf(1, 2) 93 94 95@non_graph_engine 96def test_hypermap_noleaf_tensor(): 97 tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) 98 tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) 99 main_noleaf(tensor1, tensor2) 100 101 102def test_hypermap_noleaf_tuple(): 103 main_noleaf((1, 1), (2, 2)) 104 105 106@non_graph_engine 107def test_hypermap_noleaf_tuple_tensor(): 108 tensor1 = Tensor(np.array([[1.1, 2.1], [2.1, 3.1]]).astype('float32')) 109 tensor2 = Tensor(np.array([[1.2, 2.2], [2.2, 3.2]]).astype('float32')) 110 tensor3 = Tensor(np.array([[2.2], [3.2]]).astype('float32')) 111 tensor4 = Tensor(np.array([[2.2], [3.2]]).astype('float32')) 112 main_noleaf((tensor1, tensor3), (tensor2, tensor4)) 113 114 115def test_hypermap_noleaf_tuple_mix(): 116 tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) 117 tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) 118 main_noleaf((tensor1, 1), (tensor2, 2)) 119 120 121def add3_scalar(x, y, z): 122 return scala_add(scala_add(x, y), z) 123 124 125@ms_function 126def main_add3_easy(x, y): 127 add2 = F.partial(add3_scalar, 1) 128 return add2(x, y) 129 130 131def test_hypermap_add3_easy(): 132 main_add3_easy(1, 2) 133 134 135add3 = C.MultitypeFuncGraph('add') 136partial = P.Partial() 137 138 139@add3.register("Number", "Number", "Number") 140def add3_scala(x, y, z): 141 return scala_add(scala_add(x, y), z) 142 143 144@add3.register("Number", "Tensor", "Tensor") 145def add3_tensor(x, y, z): 146 return tensor_add(y, z) 147 148 149@ms_function 150def main_add3_scala(x, y): 151 add2 = partial(add3_scala, 1) 152 return hyper_map(add2, x, y) 153 154 155@ms_function 156def main_add3(x, y): 157 add2 = partial(add3, 1) 158 return hyper_map(add2, x, y) 159 160 161@non_graph_engine 162def test_hypermap_add3_tensor(): 163 tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) 164 tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) 165 main_add3(tensor1, tensor2) 166 167 168def test_hypermap_add3_tuple(): 169 tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) 170 tensor2 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) 171 172 main_add3((tensor1, 1), (tensor2, 1)) 173