1# Copyright 2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""Template.""" 16import re 17import os 18 19 20class CppTemplate: 21 """ 22 template for generate c++ code 23 """ 24 regular_str = r"(^[^\n\S]*)?\$([^\d\W]\w*|\{,?[^\d\W]\w*\,?})" 25 regular_match = re.compile(regular_str, re.MULTILINE) 26 27 def __init__(self, code_pattern): 28 self.code_pattern = code_pattern 29 30 @staticmethod 31 def load_from_file(file_path): 32 with open(file_path, "r") as f: 33 return CppTemplate(f.read()) 34 35 def replace(self, **kwargs): 36 """ 37 replace param. 38 :param kwargs: 39 :return: 40 """ 41 42 def find(key: str): 43 if key in kwargs: 44 return kwargs[key] 45 raise TypeError(f"{key} should be in kwargs!") 46 47 def add_indent(indent, var): 48 return "".join([indent + line + "\n" for data in var for line in str(data).splitlines()]).rstrip() 49 50 def extract_variable(key): 51 start = "" 52 end = "" 53 if key[0] == "{": 54 key = key[1:-1] 55 if key[0] == ",": 56 start = "," 57 key = key[1:] 58 if key[-1] == ",": 59 end = ", " 60 key = key[:-1] 61 return find(key), start, end 62 63 def match_rule(match): 64 indent = match.group(1) 65 key = match.group(2) 66 var, start, end = extract_variable(key) 67 if indent is not None: 68 if not isinstance(var, list): 69 return add_indent(indent, [var]) 70 return add_indent(indent, var) 71 if isinstance(var, list): 72 code = ", ".join(str(x) for x in var) 73 if not var: 74 return code 75 return start + code + end 76 return str(var) 77 78 return self.regular_match.sub(match_rule, self.code_pattern) 79 80 81NEW_LINE = "\n" 82WORK_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../../") 83 84PYTHON_PRIM_TEMPLATE = CppTemplate(""" 85 86class _Pyboost${class_name}Prim(${class_name}Prim_): 87 def __call__(self, ${input_args}): 88 ${process_func} 89 return _convert_stub(super().__call__(${input_args})) 90 91 92${func_impl_name}_impl = _Pyboost${class_name}Prim() 93""") 94 95IMPORT_PYBOOST_PRIM_HEADER = f""" 96from mindspore.common._stub_tensor import _convert_stub 97from mindspore.ops.auto_generate.gen_arg_handler import * 98""" 99 100IMPORT_PYBOOST_FUNC_HEADER = f""" 101from mindspore.common import dtype as mstype 102from mindspore.ops.auto_generate.pyboost_inner_prim import * 103 104""" 105 106REGISTER_DEFINE_TEMPLATE = CppTemplate( 107 """ 108 (void)py::class_<${class_name}PrimAdapter, PrimitiveFunctionAdapter, std::shared_ptr<${class_name}PrimAdapter>>( 109 *m, "${class_name}Prim_") 110 .def(py::init<>()) 111 .def("__call__", &${class_name}PrimAdapter::Call, "Call ${class_name} op."); 112 m->def(\"${pyboost_op_name}\", &mindspore::pynative::${pyboost_cfunc_name}, \"Encrypt the data.\");""") 113REGISTER_TEMPLATE = CppTemplate("void RegisterPyBoostFunction(py::module *m) {${register_func}\n}") 114 115REGISTER_PYBOOST_GRAD_DEFINE_TEMPLATE = CppTemplate( 116 "MS_REG_PYBOOST_GRAD_OP(${pyboost_op_name}, mindspore::runtime::${pyboost_cfunc_name});\n") 117REGISTER_PYBOOST_GRAD_TEMPLATE = CppTemplate("${register_func}") 118 119PYBOOST_FUNCTION_TEMPLATE = CppTemplate.load_from_file( 120 os.path.join(WORK_PATH, './mindspore/ccsrc/pipeline/pynative/op_function/template/pyboost_function.tpl')) 121 122PYBOOST_HEADER_TEMPLATE = CppTemplate.load_from_file( 123 os.path.join(WORK_PATH, './mindspore/ccsrc/pipeline/pynative/op_function/template/pyboost_function_header.tpl')) 124 125PYBOOST_GRAD_FUNCTION_TEMPLATE = CppTemplate.load_from_file( 126 os.path.join(WORK_PATH, './mindspore/ccsrc/runtime/pynative/op_function/template/pyboost_grad_function.tpl')) 127 128PYBOOST_GRAD_HEADER_TEMPLATE = CppTemplate.load_from_file( 129 os.path.join(WORK_PATH, 130 './mindspore/ccsrc/runtime/pynative/op_function/template/pyboost_grad_function_header.tpl')) 131 132GEN_OPS_DEF_HEADER_TEMPLATE = CppTemplate.load_from_file( 133 os.path.join(WORK_PATH, './mindspore/python/mindspore/ops_generate/gen_ops_def_header.tpl')) 134 135PYBOOST_BASE_OP_DEFINE_TEMPLATE = CppTemplate.load_from_file( 136 os.path.join(WORK_PATH, './mindspore/ccsrc/kernel/pyboost/template/pyboost_op_header.tpl')) 137 138PYBOOST_OP_REGISTER_TEMPLATE = CppTemplate.load_from_file( 139 os.path.join(WORK_PATH, './mindspore/ccsrc/kernel/pyboost/template/pyboost_op_register.tpl')) 140 141# Ascend op generate 142PYBOOST_ASCEND_OP_HEADER_TEMPLATE = CppTemplate.load_from_file( 143 os.path.join(WORK_PATH, 144 './mindspore/ccsrc/plugin/device/ascend/kernel/pyboost/template/pyboost_aclnn_header_template.tpl')) 145 146PYBOOST_ASCEND_OP_SOURCE_TEMPLATE = CppTemplate.load_from_file( 147 os.path.join(WORK_PATH, 148 './mindspore/ccsrc/plugin/device/ascend/kernel/pyboost/template/pyboost_aclnn_source_template.tpl')) 149 150PYBOOST_ASCEND_CALL_TEMPLATE = CppTemplate.load_from_file( 151 os.path.join(WORK_PATH, 152 './mindspore/ccsrc/plugin/device/ascend/kernel/pyboost/template/pyboost_ascend_call_template.tpl')) 153 154PYBOOST_ASCEND_VIEW_CALL_TEMPLATE = CppTemplate.load_from_file( 155 os.path.join(WORK_PATH, 156 './mindspore/ccsrc/kernel/pyboost/template/' 157 'pyboost_view_template.tpl')) 158 159PYBOOST_ASCEND_CUSTOMIZE_CALL_TEMPLATE = CppTemplate.load_from_file( 160 os.path.join(WORK_PATH, 161 './mindspore/ccsrc/plugin/device/ascend/kernel/pyboost/template' 162 '/pyboost_ascend_customize_call_template.tpl')) 163 164# GPU op generate 165PYBOOST_GPU_OP_HEADER_TEMPLATE = CppTemplate.load_from_file( 166 os.path.join(WORK_PATH, 167 './mindspore/ccsrc/plugin/device/gpu/kernel/pyboost/template/pyboost_gpu_header_template.tpl')) 168 169PYBOOST_GPU_OP_SOURCE_TEMPLATE = CppTemplate.load_from_file( 170 os.path.join(WORK_PATH, 171 './mindspore/ccsrc/plugin/device/gpu/kernel/pyboost/template/pyboost_gpu_source_template.tpl')) 172 173PYBOOST_GPU_CALL_TEMPLATE = CppTemplate.load_from_file( 174 os.path.join(WORK_PATH, 175 './mindspore/ccsrc/plugin/device/gpu/kernel/pyboost/template/pyboost_gpu_call_template.tpl')) 176 177PYBOOST_GPU_VIEW_CALL_TEMPLATE = CppTemplate.load_from_file( 178 os.path.join(WORK_PATH, 179 './mindspore/ccsrc/kernel/pyboost/template/pyboost_view_template.tpl')) 180 181PYBOOST_GPU_CUSTOMIZE_CALL_TEMPLATE = CppTemplate.load_from_file( 182 os.path.join(WORK_PATH, 183 './mindspore/ccsrc/plugin/device/gpu/kernel/pyboost/template' 184 '/pyboost_gpu_customize_call_template.tpl')) 185 186# CPU op generate 187PYBOOST_CPU_OP_HEADER_TEMPLATE = CppTemplate.load_from_file( 188 os.path.join(WORK_PATH, 189 './mindspore/ccsrc/plugin/device/cpu/kernel/pyboost/template/pyboost_cpu_header_template.tpl')) 190 191PYBOOST_CPU_OP_SOURCE_TEMPLATE = CppTemplate.load_from_file( 192 os.path.join(WORK_PATH, 193 './mindspore/ccsrc/plugin/device/cpu/kernel/pyboost/template/pyboost_cpu_source_template.tpl')) 194 195PYBOOST_CPU_CALL_TEMPLATE = CppTemplate.load_from_file( 196 os.path.join(WORK_PATH, 197 './mindspore/ccsrc/plugin/device/cpu/kernel/pyboost/template/pyboost_cpu_call_template.tpl')) 198 199PYBOOST_CPU_VIEW_CALL_TEMPLATE = CppTemplate.load_from_file( 200 os.path.join(WORK_PATH, 201 './mindspore/ccsrc/kernel/pyboost/template/pyboost_view_template.tpl')) 202 203PYBOOST_CPU_CUSTOMIZE_CALL_TEMPLATE = CppTemplate.load_from_file( 204 os.path.join(WORK_PATH, 205 './mindspore/ccsrc/plugin/device/cpu/kernel/pyboost/template' 206 '/pyboost_cpu_customize_call_template.tpl')) 207 208PYBOOST_PY_FUNC_IMPORT_HEADEAR = CppTemplate( 209 """from mindspore._c_expression import ${class_name}Prim_\n""" 210) 211 212PYBOOST_PY_FUNC_TEMPLATE = CppTemplate(""" 213def ${func_name}(${func_args}): 214 r\"\"\" 215 ${description} 216 \"\"\" 217 return ${func_impl_name}_impl(${input_args})\n\n""") 218 219OP_PROTO_TEMPLATE = CppTemplate(""" 220${class_name}FuncImpl g${class_name}FuncImpl; 221OpDef g${class_name} = { 222 /*.name_=*/"${class_name}", 223 /*.args_=*/ { 224 ${input_args} 225 }, 226 /* .returns_ = */ { 227 ${return_args} 228 }, 229 /*.signatures_ =*/ { 230 ${signatures} 231 }, 232 /*.indexes_ =*/ { 233 ${indexes} 234 }, 235 /*.func_impl_=*/g${class_name}FuncImpl, 236 /*.enable_dispatch_ =*/${enable_dispatch}, 237 /*.is_view_ =*/${is_view}, 238}; 239""") 240