1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16import mindspore as ms 17from mindspore import dtype as mstype 18from mindspore._c_expression import typing 19from mindspore.ops.operations import _inner_ops as inner 20from tests.st.ms_adapter._register.ms_adapter_api import Tensor as adapter_Tensor 21 22 23def convert_to_ms_tensor(x): 24 return inner.convert_to_ms_tensor(x) 25 26 27def convert_to_adapter_tensor(x): 28 return inner.convert_to_adapter_tensor(x) 29 30 31def get_registed_fn(ops, *type_names): 32 types = tuple(map(mstype.typing.str_to_type, type_names)) 33 for sigs, fn in ops.entries: 34 if len(sigs) != len(types): 35 continue 36 if any(not typing.is_subclass(type_, sig) for sig, type_ in zip(sigs, types)): 37 continue 38 return fn 39 raise ValueError(f"For 'MultitypeFuncGraph', cannot find fn match given types: {types}.") 40 41 42def convert_output(out): 43 if isinstance(out, ms.Tensor): 44 out = convert_to_adapter_tensor(out) 45 return out 46 47 48def update_multitype_ops_tensor(ops): 49 func = get_registed_fn(ops, "Tensor") 50 51 @ops.register("Tensor") 52 def _tensor(x): 53 if isinstance(x, adapter_Tensor): 54 x = convert_to_ms_tensor(x) 55 out = func(x) 56 out = convert_output(out) 57 else: 58 out = func(x) 59 return out 60 61 62def update_multitype_ops_tensor_tensor(ops): 63 func = get_registed_fn(ops, "Tensor", "Tensor") 64 65 @ops.register("Tensor", "Tensor") 66 def _tensor_and_tensor(x, y): 67 if isinstance(x, adapter_Tensor) and isinstance(y, adapter_Tensor): 68 x = convert_to_ms_tensor(x) 69 y = convert_to_ms_tensor(y) 70 out = func(x, y) 71 out = convert_output(out) 72 else: 73 out = func(x, y) 74 return out 75 76 77def update_multitype_ops_number_tensor(ops): 78 func = get_registed_fn(ops, "Number", "Tensor") 79 80 @ops.register("Number", "Tensor") 81 def _number_and_tensor(x, y): 82 if isinstance(y, adapter_Tensor): 83 y = convert_to_ms_tensor(y) 84 out = func(x, y) 85 out = convert_output(out) 86 else: 87 out = func(x, y) 88 return out 89 90 91def update_multitype_ops_tensor_number(ops): 92 func = get_registed_fn(ops, "Tensor", "Number") 93 94 @ops.register("Tensor", "Number") 95 def _tensor_and_number(x, y): 96 if isinstance(x, adapter_Tensor): 97 x = convert_to_ms_tensor(x) 98 out = func(x, y) 99 out = convert_output(out) 100 else: 101 out = func(x, y) 102 return out 103 104 105def update_multitype_ops_tuple_tensor(ops): 106 func = get_registed_fn(ops, "Tuple", "Tensor") 107 108 @ops.register("Tuple", "Tensor") 109 def _tuple_and_tensor(x, y): 110 if isinstance(y, adapter_Tensor): 111 y = convert_to_ms_tensor(y) 112 out = func(x, y) 113 out = convert_output(out) 114 else: 115 out = func(x, y) 116 return out 117 118 119def update_multitype_ops_tensor_tuple(ops): 120 func = get_registed_fn(ops, "Tensor", "Tuple") 121 122 @ops.register("Tensor", "Tuple") 123 def _tensor_and_tuple(x, y): 124 if isinstance(x, adapter_Tensor): 125 x = convert_to_ms_tensor(x) 126 out = func(x, y) 127 out = convert_output(out) 128 else: 129 out = func(x, y) 130 return out 131 132 133def update_multitype_ops_list_tensor(ops): 134 func = get_registed_fn(ops, "List", "Tensor") 135 136 @ops.register("List", "Tensor") 137 def _list_and_tensor(x, y): 138 if isinstance(y, adapter_Tensor): 139 y = convert_to_ms_tensor(y) 140 out = func(x, y) 141 out = convert_output(out) 142 else: 143 out = func(x, y) 144 return out 145 146 147def update_multitype_ops_tensor_list(ops): 148 func = get_registed_fn(ops, "Tensor", "List") 149 150 @ops.register("Tensor", "List") 151 def _tensor_and_list(x, y): 152 if isinstance(x, adapter_Tensor): 153 x = convert_to_ms_tensor(x) 154 out = func(x, y) 155 out = convert_output(out) 156 else: 157 out = func(x, y) 158 return out 159 160 161def update_multitype_ops_tensor_none(ops): 162 func = get_registed_fn(ops, "Tensor", "None") 163 164 @ops.register("Tensor", "None") 165 def _tensor_and_none(x, y): 166 if isinstance(x, adapter_Tensor): 167 x = convert_to_ms_tensor(x) 168 out = func(x, y) 169 out = convert_output(out) 170 else: 171 out = func(x, y) 172 return out 173 174 175def update_multitype_ops_tensor_slice(ops): 176 func = get_registed_fn(ops, "Tensor", "Slice") 177 178 @ops.register("Tensor", "Slice") 179 def _tensor_and_slice(x, y): 180 if isinstance(x, adapter_Tensor): 181 x = convert_to_ms_tensor(x) 182 out = func(x, y) 183 out = convert_output(out) 184 else: 185 out = func(x, y) 186 return out 187 188 189def update_multitype_ops_setitem_tensor(ops): 190 def register_for_setitem(sigs, fn): 191 @ops.register(*sigs) 192 def _tensor_setitem(data, index, value): 193 if isinstance(data, adapter_Tensor): 194 data = convert_to_ms_tensor(data) 195 out = fn(data, index, value) 196 out = convert_to_adapter_tensor(out) 197 else: 198 out = fn(data, index, value) 199 return out 200 201 entries = ops.entries.copy() 202 for sigs, fn in entries: 203 if typing.is_subclass(sigs[0], mstype.tensor_type): 204 register_for_setitem(sigs, fn) 205