1# Copyright 2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Op Proto.""" 16from pyboost_utils import convert_python_func_name_to_c 17 18 19class Arg: 20 def __init__(self, arg_name, arg_dtype, type_cast, is_type_id=False, as_init_arg=False, default=-1, inplace=''): 21 self.arg_name = arg_name 22 self.arg_dtype = arg_dtype 23 self.type_cast = type_cast 24 self.is_type_id = is_type_id 25 self.as_init_arg = as_init_arg 26 self.default = default 27 self.inplace = inplace 28 29 30class OpProto: 31 """ 32 This class defines mindspore op prototype, we parse ops.yaml to the object, to auto generate primitive 33 and pyboost function. 34 """ 35 36 def __init__(self, 37 operator_name, 38 op_args, 39 returns, 40 class_name, 41 is_pyboost, 42 is_view, 43 cpu, 44 gpu, 45 ascend, 46 prim_init, 47 is_dispatch): 48 self.operator_name = operator_name 49 self.class_name = class_name 50 self.op_args = op_args 51 self.returns = returns 52 self.indexes = {arg.arg_name: index for index, arg in enumerate(op_args)} 53 self.pyboost_function_name = "Pyboost_" + self.class_name 54 self.is_pyboost = is_pyboost 55 self.is_view = is_view 56 self.cpu = cpu 57 self.gpu = gpu 58 self.ascend = ascend 59 self.prim_init = prim_init 60 self.is_dispatch = is_dispatch 61 62 @staticmethod 63 def get_device_special_name(dispatch, gpu, cpu, ascend): 64 if 'GPU' in dispatch.keys(): 65 gpu = dispatch['GPU'] 66 if 'CPU' in dispatch.keys(): 67 cpu = dispatch['CPU'] 68 if 'Ascend' in dispatch.keys(): 69 ascend = dispatch['Ascend'] 70 return gpu, cpu, ascend 71 72 @staticmethod 73 def load_from_yaml(op_name, yaml): 74 """ 75 load from yaml 76 :param op_name: 77 :param yaml: 78 :return: 79 """ 80 if 'args' not in yaml.keys(): 81 raise TypeError("op define need key 'args'") 82 args_dict = yaml.get('args') 83 op_args = [] 84 default_str = 'default' 85 is_type_id = False 86 prim_init = False 87 for arg_name in args_dict.keys(): 88 arg_dtype = args_dict[arg_name]['dtype'] 89 if arg_dtype == 'TypeId': 90 arg_dtype = 'int' 91 default = None 92 as_init_arg = False 93 is_type_id = False 94 type_cast = [] 95 if default_str in args_dict[arg_name]: 96 default = args_dict[arg_name][default_str] 97 as_init_arg = True 98 if 'prim_init' in args_dict[arg_name]: 99 prim_init = args_dict[arg_name]['prim_init'] 100 if 'type_cast' in args_dict[arg_name]: 101 type_cast = [cast_type.strip() for cast_type in args_dict[arg_name]['type_cast'].split(',')] 102 arg_handler_key = 'arg_handler' 103 if arg_handler_key in args_dict[arg_name] and args_dict[arg_name][arg_handler_key] == 'dtype_to_type_id': 104 is_type_id = True 105 arg = Arg(arg_name, arg_dtype, type_cast, is_type_id, as_init_arg, default) 106 op_args.append(arg) 107 if 'returns' not in yaml.keys(): 108 raise TypeError("op define need key 'returns'") 109 110 is_pyboost = False 111 is_dispatch = False 112 gpu = default_str 113 cpu = default_str 114 ascend = default_str 115 dispatch_key = 'dispatch' 116 if dispatch_key in yaml.keys(): 117 is_dispatch = True 118 is_pyboost = yaml[dispatch_key].get('enable') 119 gpu, cpu, ascend = OpProto.get_device_special_name(yaml[dispatch_key], gpu, cpu, ascend) 120 return_dict = yaml['returns'] 121 class_name = convert_python_func_name_to_c(op_name) 122 class_key = 'class' 123 if class_key in yaml.keys() and 'name' in yaml[class_key].keys(): 124 class_name = yaml[class_key]['name'] 125 return_args = [] 126 for return_name in return_dict.keys(): 127 inplace = '' 128 if 'inplace' in return_dict[return_name]: 129 inplace = return_dict[return_name]['inplace'] 130 dtype = return_dict[return_name]['dtype'] 131 arg = Arg(return_name, dtype, type_cast=[], inplace=inplace) 132 return_args.append(arg) 133 is_view = False 134 if 'view' in yaml.keys(): 135 is_view = True 136 op_proto = OpProto(op_name, op_args, return_args, class_name, 137 is_pyboost, is_view, cpu, gpu, ascend, prim_init, is_dispatch) 138 return op_proto 139