1# Copyright 2022 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16from mindspore._extends.parse import trope as T 17from mindspore._extends.parse.resources import convert_object_map 18from tests.st.ms_adapter._register.ms_adapter_api import Tensor as adapter_Tensor 19from tests.st.ms_adapter._register.utils import convert_to_ms_tensor, convert_to_adapter_tensor 20 21 22matmul_fn = convert_object_map.get(T.matmul) 23invert_fn = convert_object_map.get(T.invert) 24abs_fn = convert_object_map.get(T.abs) 25round_fn = convert_object_map.get(T.round) 26max_fn = convert_object_map.get(T.max) 27min_fn = convert_object_map.get(T.min) 28sum_fn = convert_object_map.get(T.sum) 29 30 31def adapter_matmul(x, y): 32 if isinstance(x, adapter_Tensor) and isinstance(y, adapter_Tensor): 33 x = convert_to_ms_tensor(x) 34 y = convert_to_ms_tensor(y) 35 out = matmul_fn(x, y) 36 out = convert_to_adapter_tensor(out) 37 else: 38 out = matmul_fn(x, y) 39 return out 40 41 42def adapter_invert(x): 43 if isinstance(x, adapter_Tensor): 44 x = convert_to_ms_tensor(x) 45 out = invert_fn(x) 46 out = convert_to_adapter_tensor(out) 47 else: 48 out = invert_fn(x) 49 return out 50 51 52def adapter_abs(x): 53 if isinstance(x, adapter_Tensor): 54 x = convert_to_ms_tensor(x) 55 out = abs_fn(x) 56 out = convert_to_adapter_tensor(out) 57 else: 58 out = abs_fn(x) 59 return out 60 61 62def adapter_round(*data): 63 if (len(data) == 1 and isinstance(data[0], adapter_Tensor)) or \ 64 (len(data) == 2 and isinstance(data[0], adapter_Tensor) and isinstance(data[1], None)): 65 x = data[0] 66 x = convert_to_ms_tensor(x) 67 out = round_fn(x) 68 out = convert_to_adapter_tensor(out) 69 else: 70 out = round_fn(*data) 71 return out 72 73 74def _has_adapter_tensor(*data): 75 if len(data) == 1 and isinstance(data[0], adapter_Tensor): 76 return True 77 for elem in data: 78 if isinstance(elem, adapter_Tensor): 79 return True 80 return False 81 82 83def adapter_max(*data): 84 if _has_adapter_tensor(*data): 85 out = max_fn(*data) 86 out = convert_to_adapter_tensor(out) 87 else: 88 out = max_fn(*data) 89 return out 90 91 92def adapter_min(*data): 93 if _has_adapter_tensor(*data): 94 out = min_fn(*data) 95 out = convert_to_adapter_tensor(out) 96 else: 97 out = min_fn(*data) 98 return out 99 100 101def adapter_sum(*data): 102 if _has_adapter_tensor(*data): 103 out = sum_fn(*data) 104 out = convert_to_adapter_tensor(out) 105 else: 106 out = sum_fn(*data) 107 return out 108 109 110def create_adapter_tensor(*data, dtype=None, inner=False, cast_tensor=False): 111 return adapter_Tensor(*data, dtype=dtype, inner=inner, cast_tensor=cast_tensor) # @jit.typing: () -> tensor_type[{dtype}] 112