1# Copyright 2023 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16Generate operator definition from ops.yaml 17""" 18import os 19import re 20import shutil 21import pathlib 22import logging 23import gen_utils 24from gen_utils import (py_licence_str, cc_license_str, check_change_and_replace_file, merge_files, 25 merge_files_append, safe_load_yaml, convert_dtype_str, write_file) 26from pyboost_utils import get_pyboost_name, is_pyboost_enable, AclnnUtils, get_dtypes 27import template 28from template import CppTemplate 29from gen_pyboost_func import gen_pyboost_code 30from gen_aclnn_implement import gen_aclnn_kernel 31 32 33def _get_op_name(yaml_key, yaml_value): 34 """ 35 Get op name for python class Primitive or c++ OpDef name. 36 """ 37 # If has class item, use the specified item. 38 class_def = yaml_value.get("class") 39 if class_def is not None: 40 class_name_specify = class_def.get("name") 41 if class_name_specify is not None: 42 return class_name_specify 43 # Else use the default rule generate class name. 44 op_name = yaml_key 45 class_name_normal = ''.join(word.capitalize() for word in op_name.split('_')) 46 return class_name_normal 47 48 49def _get_op_func_name(yaml_key, yaml_value): 50 func_def = yaml_value.get('function') 51 func_name = yaml_key 52 53 if func_def is not None: 54 item = func_def.get("name") 55 if item is not None: 56 func_name = item 57 return func_name 58 59 60def _auto_generate_class_disabled(yaml_value): 61 """Check whether class can be auto generated.""" 62 if 'class' not in yaml_value.keys(): 63 return False 64 class_def = yaml_value.get("class") 65 if 'disable' not in class_def.keys(): 66 return False 67 disable_item = class_def.get("disable") 68 if disable_item is True: 69 return True 70 if disable_item is False: 71 return False 72 raise TypeError(f"The disable label for class should be True or False, but get {disable_item}.") 73 74 75def _auto_generate_func_disabled(yaml_value): 76 """Check whether function can be auto generated.""" 77 if 'function' not in yaml_value.keys(): 78 return False 79 func_def = yaml_value.get('function') 80 if 'disable' not in func_def.keys(): 81 return False 82 disable_item = func_def.get("disable") 83 if disable_item is True: 84 return True 85 if disable_item is False: 86 return False 87 raise TypeError(f"The disable label for function should be True or False, but get {disable_item}.") 88 89 90def signature_get_rw_label(arg_name, write_list, read_list, ref_list): 91 """ 92 Generate signature rw code 93 """ 94 for rw_arg_name in write_list: 95 if rw_arg_name == arg_name: 96 return ', sig.sig_rw.RW_WRITE' 97 for read_arg_name in read_list: 98 if read_arg_name == arg_name: 99 return ', sig.sig_rw.RW_READ' 100 for ref_arg_name in ref_list: 101 if ref_arg_name == arg_name: 102 return ', sig.sig_rw.RW_REF' 103 return '' 104 105 106def signature_get_rw_label_cc(rw_op_name, write_list, read_list, ref_list): 107 """ 108 Generate cc signature rw code 109 """ 110 rw_label = 'kRWDefault' 111 for op in write_list: 112 if op == rw_op_name: 113 rw_label = 'kRWWrite' 114 for op in read_list: 115 if op == rw_op_name: 116 rw_label = 'kRWRead' 117 for op in ref_list: 118 if op == rw_op_name: 119 rw_label = 'kRWRef' 120 return 'SignatureEnumRW::' + rw_label 121 122 123def signature_get_enum_dtype_cc(index): 124 """ 125 Generate cc enum dtype code 126 """ 127 enum_type = 'SignatureEnumDType::' 128 type_map = {0: 'kDType', 129 1: 'kDType1', 130 2: 'kDType2', 131 3: 'kDType3', 132 4: 'kDType4', 133 5: 'kDType5', 134 6: 'kDType6', 135 7: 'kDType7', 136 8: 'kDType8', 137 9: 'kDType9'} 138 if index in type_map: 139 return enum_type + type_map[index] 140 return enum_type + 'kDTypeEmptyDefaultValue' 141 142 143def signature_get_dtype_label(index): 144 """ 145 Generate signature dtype code 146 """ 147 dtype_index = '' 148 if index > 0: 149 dtype_index = f"""{index}""" 150 return f"""dtype=sig.sig_dtype.T{dtype_index}""" 151 152 153def get_same_dtype_groups(args_signature, args_name): 154 """ 155 Get same dtype groups 156 """ 157 same_dtype_groups = {} 158 dtype_conut = 0 159 if args_signature is None: 160 return same_dtype_groups, dtype_conut 161 162 dtype_group = args_signature.get('dtype_group') 163 if dtype_group is not None: 164 args_list = [] 165 match = re.findall(r'\((.*?)\)', dtype_group) 166 for item in match: 167 args_list.append(item.replace(' ', '').split(",")) 168 for arg_name in args_name: 169 if arg_name in same_dtype_groups: 170 continue 171 is_match = False 172 for group in args_list: 173 if arg_name in group: 174 is_match = True 175 for item in group: 176 same_dtype_groups[item] = dtype_conut 177 break 178 if not is_match: 179 same_dtype_groups[arg_name] = dtype_conut 180 dtype_conut = dtype_conut + 1 181 return same_dtype_groups, dtype_conut 182 183 184def generate_py_op_signature(op_name, args_signature, args_name, args_default): 185 """ 186 Generate __mindspore_signature__ 187 """ 188 def _check_signature_arg_valid(op_name, sig_arg_names, args_names): 189 for sig_arg_name in sig_arg_names: 190 if sig_arg_name not in args_names: 191 raise ValueError(f"Op {op_name} has no input arg named '{sig_arg_name}'!") 192 193 if args_signature is None and not args_default: 194 return '' 195 196 signature_code = f""" __mindspore_signature__ = """ 197 198 # Init rw. 199 write_list = [] 200 read_list = [] 201 ref_list = [] 202 if args_signature is not None: 203 rw_write = args_signature.get('rw_write') 204 rw_read = args_signature.get('rw_read') 205 rw_ref = args_signature.get('rw_ref') 206 if rw_write is not None: 207 write_list = rw_write.replace(' ', '').split(",") 208 _check_signature_arg_valid(op_name, write_list, args_name) 209 if rw_read is not None: 210 read_list = rw_read.replace(' ', '').split(",") 211 _check_signature_arg_valid(op_name, read_list, args_name) 212 if rw_ref is not None: 213 ref_list = rw_ref.replace(' ', '').split(",") 214 _check_signature_arg_valid(op_name, ref_list, args_name) 215 # Init dtype group. 216 same_dtype_groups, dtype_conut = get_same_dtype_groups(args_signature, args_name) 217 _check_signature_arg_valid(op_name, list(same_dtype_groups.keys()), args_name) 218 # Only one dtype_group is set. 219 if dtype_conut == 1 and not any([write_list, read_list, ref_list, args_default]): 220 signature_code += '(' 221 for _ in range(len(args_name) - 1): 222 signature_code += 'sig.sig_dtype.T, ' 223 signature_code += 'sig.sig_dtype.T)\n\n' 224 return signature_code 225 226 # Set sig.make_sig. 227 signature_code += f""" (\n""" 228 for arg_name in args_name: 229 signature_code += f""" sig.make_sig('{arg_name}'""" 230 signature_code += signature_get_rw_label(arg_name, write_list, read_list, ref_list) 231 if arg_name in same_dtype_groups: 232 signature_code += f""", """ + signature_get_dtype_label(same_dtype_groups[arg_name]) 233 if arg_name in args_default: 234 signature_code += f""", default=""" + str(args_default[arg_name]) 235 signature_code += f"""),\n""" 236 signature_code += f""" )\n\n""" 237 return signature_code 238 239 240def generate_cc_op_signature(args_signature, args_name): 241 """ 242 generate signatures on in cc file 243 :param args_signature: 244 :param args_name: 245 :return: 246 """ 247 if args_signature is None: 248 return '' 249 signature_code = '' 250 # Init rw. 251 write_list = [] 252 read_list = [] 253 ref_list = [] 254 if args_signature is not None: 255 rw_write = args_signature.get('rw_write') 256 rw_read = args_signature.get('rw_read') 257 rw_ref = args_signature.get('rw_ref') 258 if rw_write is not None: 259 write_list = rw_write.replace(' ', '').split(",") 260 if rw_read is not None: 261 read_list = rw_read.replace(' ', '').split(",") 262 if rw_ref is not None: 263 ref_list = rw_ref.replace(' ', '').split(",") 264 # Init dtype group. 265 same_dtype_groups, _ = get_same_dtype_groups(args_signature, args_name) 266 for arg_name in args_name: 267 enum_rw = signature_get_rw_label_cc(arg_name, write_list, read_list, ref_list) 268 enum_dtype = signature_get_enum_dtype_cc(same_dtype_groups.get(arg_name)) 269 signature = f"""Signature("{arg_name}", {enum_rw}, \ 270 SignatureEnumKind::kKindPositionalKeyword, nullptr, {enum_dtype}),\n """ 271 signature_code += signature 272 return signature_code 273 274 275def generate_py_op_deprecated(deprecated): 276 """ 277 Generate @deprecated 278 """ 279 if deprecated is None: 280 return '' 281 version = deprecated.get("version") 282 if version is None: 283 raise ValueError("The version of deprecated can't be None.") 284 substitute = deprecated.get("substitute") 285 if substitute is None: 286 raise ValueError("The substitute of deprecated can't be None.") 287 use_substitute = deprecated.get("use_substitute") 288 if use_substitute is None: 289 raise ValueError("The use_substitute of deprecated can't be None.") 290 if use_substitute is not True and use_substitute is not False: 291 raise ValueError(f"The use_substitute must be True or False, but got {use_substitute}") 292 293 deprecated = f""" @deprecated("{version}", "{substitute}", {use_substitute})\n""" 294 return deprecated 295 296 297def _normalize_func_description_fromat(description): 298 """ 299 Process description. 300 """ 301 if not description: 302 return description 303 lines = description.split("\n") 304 if len(lines) == 1: 305 return description 306 # Add line indentation to other lines after the first line 307 for i in range(1, len(lines)): 308 indent = " " if lines[i] else "" 309 lines[i] = indent + lines[i] 310 # Remove trailing blank lines 311 lines = lines if lines[-1] != "" else lines[:-1] 312 description = "\n".join(lines) 313 return description 314 315 316def _get_op_description(operator_name, doc_str): 317 """ 318 Generate ops api description. 319 """ 320 if doc_str is None: 321 print(f"Description is None, op_name: {operator_name}") 322 return "" 323 description = doc_str.get(operator_name) 324 if description is None: 325 print(f"Description is None, op_name: {operator_name}") 326 return "" 327 description = description.get("description") 328 if description is None: 329 print(f"Description is None, op_name: {operator_name}") 330 return "" 331 return _normalize_func_description_fromat(description) 332 333 334def generate_py_op_func(yaml_data, doc_data): 335 """ 336 Generate operator python function api. 337 """ 338 gen_py = '' 339 340 for operator_name, operator_data in yaml_data.items(): 341 if _auto_generate_func_disabled(operator_data): 342 continue 343 func_name = _get_op_func_name(operator_name, operator_data) 344 args = operator_data.get('args') 345 class_name = _get_op_name(operator_name, operator_data) 346 func_args = [] 347 prim_init_args = [] 348 prim_call_args = [] 349 for arg_name, arg_info in args.items(): 350 is_prim_init = arg_info.get('prim_init') 351 has_default = 'default' in arg_info.keys() 352 353 # step1: Process function args. 354 if not has_default: 355 func_args.append(f"""{arg_name}""") 356 else: 357 default_value = arg_info.get('default') 358 func_args.append(f"""{arg_name}={default_value}""") 359 360 # step2: Process primitive object init args. 361 if is_prim_init: 362 prim_init_args.append(arg_name) 363 364 # step3: Process primitive object call args. 365 else: 366 prim_call_args.append(arg_name) 367 description = _get_op_description(operator_name, doc_data) 368 function_code = f"""\n 369def {func_name}({', '.join(arg for arg in func_args)}): 370 r\"\"\" 371 {description} 372 \"\"\" 373 {operator_name}_op = _get_cache_prim({class_name})({', '.join(arg_name for arg_name in prim_init_args)}) 374 return {operator_name}_op({', '.join(arg_name for arg_name in prim_call_args)})\n""" 375 376 if not prim_init_args: 377 if _auto_generate_class_disabled(operator_data): 378 gen_py += f"""\n{operator_name}_op={class_name}()""" 379 function_code = f"""\n 380def {func_name}({', '.join(arg for arg in func_args)}): 381 r\"\"\" 382 {description} 383 \"\"\" 384 return {operator_name}_op({', '.join(arg_name for arg_name in prim_call_args)})\n""" 385 else: 386 dis = operator_data.get("dispatch") 387 if dis is not None: 388 enable_pyboost = dis.get("enable") 389 if enable_pyboost: 390 function_code = f"""\n 391def {func_name}({', '.join(arg for arg in func_args)}): 392 r\"\"\" 393 {description} 394 \"\"\" 395 return {operator_name}_impl({', '.join(arg_name for arg_name, _ in args.items())})\n""" 396 gen_py += function_code 397 398 return gen_py 399 400 401def get_dtype(arg_info): 402 dtype = arg_info.get('dtype') 403 # Currently, TypeId is represented by int 404 if dtype == 'TypeId': 405 dtype = 'int' 406 return dtype 407 408 409def process_args(class_name, args): 410 """ 411 Process arg for yaml, get arg_name, init value, type cast, arg_handler, etc. 412 """ 413 inputs_name = [] 414 args_name = [] 415 args_assign = [] 416 inputs_default = {} 417 init_args_with_default = [] 418 args_handlers = {} 419 for arg_name, arg_info in args.items(): 420 dtype = get_dtype(arg_info) 421 default_value = arg_info.get('default') 422 has_default = 'default' in arg_info.keys() 423 is_prim_init = arg_info.get('prim_init') 424 arg_handler = arg_info.get('arg_handler') 425 426 # step1: get args infos: 427 if is_prim_init: 428 # step1.1: get args name: 429 args_name.append(arg_name) 430 # step1.2: get args assign with default value: 431 if has_default: 432 init_args_with_default.append(f"""{arg_name}={default_value}""") 433 else: 434 init_args_with_default.append(f"""{arg_name}""") 435 436 # step1.3: get args set prim arg expression: 437 assign_str = gen_utils.get_assign_str_by_type_it(class_name, arg_info, arg_name, dtype) 438 if arg_handler: 439 assign_str = f""" self._set_prim_arg_with_handler("{arg_name}", {assign_str}, {arg_handler})""" 440 else: 441 assign_str = f""" self._set_prim_arg("{arg_name}", {assign_str})""" 442 args_assign.append(assign_str) 443 # step2: get inputs infos: 444 else: 445 # step2.1: get inputs name: 446 inputs_name.append(arg_name) 447 448 # step2.2: get default value of inputs: 449 if has_default: 450 inputs_default[arg_name] = default_value 451 452 # step2.3: get args_handler functions for inputs 453 if arg_handler: 454 args_handlers[arg_name] = arg_handler 455 456 return inputs_name, inputs_default, args_name, args_assign, init_args_with_default, args_handlers 457 458 459def generate_pyboost_import_header(yaml_data): 460 """ 461 Generate python primitive 462 """ 463 pyboost_import_header = '' 464 import_pyboost = CppTemplate("from mindspore._c_expression import $var\n") 465 for operator_name, operator_data in yaml_data.items(): 466 is_pyboost = is_pyboost_enable(operator_data) 467 if is_pyboost: 468 header = import_pyboost.replace(var=get_pyboost_name(operator_name)) 469 pyboost_import_header += header 470 return pyboost_import_header 471 472 473def _generate_class_description(class_name, func_name, input_args, init_args, func_disabled, doc_str): 474 """Generate description for every primitive definition.""" 475 if func_disabled: 476 # if function disabled, function name is equal to operator_name 477 description = _get_op_description(func_name, doc_str) 478 description = f""" r\"\"\" 479 {description} 480 \"\"\" 481""" 482 return description 483 484 # If function is an released API, refer to the function doc. 485 description_str = f""" r\"\"\" 486 .. code-block:: 487 488 prim = ops.{class_name}({', '.join(init_args)}) 489 out = prim({', '.join(input_args)}) 490 491 is equivalent to 492 493 .. code-block:: 494 495 ops.{func_name}({", ".join(input_args + init_args)}) 496 497 Refer to :func:`mindspore.ops.{func_name}` for more details. 498 \"\"\" 499""" 500 return description_str 501 502 503def get_init_code(init_code, operator_data): 504 """ 505 Generate init code for primitive 506 """ 507 labels = operator_data.get('labels') 508 if labels is not None: 509 if init_code != "": 510 init_code += "\n" 511 init_code += \ 512 '\n'.join([f""" self.add_prim_attr("{key}", {value})""" for key, value in labels.items()]) 513 if init_code == "": 514 init_code = f""" pass""" 515 return init_code 516 517 518def generate_py_primitive(yaml_data, doc_str): 519 """ 520 Generate python primitive 521 """ 522 523 def _generate_arg_handler(class_name, arg, arg_handler, is_optional): 524 """Generate arg_handler""" 525 arg_handler_call = f"""{arg_handler}('{class_name}', '{arg}', {arg})""" 526 if is_optional: 527 arg_handler_call = f"""{arg} if {arg} is None else {arg_handler_call}""" 528 return arg_handler_call 529 530 gen_py = '' 531 for operator_name, operator_data in yaml_data.items(): 532 if _auto_generate_class_disabled(operator_data): 533 continue 534 class_name = _get_op_name(operator_name, operator_data) 535 func_name = _get_op_func_name(operator_name, operator_data) 536 pyboost_func_name = get_pyboost_name(operator_name) 537 args = operator_data.get('args') 538 inputs_args, inputs_default, init_args, args_assign, init_args_with_default, args_handlers = \ 539 process_args(class_name, args) 540 init_code = '\n'.join(args_assign) 541 signature_code = generate_py_op_signature(class_name, operator_data.get('args_signature'), inputs_args, 542 inputs_default) 543 deprecated_code = generate_py_op_deprecated(operator_data.get('deprecated')) 544 init_code = get_init_code(init_code, operator_data) 545 primitive_code = f"""\n 546class {class_name}(Primitive):\n""" 547 func_disabled = _auto_generate_func_disabled(operator_data) 548 primitive_code += _generate_class_description(class_name, func_name, inputs_args, init_args, func_disabled, 549 doc_str) 550 if signature_code != "": 551 primitive_code += signature_code 552 if deprecated_code != "": 553 primitive_code += deprecated_code 554 primitive_code += f""" @prim_arg_register 555 def __init__(self""" 556 if init_args_with_default: 557 primitive_code += ", " + f"""{', '.join(init_args_with_default) if init_args_with_default else ''}""" 558 call_args = [] 559 for name in inputs_args: 560 call_args.append(f"""{name}={inputs_default[name]}""" if name in inputs_default else name) 561 primitive_code += f"""): 562{init_code} 563 564 def __call__(self, {', '.join(call_args)}):""" 565 is_pyboost = is_pyboost_enable(operator_data) 566 if is_pyboost: 567 primitive_code += f""" 568 return _convert_stub({pyboost_func_name}(self, [""" 569 else: 570 primitive_code += f""" 571 return super().__call__(""" 572 if inputs_args: 573 args_with_handler = [] 574 for arg in inputs_args: 575 if arg in args_handlers: 576 is_optional = inputs_default.get(arg) == "None" 577 args_with_handler.append(_generate_arg_handler(class_name, arg, args_handlers[arg], is_optional)) 578 else: 579 args_with_handler.append(arg) 580 primitive_code += ', '.join(args_with_handler) 581 582 if init_args: 583 primitive_code += ', ' 584 primitive_code += ', '.join([f'self.{arg}' for arg in init_args]) 585 if is_pyboost: 586 primitive_code += """]))""" 587 else: 588 primitive_code += """) 589""" 590 591 gen_py += primitive_code 592 if not init_args: 593 prim_op_object = f"""\n 594{operator_name}_op={class_name}() 595""" 596 gen_py += prim_op_object 597 return gen_py 598 599 600def generate_op_name_opdef(yaml_data): 601 """ 602 Generate op name 603 """ 604 op_name_head = f""" 605#ifndef MINDSPORE_CORE_OP_NAME_H_ 606#define MINDSPORE_CORE_OP_NAME_H_ 607 608namespace mindspore::ops {{ 609""" 610 611 op_name_end = f"""}} // namespace mindspore::ops 612 613#endif // MINDSPORE_CORE_OP_NAME_H_ 614""" 615 616 op_name_gen = '' 617 op_name_gen += op_name_head 618 for operator_name, operator_data in yaml_data.items(): 619 k_name_op = _get_op_name(operator_name, operator_data) 620 op_name_gen += f"""constexpr auto kName{k_name_op} = "{k_name_op}"; 621""" 622 623 op_name_gen += op_name_end 624 return op_name_gen 625 626 627def generate_op_prim_opdef(yaml_data): 628 """ 629 Generate primitive c++ definition 630 """ 631 ops_prim_head = f""" 632#ifndef MINDSPORE_CORE_OPS_GEN_OPS_PRIMITIVE_H_ 633#define MINDSPORE_CORE_OPS_GEN_OPS_PRIMITIVE_H_ 634 635#include <memory> 636#include "ir/anf.h" 637#include "ir/primitive.h" 638#include "ops/auto_generate/gen_ops_name.h" 639#include "mindapi/base/macros.h" 640 641namespace mindspore::prim {{ 642""" 643 644 ops_prim_end = f"""}} // namespace mindspore::prim 645#endif // MINDSPORE_CORE_OPS_GEN_OPS_PRIMITIVE_H_ 646""" 647 648 ops_prim_gen = '' 649 ops_prim_gen += ops_prim_head 650 for operator_name, operator_data in yaml_data.items(): 651 k_name_op = _get_op_name(operator_name, operator_data) 652 ops_prim_gen += f"""GVAR_DEF(PrimitivePtr, kPrim{k_name_op}, std::make_shared<Primitive>(ops::kName{k_name_op})) 653""" 654 ops_prim_gen += ops_prim_end 655 return ops_prim_gen 656 657 658def generate_lite_ops(yaml_data): 659 """ 660 Generate BaseOperator parameter set and get func 661 """ 662 lite_ops_h_head = f""" 663#ifndef MINDSPORE_CORE_OPS_GEN_LITE_OPS_H_ 664#define MINDSPORE_CORE_OPS_GEN_LITE_OPS_H_ 665 666#include <vector> 667#include "ops/base_operator.h" 668#include "ops/auto_generate/gen_ops_name.h" 669 670namespace mindspore::ops {{ 671""" 672 673 lite_ops_h_end = f"""}} // namespace mindspore::ops 674#endif // MINDSPORE_CORE_OPS_GEN_LITE_OPS_H_ 675""" 676 677 lite_ops_cc_head = """ 678#include "ops/auto_generate/gen_lite_ops.h" 679#include "mindapi/src/helper.h" 680#include "ops/primitive_c.h" 681#include "ops/base_operator.h" 682#include "abstract/abstract_value.h" 683 684namespace mindspore::ops { 685""" 686 687 lite_ops_cc_end = f"""}} // namespace mindspore::ops 688 """ 689 690 lite_ops_h_gen = '' 691 lite_ops_cc_gen = '' 692 693 lite_ops_h_gen += lite_ops_h_head 694 lite_ops_cc_gen += lite_ops_cc_head 695 for operator_name, operator_data in yaml_data.items(): 696 op_name = _get_op_name(operator_name, operator_data) 697 lite_ops_h_gen += f"""class MIND_API {op_name} : public BaseOperator {{ 698 public: 699 MIND_API_BASE_MEMBER({op_name}); 700 {op_name}() : BaseOperator(kName{op_name}) {{}}\n""" 701 args = operator_data.get('args') 702 for _, (arg_name, arg_info) in enumerate(args.items()): 703 is_prim_init = arg_info.get('prim_init') 704 if not is_prim_init: 705 continue 706 707 dtype = get_dtype(arg_info) 708 if dtype == "str": 709 dtype = "std::string" 710 if dtype in ("tuple[str]", "list[str]"): 711 dtype = "std::vector<std::string>" 712 if dtype in ("tuple[int]", "list[int]"): 713 dtype = "std::vector<int64_t>" 714 if dtype in ("tuple[float]", "list[float]"): 715 dtype = "std::vector<float>" 716 if dtype in ("tuple[bool]", "list[bool]"): 717 dtype = "std::vector<bool>" 718 if dtype == "int": 719 dtype = "int64_t" 720 lite_ops_h_gen += f""" void set_{arg_name}(const {dtype} &{arg_name});\n""" 721 lite_ops_h_gen += f""" {dtype} get_{arg_name}() const;\n""" 722 723 lite_ops_cc_gen += f"""void {op_name}::set_{arg_name}(const {dtype} &{arg_name}) \ 724 {{ (void)this->AddAttr("{arg_name}", api::MakeValue({arg_name})); }}\n\n""" 725 lite_ops_cc_gen += f"""{dtype} {op_name}::get_{arg_name}() const \ 726 {{ return GetValue<{dtype}>(GetAttr("{arg_name}")); }}\n\n""" 727 728 op_name = _get_op_name(operator_name, operator_data) 729 lite_ops_cc_gen += f"""REGISTER_PRIMITIVE_C(kName{op_name}, {op_name});\n""" 730 lite_ops_cc_gen += f"""MIND_API_OPERATOR_IMPL({op_name}, BaseOperator);\n\n""" 731 lite_ops_h_gen += f"""}};\n\n""" 732 lite_ops_h_gen += lite_ops_h_end 733 lite_ops_cc_gen += lite_ops_cc_end 734 return lite_ops_h_gen, lite_ops_cc_gen 735 736 737def generate_cc_opdef(yaml_data): 738 """ 739 Generate c++ OpDef 740 """ 741 gen_cc_code = f"""\n 742namespace mindspore::ops {{""" 743 gen_opdef_map = f""" 744std::unordered_map<std::string, OpDefPtr> gOpDefTable = {{""" 745 gen_include = f"""\n 746#include \"ops/auto_generate/gen_ops_def.h\"""" 747 gen_include += f""" 748#include \"mindspore/core/ir/signature.h\"""" 749 750 for operator_name, operator_data in yaml_data.items(): 751 args = operator_data.get('args') 752 class_name = _get_op_name(operator_name, operator_data) 753 inputs_args, _, _, _, _, _ = process_args(class_name, args) 754 signature_code = generate_cc_op_signature(operator_data.get('args_signature'), inputs_args) 755 args = operator_data.get('args') 756 returns = operator_data.get('returns') 757 dispatch = operator_data.get("dispatch") 758 # dispatch not defined in yaml or dispatch.enable==False 759 if not dispatch or not dispatch.get("enable"): 760 dispatch = "false" 761 else: 762 dispatch = "true" 763 enable_dispatch_str = f"""{dispatch}""" 764 765 is_view = operator_data.get('view') 766 if is_view: 767 is_view = "true" 768 else: 769 is_view = "false" 770 is_view_str = f"""{is_view}""" 771 772 773 gen_include += f"""\n#include "ops/ops_func_impl/{operator_name}.h\"""" 774 cc_index_str = '' 775 gen_opdef_map += f"""\n {{"{class_name}", &g{class_name}}},""" 776 input_args_str = '' 777 args_dict = {} 778 for i, (arg_name, arg_info) in enumerate(args.items()): 779 args_dict[arg_name] = i 780 cc_index_str += f"""{{"{arg_name}", {i}}},\n""" 781 dtype = get_dtype(arg_info) 782 cc_dtype_str = convert_dtype_str(dtype) 783 784 is_prim_init = 1 if arg_info.get('prim_init') else 0 785 arg_handler = arg_info.get('arg_handler') 786 arg_handler_str = "" if arg_handler is None else arg_handler 787 788 type_cast = arg_info.get('type_cast') 789 type_cast_str = "" if type_cast is None else \ 790 ', '.join('DT_' + type.replace('[', '_').replace(']', '').upper() for type in 791 (ct.strip() for ct in type_cast.split(","))) 792 793 # default: None is regarded as a optional argument. 794 is_optional_str = "false" 795 if 'default' in arg_info.keys() and arg_info.get('default') == "None": 796 is_optional_str = "true" 797 798 input_args_str += f"""\n {{/*.arg_name_=*/"{arg_name}", /*.arg_dtype_=*/{cc_dtype_str}, """ + \ 799 f"""/*.as_init_arg_=*/{is_prim_init}, /*.arg_handler_=*/"{arg_handler_str}", """ + \ 800 f"""/*.cast_dtype_ =*/{{{type_cast_str}}}, /*.is_optional_=*/{is_optional_str}}},""" 801 802 # Process outputs. 803 return_args_str = '' 804 for return_name, return_info in returns.items(): 805 return_dtype = return_info.get('dtype') 806 ref_name = return_info.get('inplace') 807 ref_index_str = -1 if ref_name is None else args_dict.get(ref_name) 808 cc_return_type_str = 'DT_' + return_dtype.replace('[', '_').replace(']', '').upper() 809 return_args_str += f"""{{/*.arg_name_=*/"{return_name}", /*.arg_dtype_=*/{cc_return_type_str}, 810 /*.inplace_input_index_=*/{ref_index_str}}},\n""" 811 812 op_def_cc = template.OP_PROTO_TEMPLATE.replace(class_name=class_name, input_args=input_args_str, 813 return_args=return_args_str, signatures=signature_code, 814 indexes=cc_index_str, enable_dispatch=enable_dispatch_str, 815 is_view=is_view_str) 816 gen_cc_code += op_def_cc 817 gen_opdef_map += f"""\n}};""" 818 gen_cc_code += gen_opdef_map 819 820 cc_opdef_end = f"""\n}} // namespace mindspore::ops\n""" 821 return gen_include + gen_cc_code + cc_opdef_end 822 823 824ops_py_prim_header = f""" 825\"\"\"Operators definition generated by gen_ops.py, includes primitive classes.\"\"\" 826 827from mindspore.ops.primitive import Primitive, prim_arg_register 828from mindspore.ops import signature as sig 829from mindspore.common import dtype as mstype 830from mindspore.common._decorator import deprecated 831from mindspore.ops._primitive_cache import _get_cache_prim 832from mindspore.ops.auto_generate.gen_arg_dtype_cast import type_it 833from mindspore.ops.auto_generate.gen_arg_handler import * 834from mindspore._c_expression import OpDtype 835from mindspore.common._stub_tensor import _convert_stub 836""" 837 838 839ops_py_def_header = f""" 840\"\"\"Operators definition generated by gen_ops.py, includes functions.\"\"\" 841 842from .gen_ops_prim import * 843from .pyboost_inner_prim import * 844from mindspore.ops.operations.manually_defined.ops_def import * 845from mindspore.ops._primitive_cache import _get_cache_prim 846""" 847 848 849def generate_ops_prim_file(work_path, yaml_str, doc_str, file_pre): 850 py_path = os.path.join(work_path, f'mindspore/python/mindspore/ops/auto_generate/{file_pre}_ops_prim.py') 851 tmp_py_path = os.path.join(work_path, f'mindspore/python/mindspore/ops/auto_generate/tmp_{file_pre}_ops_prim.py') 852 pyboost_import_header = generate_pyboost_import_header(yaml_str) 853 py_prim = generate_py_primitive(yaml_str, doc_str) 854 write_file(tmp_py_path, py_licence_str + ops_py_prim_header + pyboost_import_header + py_prim) 855 check_change_and_replace_file(py_path, tmp_py_path) 856 857 858def generate_ops_def_file(work_path, yaml_str, doc_str, file_pre): 859 py_path = os.path.join(work_path, f'mindspore/python/mindspore/ops/auto_generate/{file_pre}_ops_def.py') 860 tmp_py_path = os.path.join(work_path, f'mindspore/python/mindspore/ops/auto_generate/tmp_{file_pre}_ops_def.py') 861 py_func = generate_py_op_func(yaml_str, doc_str) 862 write_file(tmp_py_path, py_licence_str + ops_py_def_header + py_func) 863 check_change_and_replace_file(py_path, tmp_py_path) 864 865 866def generate_ops_py_files(work_path, yaml_str, doc_str, file_pre): 867 """ 868 Generate ops python file from yaml. 869 """ 870 generate_ops_prim_file(work_path, yaml_str, doc_str, file_pre) 871 generate_ops_def_file(work_path, yaml_str, doc_str, file_pre) 872 873 874def generate_ops_cc_files(work_path, yaml_str): 875 """ 876 Generate ops c++ file from yaml. 877 """ 878 # ops_def 879 op_cc_path = os.path.join(work_path, 'mindspore/core/ops/auto_generate/gen_ops_def.cc') 880 tmp_op_cc_path = os.path.join(work_path, 'mindspore/core/ops/auto_generate/tmp_gen_ops_def.cc') 881 cc_def_code = generate_cc_opdef(yaml_str) 882 write_file(tmp_op_cc_path, cc_license_str + cc_def_code) 883 check_change_and_replace_file(op_cc_path, tmp_op_cc_path) 884 885 # ops_primitive 886 op_prim_path = os.path.join(work_path, 'mindspore/core/ops/auto_generate/gen_ops_primitive.h') 887 tmp_op_prim_path = os.path.join(work_path, 'mindspore/core/ops/auto_generate/tmp_gen_ops_primitive.h') 888 op_prim_code = generate_op_prim_opdef(yaml_str) 889 write_file(tmp_op_prim_path, cc_license_str + op_prim_code) 890 check_change_and_replace_file(op_prim_path, tmp_op_prim_path) 891 892 # lite_h_ops 893 lite_ops_h_path = os.path.join(work_path, 'mindspore/core/ops/auto_generate/gen_lite_ops.h') 894 tmp_lite_ops_h_path = os.path.join(work_path, 'mindspore/core/ops/auto_generate/tmp_gen_lite_ops.h') 895 lite_ops_h_code, lite_ops_cc_code = generate_lite_ops(yaml_str) 896 write_file(tmp_lite_ops_h_path, cc_license_str + lite_ops_h_code) 897 check_change_and_replace_file(lite_ops_h_path, tmp_lite_ops_h_path) 898 899 # lite_cc_ops 900 lite_ops_cc_path = os.path.join(work_path, 'mindspore/core/ops/auto_generate/gen_lite_ops.cc') 901 tmp_lite_ops_cc_path = os.path.join(work_path, 'mindspore/core/ops/auto_generate/tmp_gen_lite_ops.cc') 902 write_file(tmp_lite_ops_cc_path, cc_license_str + lite_ops_cc_code) 903 check_change_and_replace_file(lite_ops_cc_path, tmp_lite_ops_cc_path) 904 905 # ops_names 906 op_name_path = os.path.join(work_path, 'mindspore/core/ops/auto_generate/gen_ops_name.h') 907 tmp_op_name_path = os.path.join(work_path, 'mindspore/core/ops/auto_generate/tmp_gen_ops_name.h') 908 op_name_code = generate_op_name_opdef(yaml_str) 909 write_file(tmp_op_name_path, cc_license_str + op_name_code) 910 check_change_and_replace_file(op_name_path, tmp_op_name_path) 911 912 913def generate_op_labels(yaml_data): 914 """ 915 Generate python labels 916 """ 917 gen_label_py = f"""op_labels = {{""" 918 for operator_name, operator_data in yaml_data.items(): 919 labels = operator_data.get('labels') 920 if labels is not None: 921 class_name = _get_op_name(operator_name, operator_data) 922 gen_label_py += f""" 923 "{class_name}": {{""" 924 gen_label_py += f""", """.join([f""""{key}": {value}""" for key, value in labels.items()]) 925 gen_label_py += f"""}},""" 926 gen_label_py += f""" 927}}""" 928 return gen_label_py 929 930 931def generate_op_arg_default_value(yaml_data): 932 """ 933 Generate python default value. 934 """ 935 default_py_header = f"""\"\"\"Operator labels and args default value.\"\"\" 936from mindspore.common import dtype as mstype\n\n""" 937 938 gen_default_py = default_py_header + f"""op_args_default_value = {{""" 939 for operator_name, operator_data in yaml_data.items(): 940 arg_default_dict = {} 941 args = operator_data.get('args') 942 for arg_name, arg_info in args.items(): 943 arg_default = arg_info.get('default') 944 if arg_default is not None: 945 arg_default_dict[arg_name] = arg_default 946 if arg_default_dict: 947 class_name = _get_op_name(operator_name, operator_data) 948 gen_default_py += f""" 949 "{class_name}": {{""" 950 gen_default_py += f""", """.join([f""""{key}": {value}""" for key, value in arg_default_dict.items()]) 951 gen_default_py += f"""}},""" 952 gen_default_py += f""" 953}}""" 954 return gen_default_py 955 956 957def generate_create_instance_helper_file(work_path, yaml_str): 958 """ 959 Generate C++ helper file from yaml. 960 """ 961 dst_dir = os.path.join(work_path, 'mindspore/python/mindspore/ops/auto_generate') 962 op_py_path = os.path.join(dst_dir, 'cpp_create_prim_instance_helper.py') 963 tmp_op_py_path = os.path.join(dst_dir, 'tmp_cpp_create_prim_instance_helper.py') 964 py_labels = generate_op_labels(yaml_str) 965 py_arg_default = generate_op_arg_default_value(yaml_str) 966 write_file(tmp_op_py_path, py_licence_str + "\n" + py_arg_default + "\n\n" + py_labels + "\n") 967 check_change_and_replace_file(op_py_path, tmp_op_py_path) 968 969 970def generate_aclnn_reg_code(yaml_data): 971 """generate aclnn register code""" 972 current_path = os.path.dirname(os.path.abspath(__file__)) 973 work_path = os.path.join(current_path, '../../../../') 974 ops_yaml_path = os.path.join(work_path, 'mindspore/python/mindspore/ops_generate/ops.yaml') 975 yaml_str = gen_utils.safe_load_yaml(ops_yaml_path) 976 977 reg_code = f""" 978#include "plugin/device/ascend/kernel/opapi/aclnn_kernel_mod.h" 979 980namespace mindspore {{ 981namespace kernel {{ 982""" 983 for operator_name, operator_data in yaml_data.items(): 984 dispatch = operator_data.get("dispatch") 985 if not dispatch or not dispatch.get("enable"): 986 continue 987 Ascend = dispatch.get("Ascend") 988 if Ascend is not None: # KernelMod is provided by yaml, don't auto generate it. 989 continue 990 _, _, none_tensor_exist = get_dtypes(operator_data) 991 if none_tensor_exist: 992 gen_aclnn_kernel(operator_name, yaml_str, auto=True) 993 continue 994 class_name = ''.join(word.capitalize() for word in operator_name.split('_')) 995 op_class = operator_data.get("class") 996 if op_class and op_class.get("name") is not None: 997 class_name = op_class.get("name") 998 inputs_outputs_num = len(operator_data.get("args")) + len(operator_data.get("returns")) 999 aclnn_name = AclnnUtils.get_aclnn_interface(class_name) 1000 reg_code += f""" 1001MS_ACLNN_COMMON_KERNEL_FACTORY_REG({class_name}, {aclnn_name}, {inputs_outputs_num});""" 1002 reg_code += f""" 1003}} // namespace kernel 1004}} // namespace mindspore 1005""" 1006 return reg_code 1007 1008 1009def generate_aclnn_reg_file(work_path, yaml_str): 1010 """ 1011 Generate nnacl kernelmod register 1012 """ 1013 tmp_register_file = work_path + 'mindspore/ccsrc/plugin/device/ascend/kernel/opapi/tmp_aclnn_kernel_register.cc' 1014 register_file = work_path + 'mindspore/ccsrc/plugin/device/ascend/kernel/opapi/aclnn_kernel_register_auto.cc' 1015 reg_code = generate_aclnn_reg_code(yaml_str) 1016 write_file(tmp_register_file, cc_license_str + reg_code) 1017 check_change_and_replace_file(register_file, tmp_register_file) 1018 1019 1020def generate_arg_handler_files(work_path): 1021 """ 1022 Generate arg handler files. 1023 """ 1024 dst_dir = os.path.join(work_path, 'mindspore/python/mindspore/ops/auto_generate') 1025 src_arg_handler_path = os.path.join(work_path, 'mindspore/python/mindspore/ops_generate/arg_handler.py') 1026 dst_arg_handler_path = os.path.join(dst_dir, 'gen_arg_handler.py') 1027 tmp_dst_arg_handler_path = os.path.join(dst_dir, 'tmp_gen_arg_handler.py') 1028 if not os.path.exists(dst_dir): 1029 os.makedirs(dst_dir) 1030 shutil.copy(src_arg_handler_path, tmp_dst_arg_handler_path) 1031 check_change_and_replace_file(dst_arg_handler_path, tmp_dst_arg_handler_path) 1032 1033 src_arg_dtype_cast_path = os.path.join(work_path, 'mindspore/python/mindspore/ops_generate/arg_dtype_cast.py') 1034 dst_arg_dtype_cast_path = os.path.join(dst_dir, 'gen_arg_dtype_cast.py') 1035 tmp_arg_dtype_cast_path = os.path.join(dst_dir, 'tmp_arg_dtype_cast.py') 1036 shutil.copy(src_arg_dtype_cast_path, tmp_arg_dtype_cast_path) 1037 check_change_and_replace_file(dst_arg_dtype_cast_path, tmp_arg_dtype_cast_path) 1038 1039 1040def main(): 1041 current_path = os.path.dirname(os.path.abspath(__file__)) 1042 work_path = os.path.join(current_path, '../../../../') 1043 1044 # merge ops yaml 1045 ops_yaml_path = os.path.join(work_path, 'mindspore/python/mindspore/ops_generate/ops.yaml') 1046 doc_yaml_path = os.path.join(work_path, 'mindspore/python/mindspore/ops_generate/ops_doc.yaml') 1047 1048 ops_yaml_dir_path = os.path.join(work_path, 'mindspore/core/ops/ops_def/') 1049 infer_ops_yaml_dir_path = os.path.join(work_path, 'mindspore/core/ops/ops_def/infer/') 1050 doc_yaml_dir_path = os.path.join(work_path, 'mindspore/core/ops/ops_def/doc/') 1051 merge_files(ops_yaml_dir_path, ops_yaml_path, '*op.yaml') 1052 merge_files_append(infer_ops_yaml_dir_path, ops_yaml_path, '*op.yaml') 1053 merge_files(doc_yaml_dir_path, doc_yaml_path, '*doc.yaml') 1054 1055 # make auto_generate dir 1056 cc_path = os.path.join(work_path, 'mindspore/core/ops/auto_generate/') 1057 pathlib.Path(cc_path).mkdir(parents=True, exist_ok=True) 1058 1059 # generate arg_handler files 1060 generate_arg_handler_files(work_path) 1061 1062 # read ops definition str and doc str 1063 ops_yaml_str = safe_load_yaml(ops_yaml_path) 1064 doc_yaml_str = safe_load_yaml(doc_yaml_path) 1065 1066 # generate ops python files 1067 generate_ops_py_files(work_path, ops_yaml_str, doc_yaml_str, "gen") 1068 1069 # generate ops c++ files 1070 generate_ops_cc_files(work_path, ops_yaml_str) 1071 # generate create prim instance helper file 1072 generate_create_instance_helper_file(work_path, ops_yaml_str) 1073 # generate pyboost code 1074 gen_pyboost_code(work_path, ops_yaml_str, doc_yaml_str) 1075 # generate aclnn kernelmod register 1076 generate_aclnn_reg_file(work_path, ops_yaml_str) 1077 1078 1079if __name__ == "__main__": 1080 try: 1081 main() 1082 # pylint: disable=broad-except 1083 except Exception as e: 1084 logging.critical("Auto generate failed, err info: %s", e) 1085