• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2023 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""Op Proto."""
16from pyboost_utils import convert_python_func_name_to_c
17
18
19class Arg:
20    def __init__(self, arg_name, arg_dtype, type_cast, is_type_id=False, as_init_arg=False, default=-1, inplace=''):
21        self.arg_name = arg_name
22        self.arg_dtype = arg_dtype
23        self.type_cast = type_cast
24        self.is_type_id = is_type_id
25        self.as_init_arg = as_init_arg
26        self.default = default
27        self.inplace = inplace
28
29
30class OpProto:
31    """
32    This class defines mindspore op prototype, we parse ops.yaml to the object, to auto generate primitive
33    and pyboost function.
34    """
35
36    def __init__(self,
37                 operator_name,
38                 op_args,
39                 returns,
40                 class_name,
41                 is_pyboost,
42                 is_view,
43                 cpu,
44                 gpu,
45                 ascend,
46                 prim_init,
47                 is_dispatch):
48        self.operator_name = operator_name
49        self.class_name = class_name
50        self.op_args = op_args
51        self.returns = returns
52        self.indexes = {arg.arg_name: index for index, arg in enumerate(op_args)}
53        self.pyboost_function_name = "Pyboost_" + self.class_name
54        self.is_pyboost = is_pyboost
55        self.is_view = is_view
56        self.cpu = cpu
57        self.gpu = gpu
58        self.ascend = ascend
59        self.prim_init = prim_init
60        self.is_dispatch = is_dispatch
61
62    @staticmethod
63    def get_device_special_name(dispatch, gpu, cpu, ascend):
64        if 'GPU' in dispatch.keys():
65            gpu = dispatch['GPU']
66        if 'CPU' in dispatch.keys():
67            cpu = dispatch['CPU']
68        if 'Ascend' in dispatch.keys():
69            ascend = dispatch['Ascend']
70        return gpu, cpu, ascend
71
72    @staticmethod
73    def load_from_yaml(op_name, yaml):
74        """
75        load from yaml
76        :param op_name:
77        :param yaml:
78        :return:
79        """
80        if 'args' not in yaml.keys():
81            raise TypeError("op define need key 'args'")
82        args_dict = yaml.get('args')
83        op_args = []
84        default_str = 'default'
85        is_type_id = False
86        prim_init = False
87        for arg_name in args_dict.keys():
88            arg_dtype = args_dict[arg_name]['dtype']
89            if arg_dtype == 'TypeId':
90                arg_dtype = 'int'
91            default = None
92            as_init_arg = False
93            is_type_id = False
94            type_cast = []
95            if default_str in args_dict[arg_name]:
96                default = args_dict[arg_name][default_str]
97                as_init_arg = True
98            if 'prim_init' in args_dict[arg_name]:
99                prim_init = args_dict[arg_name]['prim_init']
100            if 'type_cast' in args_dict[arg_name]:
101                type_cast = [cast_type.strip() for cast_type in args_dict[arg_name]['type_cast'].split(',')]
102            arg_handler_key = 'arg_handler'
103            if arg_handler_key in args_dict[arg_name] and args_dict[arg_name][arg_handler_key] == 'dtype_to_type_id':
104                is_type_id = True
105            arg = Arg(arg_name, arg_dtype, type_cast, is_type_id, as_init_arg, default)
106            op_args.append(arg)
107        if 'returns' not in yaml.keys():
108            raise TypeError("op define need key 'returns'")
109
110        is_pyboost = False
111        is_dispatch = False
112        gpu = default_str
113        cpu = default_str
114        ascend = default_str
115        dispatch_key = 'dispatch'
116        if dispatch_key in yaml.keys():
117            is_dispatch = True
118            is_pyboost = yaml[dispatch_key].get('enable')
119            gpu, cpu, ascend = OpProto.get_device_special_name(yaml[dispatch_key], gpu, cpu, ascend)
120        return_dict = yaml['returns']
121        class_name = convert_python_func_name_to_c(op_name)
122        class_key = 'class'
123        if class_key in yaml.keys() and 'name' in yaml[class_key].keys():
124            class_name = yaml[class_key]['name']
125        return_args = []
126        for return_name in return_dict.keys():
127            inplace = ''
128            if 'inplace' in return_dict[return_name]:
129                inplace = return_dict[return_name]['inplace']
130            dtype = return_dict[return_name]['dtype']
131            arg = Arg(return_name, dtype, type_cast=[], inplace=inplace)
132            return_args.append(arg)
133        is_view = False
134        if 'view' in yaml.keys():
135            is_view = True
136        op_proto = OpProto(op_name, op_args, return_args, class_name,
137                           is_pyboost, is_view, cpu, gpu, ascend, prim_init, is_dispatch)
138        return op_proto
139