• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16import mindspore as ms
17from mindspore import dtype as mstype
18from mindspore._c_expression import typing
19from mindspore.ops.operations import _inner_ops as inner
20from tests.st.ms_adapter._register.ms_adapter_api import Tensor as adapter_Tensor
21
22
23def convert_to_ms_tensor(x):
24    return inner.convert_to_ms_tensor(x)
25
26
27def convert_to_adapter_tensor(x):
28    return inner.convert_to_adapter_tensor(x)
29
30
31def get_registed_fn(ops, *type_names):
32    types = tuple(map(mstype.typing.str_to_type, type_names))
33    for sigs, fn in ops.entries:
34        if len(sigs) != len(types):
35            continue
36        if any(not typing.is_subclass(type_, sig) for sig, type_ in zip(sigs, types)):
37            continue
38        return fn
39    raise ValueError(f"For 'MultitypeFuncGraph', cannot find fn match given types: {types}.")
40
41
42def convert_output(out):
43    if isinstance(out, ms.Tensor):
44        out = convert_to_adapter_tensor(out)
45    return out
46
47
48def update_multitype_ops_tensor(ops):
49    func = get_registed_fn(ops, "Tensor")
50
51    @ops.register("Tensor")
52    def _tensor(x):
53        if isinstance(x, adapter_Tensor):
54            x = convert_to_ms_tensor(x)
55            out = func(x)
56            out = convert_output(out)
57        else:
58            out = func(x)
59        return out
60
61
62def update_multitype_ops_tensor_tensor(ops):
63    func = get_registed_fn(ops, "Tensor", "Tensor")
64
65    @ops.register("Tensor", "Tensor")
66    def _tensor_and_tensor(x, y):
67        if isinstance(x, adapter_Tensor) and isinstance(y, adapter_Tensor):
68            x = convert_to_ms_tensor(x)
69            y = convert_to_ms_tensor(y)
70            out = func(x, y)
71            out = convert_output(out)
72        else:
73            out = func(x, y)
74        return out
75
76
77def update_multitype_ops_number_tensor(ops):
78    func = get_registed_fn(ops, "Number", "Tensor")
79
80    @ops.register("Number", "Tensor")
81    def _number_and_tensor(x, y):
82        if isinstance(y, adapter_Tensor):
83            y = convert_to_ms_tensor(y)
84            out = func(x, y)
85            out = convert_output(out)
86        else:
87            out = func(x, y)
88        return out
89
90
91def update_multitype_ops_tensor_number(ops):
92    func = get_registed_fn(ops, "Tensor", "Number")
93
94    @ops.register("Tensor", "Number")
95    def _tensor_and_number(x, y):
96        if isinstance(x, adapter_Tensor):
97            x = convert_to_ms_tensor(x)
98            out = func(x, y)
99            out = convert_output(out)
100        else:
101            out = func(x, y)
102        return out
103
104
105def update_multitype_ops_tuple_tensor(ops):
106    func = get_registed_fn(ops, "Tuple", "Tensor")
107
108    @ops.register("Tuple", "Tensor")
109    def _tuple_and_tensor(x, y):
110        if isinstance(y, adapter_Tensor):
111            y = convert_to_ms_tensor(y)
112            out = func(x, y)
113            out = convert_output(out)
114        else:
115            out = func(x, y)
116        return out
117
118
119def update_multitype_ops_tensor_tuple(ops):
120    func = get_registed_fn(ops, "Tensor", "Tuple")
121
122    @ops.register("Tensor", "Tuple")
123    def _tensor_and_tuple(x, y):
124        if isinstance(x, adapter_Tensor):
125            x = convert_to_ms_tensor(x)
126            out = func(x, y)
127            out = convert_output(out)
128        else:
129            out = func(x, y)
130        return out
131
132
133def update_multitype_ops_list_tensor(ops):
134    func = get_registed_fn(ops, "List", "Tensor")
135
136    @ops.register("List", "Tensor")
137    def _list_and_tensor(x, y):
138        if isinstance(y, adapter_Tensor):
139            y = convert_to_ms_tensor(y)
140            out = func(x, y)
141            out = convert_output(out)
142        else:
143            out = func(x, y)
144        return out
145
146
147def update_multitype_ops_tensor_list(ops):
148    func = get_registed_fn(ops, "Tensor", "List")
149
150    @ops.register("Tensor", "List")
151    def _tensor_and_list(x, y):
152        if isinstance(x, adapter_Tensor):
153            x = convert_to_ms_tensor(x)
154            out = func(x, y)
155            out = convert_output(out)
156        else:
157            out = func(x, y)
158        return out
159
160
161def update_multitype_ops_tensor_none(ops):
162    func = get_registed_fn(ops, "Tensor", "None")
163
164    @ops.register("Tensor", "None")
165    def _tensor_and_none(x, y):
166        if isinstance(x, adapter_Tensor):
167            x = convert_to_ms_tensor(x)
168            out = func(x, y)
169            out = convert_output(out)
170        else:
171            out = func(x, y)
172        return out
173
174
175def update_multitype_ops_tensor_slice(ops):
176    func = get_registed_fn(ops, "Tensor", "Slice")
177
178    @ops.register("Tensor", "Slice")
179    def _tensor_and_slice(x, y):
180        if isinstance(x, adapter_Tensor):
181            x = convert_to_ms_tensor(x)
182            out = func(x, y)
183            out = convert_output(out)
184        else:
185            out = func(x, y)
186        return out
187
188
189def update_multitype_ops_setitem_tensor(ops):
190    def register_for_setitem(sigs, fn):
191        @ops.register(*sigs)
192        def _tensor_setitem(data, index, value):
193            if isinstance(data, adapter_Tensor):
194                data = convert_to_ms_tensor(data)
195                out = fn(data, index, value)
196                out = convert_to_adapter_tensor(out)
197            else:
198                out = fn(data, index, value)
199            return out
200
201    entries = ops.entries.copy()
202    for sigs, fn in entries:
203        if typing.is_subclass(sigs[0], mstype.tensor_type):
204            register_for_setitem(sigs, fn)
205