• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16from mindspore._extends.parse import trope as T
17from mindspore._extends.parse.resources import convert_object_map
18from tests.st.ms_adapter._register.ms_adapter_api import Tensor as adapter_Tensor
19from tests.st.ms_adapter._register.utils import convert_to_ms_tensor, convert_to_adapter_tensor
20
21
22matmul_fn = convert_object_map.get(T.matmul)
23invert_fn = convert_object_map.get(T.invert)
24abs_fn = convert_object_map.get(T.abs)
25round_fn = convert_object_map.get(T.round)
26max_fn = convert_object_map.get(T.max)
27min_fn = convert_object_map.get(T.min)
28sum_fn = convert_object_map.get(T.sum)
29
30
31def adapter_matmul(x, y):
32    if isinstance(x, adapter_Tensor) and isinstance(y, adapter_Tensor):
33        x = convert_to_ms_tensor(x)
34        y = convert_to_ms_tensor(y)
35        out = matmul_fn(x, y)
36        out = convert_to_adapter_tensor(out)
37    else:
38        out = matmul_fn(x, y)
39    return out
40
41
42def adapter_invert(x):
43    if isinstance(x, adapter_Tensor):
44        x = convert_to_ms_tensor(x)
45        out = invert_fn(x)
46        out = convert_to_adapter_tensor(out)
47    else:
48        out = invert_fn(x)
49    return out
50
51
52def adapter_abs(x):
53    if isinstance(x, adapter_Tensor):
54        x = convert_to_ms_tensor(x)
55        out = abs_fn(x)
56        out = convert_to_adapter_tensor(out)
57    else:
58        out = abs_fn(x)
59    return out
60
61
62def adapter_round(*data):
63    if (len(data) == 1 and isinstance(data[0], adapter_Tensor)) or \
64      (len(data) == 2 and isinstance(data[0], adapter_Tensor) and isinstance(data[1], None)):
65        x = data[0]
66        x = convert_to_ms_tensor(x)
67        out = round_fn(x)
68        out = convert_to_adapter_tensor(out)
69    else:
70        out = round_fn(*data)
71    return out
72
73
74def _has_adapter_tensor(*data):
75    if len(data) == 1 and isinstance(data[0], adapter_Tensor):
76        return True
77    for elem in data:
78        if isinstance(elem, adapter_Tensor):
79            return True
80    return False
81
82
83def adapter_max(*data):
84    if _has_adapter_tensor(*data):
85        out = max_fn(*data)
86        out = convert_to_adapter_tensor(out)
87    else:
88        out = max_fn(*data)
89    return out
90
91
92def adapter_min(*data):
93    if _has_adapter_tensor(*data):
94        out = min_fn(*data)
95        out = convert_to_adapter_tensor(out)
96    else:
97        out = min_fn(*data)
98    return out
99
100
101def adapter_sum(*data):
102    if _has_adapter_tensor(*data):
103        out = sum_fn(*data)
104        out = convert_to_adapter_tensor(out)
105    else:
106        out = sum_fn(*data)
107    return out
108
109
110def create_adapter_tensor(*data, dtype=None, inner=False, cast_tensor=False):
111    return adapter_Tensor(*data, dtype=dtype, inner=inner, cast_tensor=cast_tensor) # @jit.typing: () -> tensor_type[{dtype}]
112