1#!/usr/bin/env python3 2# -*- coding: utf-8 -*- 3 4# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 5# See https://llvm.org/LICENSE.txt for license information. 6# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 8# Script for updating SPIR-V dialect by scraping information from SPIR-V 9# HTML and JSON specs from the Internet. 10# 11# For example, to define the enum attribute for SPIR-V memory model: 12# 13# ./gen_spirv_dialect.py --base_td_path /path/to/SPIRVBase.td \ 14# --new-enum MemoryModel 15# 16# The 'operand_kinds' dict of spirv.core.grammar.json contains all supported 17# SPIR-V enum classes. 18 19import itertools 20import re 21import requests 22import textwrap 23import yaml 24 25SPIRV_HTML_SPEC_URL = 'https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html' 26SPIRV_JSON_SPEC_URL = 'https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/spirv.core.grammar.json' 27 28AUTOGEN_OP_DEF_SEPARATOR = '\n// -----\n\n' 29AUTOGEN_ENUM_SECTION_MARKER = 'enum section. Generated from SPIR-V spec; DO NOT MODIFY!' 30AUTOGEN_OPCODE_SECTION_MARKER = ( 31 'opcode section. Generated from SPIR-V spec; DO NOT MODIFY!') 32 33 34def get_spirv_doc_from_html_spec(): 35 """Extracts instruction documentation from SPIR-V HTML spec. 36 37 Returns: 38 - A dict mapping from instruction opcode to documentation. 39 """ 40 response = requests.get(SPIRV_HTML_SPEC_URL) 41 spec = response.content 42 43 from bs4 import BeautifulSoup 44 spirv = BeautifulSoup(spec, 'html.parser') 45 46 section_anchor = spirv.find('h3', {'id': '_a_id_instructions_a_instructions'}) 47 48 doc = {} 49 50 for section in section_anchor.parent.find_all('div', {'class': 'sect3'}): 51 for table in section.find_all('table'): 52 inst_html = table.tbody.tr.td.p 53 opname = inst_html.a['id'] 54 # Ignore the first line, which is just the opname. 55 doc[opname] = inst_html.text.split('\n', 1)[1].strip() 56 57 return doc 58 59 60def get_spirv_grammar_from_json_spec(): 61 """Extracts operand kind and instruction grammar from SPIR-V JSON spec. 62 63 Returns: 64 - A list containing all operand kinds' grammar 65 - A list containing all instructions' grammar 66 """ 67 response = requests.get(SPIRV_JSON_SPEC_URL) 68 spec = response.content 69 70 import json 71 spirv = json.loads(spec) 72 73 return spirv['operand_kinds'], spirv['instructions'] 74 75 76def split_list_into_sublists(items): 77 """Split the list of items into multiple sublists. 78 79 This is to make sure the string composed from each sublist won't exceed 80 80 characters. 81 82 Arguments: 83 - items: a list of strings 84 """ 85 chuncks = [] 86 chunk = [] 87 chunk_len = 0 88 89 for item in items: 90 chunk_len += len(item) + 2 91 if chunk_len > 80: 92 chuncks.append(chunk) 93 chunk = [] 94 chunk_len = len(item) + 2 95 chunk.append(item) 96 97 if len(chunk) != 0: 98 chuncks.append(chunk) 99 100 return chuncks 101 102 103def uniquify_enum_cases(lst): 104 """Prunes duplicate enum cases from the list. 105 106 Arguments: 107 - lst: List whose elements are to be uniqued. Assumes each element is a 108 (symbol, value) pair and elements already sorted according to value. 109 110 Returns: 111 - A list with all duplicates removed. The elements are sorted according to 112 value and, for each value, uniqued according to symbol. 113 original list, 114 - A map from deduplicated cases to the uniqued case. 115 """ 116 cases = lst 117 uniqued_cases = [] 118 duplicated_cases = {} 119 120 # First sort according to the value 121 cases.sort(key=lambda x: x[1]) 122 123 # Then group them according to the value 124 for _, groups in itertools.groupby(cases, key=lambda x: x[1]): 125 # For each value, sort according to the enumerant symbol. 126 sorted_group = sorted(groups, key=lambda x: x[0]) 127 # Keep the "smallest" case, which is typically the symbol without extension 128 # suffix. But we have special cases that we want to fix. 129 case = sorted_group[0] 130 for i in range(1, len(sorted_group)): 131 duplicated_cases[sorted_group[i][0]] = case[0] 132 if case[0] == 'HlslSemanticGOOGLE': 133 assert len(sorted_group) == 2, 'unexpected new variant for HlslSemantic' 134 case = sorted_group[1] 135 duplicated_cases[sorted_group[0][0]] = case[0] 136 uniqued_cases.append(case) 137 138 return uniqued_cases, duplicated_cases 139 140 141def toposort(dag, sort_fn): 142 """Topologically sorts the given dag. 143 144 Arguments: 145 - dag: a dict mapping from a node to its incoming nodes. 146 - sort_fn: a function for sorting nodes in the same batch. 147 148 Returns: 149 A list containing topologically sorted nodes. 150 """ 151 152 # Returns the next batch of nodes without incoming edges 153 def get_next_batch(dag): 154 while True: 155 no_prev_nodes = set(node for node, prev in dag.items() if not prev) 156 if not no_prev_nodes: 157 break 158 yield sorted(no_prev_nodes, key=sort_fn) 159 dag = { 160 node: (prev - no_prev_nodes) 161 for node, prev in dag.items() 162 if node not in no_prev_nodes 163 } 164 assert not dag, 'found cyclic dependency' 165 166 sorted_nodes = [] 167 for batch in get_next_batch(dag): 168 sorted_nodes.extend(batch) 169 170 return sorted_nodes 171 172 173def toposort_capabilities(all_cases, capability_mapping): 174 """Returns topologically sorted capability (symbol, value) pairs. 175 176 Arguments: 177 - all_cases: all capability cases (containing symbol, value, and implied 178 capabilities). 179 - capability_mapping: mapping from duplicated capability symbols to the 180 canonicalized symbol chosen for SPIRVBase.td. 181 182 Returns: 183 A list containing topologically sorted capability (symbol, value) pairs. 184 """ 185 dag = {} 186 name_to_value = {} 187 for case in all_cases: 188 # Get the current capability. 189 cur = case['enumerant'] 190 name_to_value[cur] = case['value'] 191 # Ignore duplicated symbols. 192 if cur in capability_mapping: 193 continue 194 195 # Get capabilities implied by the current capability. 196 prev = case.get('capabilities', []) 197 uniqued_prev = set([capability_mapping.get(c, c) for c in prev]) 198 dag[cur] = uniqued_prev 199 200 sorted_caps = toposort(dag, lambda x: name_to_value[x]) 201 # Attach the capability's value as the second component of the pair. 202 return [(c, name_to_value[c]) for c in sorted_caps] 203 204 205def get_capability_mapping(operand_kinds): 206 """Returns the capability mapping from duplicated cases to canonicalized ones. 207 208 Arguments: 209 - operand_kinds: all operand kinds' grammar spec 210 211 Returns: 212 - A map mapping from duplicated capability symbols to the canonicalized 213 symbol chosen for SPIRVBase.td. 214 """ 215 # Find the operand kind for capability 216 cap_kind = {} 217 for kind in operand_kinds: 218 if kind['kind'] == 'Capability': 219 cap_kind = kind 220 221 kind_cases = [ 222 (case['enumerant'], case['value']) for case in cap_kind['enumerants'] 223 ] 224 _, capability_mapping = uniquify_enum_cases(kind_cases) 225 226 return capability_mapping 227 228 229def get_availability_spec(enum_case, capability_mapping, for_op, for_cap): 230 """Returns the availability specification string for the given enum case. 231 232 Arguments: 233 - enum_case: the enum case to generate availability spec for. It may contain 234 'version', 'lastVersion', 'extensions', or 'capabilities'. 235 - capability_mapping: mapping from duplicated capability symbols to the 236 canonicalized symbol chosen for SPIRVBase.td. 237 - for_op: bool value indicating whether this is the availability spec for an 238 op itself. 239 - for_cap: bool value indicating whether this is the availability spec for 240 capabilities themselves. 241 242 Returns: 243 - A `let availability = [...];` string if with availability spec or 244 empty string if without availability spec 245 """ 246 assert not (for_op and for_cap), 'cannot set both for_op and for_cap' 247 248 DEFAULT_MIN_VERSION = 'MinVersion<SPV_V_1_0>' 249 DEFAULT_MAX_VERSION = 'MaxVersion<SPV_V_1_5>' 250 DEFAULT_CAP = 'Capability<[]>' 251 DEFAULT_EXT = 'Extension<[]>' 252 253 min_version = enum_case.get('version', '') 254 if min_version == 'None': 255 min_version = '' 256 elif min_version: 257 min_version = 'MinVersion<SPV_V_{}>'.format(min_version.replace('.', '_')) 258 # TODO: delete this once ODS can support dialect-specific content 259 # and we can use omission to mean no requirements. 260 if for_op and not min_version: 261 min_version = DEFAULT_MIN_VERSION 262 263 max_version = enum_case.get('lastVersion', '') 264 if max_version: 265 max_version = 'MaxVersion<SPV_V_{}>'.format(max_version.replace('.', '_')) 266 # TODO: delete this once ODS can support dialect-specific content 267 # and we can use omission to mean no requirements. 268 if for_op and not max_version: 269 max_version = DEFAULT_MAX_VERSION 270 271 exts = enum_case.get('extensions', []) 272 if exts: 273 exts = 'Extension<[{}]>'.format(', '.join(sorted(set(exts)))) 274 # We need to strip the minimal version requirement if this symbol is 275 # available via an extension, which means *any* SPIR-V version can support 276 # it as long as the extension is provided. The grammar's 'version' field 277 # under such case should be interpreted as this symbol is introduced as 278 # a core symbol since the given version, rather than a minimal version 279 # requirement. 280 min_version = DEFAULT_MIN_VERSION if for_op else '' 281 # TODO: delete this once ODS can support dialect-specific content 282 # and we can use omission to mean no requirements. 283 if for_op and not exts: 284 exts = DEFAULT_EXT 285 286 caps = enum_case.get('capabilities', []) 287 implies = '' 288 if caps: 289 canonicalized_caps = [] 290 for c in caps: 291 if c in capability_mapping: 292 canonicalized_caps.append(capability_mapping[c]) 293 else: 294 canonicalized_caps.append(c) 295 prefixed_caps = [ 296 'SPV_C_{}'.format(c) for c in sorted(set(canonicalized_caps)) 297 ] 298 if for_cap: 299 # If this is generating the availability for capabilities, we need to 300 # put the capability "requirements" in implies field because now 301 # the "capabilities" field in the source grammar means so. 302 caps = '' 303 implies = 'list<I32EnumAttrCase> implies = [{}];'.format( 304 ', '.join(prefixed_caps)) 305 else: 306 caps = 'Capability<[{}]>'.format(', '.join(prefixed_caps)) 307 implies = '' 308 # TODO: delete this once ODS can support dialect-specific content 309 # and we can use omission to mean no requirements. 310 if for_op and not caps: 311 caps = DEFAULT_CAP 312 313 avail = '' 314 # Compose availability spec if any of the requirements is not empty. 315 # For ops, because we have a default in SPV_Op class, omit if the spec 316 # is the same. 317 if (min_version or max_version or caps or exts) and not ( 318 for_op and min_version == DEFAULT_MIN_VERSION and 319 max_version == DEFAULT_MAX_VERSION and caps == DEFAULT_CAP and 320 exts == DEFAULT_EXT): 321 joined_spec = ',\n '.join( 322 [e for e in [min_version, max_version, exts, caps] if e]) 323 avail = '{} availability = [\n {}\n ];'.format( 324 'let' if for_op else 'list<Availability>', joined_spec) 325 326 return '{}{}{}'.format(implies, '\n ' if implies and avail else '', avail) 327 328 329def gen_operand_kind_enum_attr(operand_kind, capability_mapping): 330 """Generates the TableGen EnumAttr definition for the given operand kind. 331 332 Returns: 333 - The operand kind's name 334 - A string containing the TableGen EnumAttr definition 335 """ 336 if 'enumerants' not in operand_kind: 337 return '', '' 338 339 # Returns a symbol for the given case in the given kind. This function 340 # handles Dim specially to avoid having numbers as the start of symbols, 341 # which does not play well with C++ and the MLIR parser. 342 def get_case_symbol(kind_name, case_name): 343 if kind_name == 'Dim': 344 if case_name == '1D' or case_name == '2D' or case_name == '3D': 345 return 'Dim{}'.format(case_name) 346 return case_name 347 348 kind_name = operand_kind['kind'] 349 is_bit_enum = operand_kind['category'] == 'BitEnum' 350 kind_category = 'Bit' if is_bit_enum else 'I32' 351 kind_acronym = ''.join([c for c in kind_name if c >= 'A' and c <= 'Z']) 352 353 name_to_case_dict = {} 354 for case in operand_kind['enumerants']: 355 name_to_case_dict[case['enumerant']] = case 356 357 if kind_name == 'Capability': 358 # Special treatment for capability cases: we need to sort them topologically 359 # because a capability can refer to another via the 'implies' field. 360 kind_cases = toposort_capabilities(operand_kind['enumerants'], 361 capability_mapping) 362 else: 363 kind_cases = [(case['enumerant'], case['value']) 364 for case in operand_kind['enumerants']] 365 kind_cases, _ = uniquify_enum_cases(kind_cases) 366 max_len = max([len(symbol) for (symbol, _) in kind_cases]) 367 368 # Generate the definition for each enum case 369 fmt_str = 'def SPV_{acronym}_{case} {colon:>{offset}} '\ 370 '{category}EnumAttrCase<"{symbol}", {value}>{avail}' 371 case_defs = [] 372 for case in kind_cases: 373 avail = get_availability_spec(name_to_case_dict[case[0]], 374 capability_mapping, 375 False, kind_name == 'Capability') 376 case_def = fmt_str.format( 377 category=kind_category, 378 acronym=kind_acronym, 379 case=case[0], 380 symbol=get_case_symbol(kind_name, case[0]), 381 value=case[1], 382 avail=' {{\n {}\n}}'.format(avail) if avail else ';', 383 colon=':', 384 offset=(max_len + 1 - len(case[0]))) 385 case_defs.append(case_def) 386 case_defs = '\n'.join(case_defs) 387 388 # Generate the list of enum case names 389 fmt_str = 'SPV_{acronym}_{symbol}'; 390 case_names = [fmt_str.format(acronym=kind_acronym,symbol=case[0]) 391 for case in kind_cases] 392 393 # Split them into sublists and concatenate into multiple lines 394 case_names = split_list_into_sublists(case_names) 395 case_names = ['{:6}'.format('') + ', '.join(sublist) 396 for sublist in case_names] 397 case_names = ',\n'.join(case_names) 398 399 # Generate the enum attribute definition 400 enum_attr = '''def SPV_{name}Attr : 401 SPV_{category}EnumAttr<"{name}", "valid SPIR-V {name}", [ 402{cases} 403 ]>;'''.format( 404 name=kind_name, category=kind_category, cases=case_names) 405 return kind_name, case_defs + '\n\n' + enum_attr 406 407 408def gen_opcode(instructions): 409 """ Generates the TableGen definition to map opname to opcode 410 411 Returns: 412 - A string containing the TableGen SPV_OpCode definition 413 """ 414 415 max_len = max([len(inst['opname']) for inst in instructions]) 416 def_fmt_str = 'def SPV_OC_{name} {colon:>{offset}} '\ 417 'I32EnumAttrCase<"{name}", {value}>;' 418 opcode_defs = [ 419 def_fmt_str.format( 420 name=inst['opname'], 421 value=inst['opcode'], 422 colon=':', 423 offset=(max_len + 1 - len(inst['opname']))) for inst in instructions 424 ] 425 opcode_str = '\n'.join(opcode_defs) 426 427 decl_fmt_str = 'SPV_OC_{name}' 428 opcode_list = [ 429 decl_fmt_str.format(name=inst['opname']) for inst in instructions 430 ] 431 opcode_list = split_list_into_sublists(opcode_list) 432 opcode_list = [ 433 '{:6}'.format('') + ', '.join(sublist) for sublist in opcode_list 434 ] 435 opcode_list = ',\n'.join(opcode_list) 436 enum_attr = 'def SPV_OpcodeAttr :\n'\ 437 ' SPV_I32EnumAttr<"{name}", "valid SPIR-V instructions", [\n'\ 438 '{lst}\n'\ 439 ' ]>;'.format(name='Opcode', lst=opcode_list) 440 return opcode_str + '\n\n' + enum_attr 441 442def map_cap_to_opnames(instructions): 443 """Maps capabilities to instructions enabled by those capabilities 444 445 Arguments: 446 - instructions: a list containing a subset of SPIR-V instructions' grammar 447 Returns: 448 - A map with keys representing capabilities and values of lists of 449 instructions enabled by the corresponding key 450 """ 451 cap_to_inst = {} 452 453 for inst in instructions: 454 caps = inst['capabilities'] if 'capabilities' in inst else ['0_core_0'] 455 for cap in caps: 456 if cap not in cap_to_inst: 457 cap_to_inst[cap] = [] 458 cap_to_inst[cap].append(inst['opname']) 459 460 return cap_to_inst 461 462def gen_instr_coverage_report(path, instructions): 463 """Dumps to standard output a YAML report of current instruction coverage 464 465 Arguments: 466 - path: the path to SPIRBase.td 467 - instructions: a list containing all SPIR-V instructions' grammar 468 """ 469 with open(path, 'r') as f: 470 content = f.read() 471 472 content = content.split(AUTOGEN_OPCODE_SECTION_MARKER) 473 474 existing_opcodes = [k[11:] for k in re.findall('def SPV_OC_\w+', content[1])] 475 existing_instructions = list( 476 filter(lambda inst: (inst['opname'] in existing_opcodes), 477 instructions)) 478 479 instructions_opnames = [inst['opname'] for inst in instructions] 480 481 remaining_opcodes = list(set(instructions_opnames) - set(existing_opcodes)) 482 remaining_instructions = list( 483 filter(lambda inst: (inst['opname'] in remaining_opcodes), 484 instructions)) 485 486 rem_cap_to_instr = map_cap_to_opnames(remaining_instructions) 487 ex_cap_to_instr = map_cap_to_opnames(existing_instructions) 488 489 rem_cap_to_cov = {} 490 491 # Calculate coverage for each capability 492 for cap in rem_cap_to_instr: 493 if cap not in ex_cap_to_instr: 494 rem_cap_to_cov[cap] = 0.0 495 else: 496 rem_cap_to_cov[cap] = \ 497 (len(ex_cap_to_instr[cap]) / (len(ex_cap_to_instr[cap]) \ 498 + len(rem_cap_to_instr[cap]))) 499 500 report = {} 501 502 # Merge the 3 maps into one report 503 for cap in rem_cap_to_instr: 504 report[cap] = {} 505 report[cap]['Supported Instructions'] = \ 506 ex_cap_to_instr[cap] if cap in ex_cap_to_instr else [] 507 report[cap]['Unsupported Instructions'] = rem_cap_to_instr[cap] 508 report[cap]['Coverage'] = '{}%'.format(int(rem_cap_to_cov[cap] * 100)) 509 510 print(yaml.dump(report)) 511 512def update_td_opcodes(path, instructions, filter_list): 513 """Updates SPIRBase.td with new generated opcode cases. 514 515 Arguments: 516 - path: the path to SPIRBase.td 517 - instructions: a list containing all SPIR-V instructions' grammar 518 - filter_list: a list containing new opnames to add 519 """ 520 521 with open(path, 'r') as f: 522 content = f.read() 523 524 content = content.split(AUTOGEN_OPCODE_SECTION_MARKER) 525 assert len(content) == 3 526 527 # Extend opcode list with existing list 528 existing_opcodes = [k[11:] for k in re.findall('def SPV_OC_\w+', content[1])] 529 filter_list.extend(existing_opcodes) 530 filter_list = list(set(filter_list)) 531 532 # Generate the opcode for all instructions in SPIR-V 533 filter_instrs = list( 534 filter(lambda inst: (inst['opname'] in filter_list), instructions)) 535 # Sort instruction based on opcode 536 filter_instrs.sort(key=lambda inst: inst['opcode']) 537 opcode = gen_opcode(filter_instrs) 538 539 # Substitute the opcode 540 content = content[0] + AUTOGEN_OPCODE_SECTION_MARKER + '\n\n' + \ 541 opcode + '\n\n// End ' + AUTOGEN_OPCODE_SECTION_MARKER \ 542 + content[2] 543 544 with open(path, 'w') as f: 545 f.write(content) 546 547 548def update_td_enum_attrs(path, operand_kinds, filter_list): 549 """Updates SPIRBase.td with new generated enum definitions. 550 551 Arguments: 552 - path: the path to SPIRBase.td 553 - operand_kinds: a list containing all operand kinds' grammar 554 - filter_list: a list containing new enums to add 555 """ 556 with open(path, 'r') as f: 557 content = f.read() 558 559 content = content.split(AUTOGEN_ENUM_SECTION_MARKER) 560 assert len(content) == 3 561 562 # Extend filter list with existing enum definitions 563 existing_kinds = [ 564 k[8:-4] for k in re.findall('def SPV_\w+Attr', content[1])] 565 filter_list.extend(existing_kinds) 566 567 capability_mapping = get_capability_mapping(operand_kinds) 568 569 # Generate definitions for all enums in filter list 570 defs = [ 571 gen_operand_kind_enum_attr(kind, capability_mapping) 572 for kind in operand_kinds 573 if kind['kind'] in filter_list 574 ] 575 # Sort alphabetically according to enum name 576 defs.sort(key=lambda enum : enum[0]) 577 # Only keep the definitions from now on 578 # Put Capability's definition at the very beginning because capability cases 579 # will be referenced later 580 defs = [enum[1] for enum in defs if enum[0] == 'Capability' 581 ] + [enum[1] for enum in defs if enum[0] != 'Capability'] 582 583 # Substitute the old section 584 content = content[0] + AUTOGEN_ENUM_SECTION_MARKER + '\n\n' + \ 585 '\n\n'.join(defs) + "\n\n// End " + AUTOGEN_ENUM_SECTION_MARKER \ 586 + content[2]; 587 588 with open(path, 'w') as f: 589 f.write(content) 590 591 592def snake_casify(name): 593 """Turns the given name to follow snake_case convention.""" 594 name = re.sub('\W+', '', name).split() 595 name = [s.lower() for s in name] 596 return '_'.join(name) 597 598 599def map_spec_operand_to_ods_argument(operand): 600 """Maps an operand in SPIR-V JSON spec to an op argument in ODS. 601 602 Arguments: 603 - A dict containing the operand's kind, quantifier, and name 604 605 Returns: 606 - A string containing both the type and name for the argument 607 """ 608 kind = operand['kind'] 609 quantifier = operand.get('quantifier', '') 610 611 # These instruction "operands" are for encoding the results; they should 612 # not be handled here. 613 assert kind != 'IdResultType', 'unexpected to handle "IdResultType" kind' 614 assert kind != 'IdResult', 'unexpected to handle "IdResult" kind' 615 616 if kind == 'IdRef': 617 if quantifier == '': 618 arg_type = 'SPV_Type' 619 elif quantifier == '?': 620 arg_type = 'Optional<SPV_Type>' 621 else: 622 arg_type = 'Variadic<SPV_Type>' 623 elif kind == 'IdMemorySemantics' or kind == 'IdScope': 624 # TODO: Need to further constrain 'IdMemorySemantics' 625 # and 'IdScope' given that they should be generated from OpConstant. 626 assert quantifier == '', ('unexpected to have optional/variadic memory ' 627 'semantics or scope <id>') 628 arg_type = 'SPV_' + kind[2:] + 'Attr' 629 elif kind == 'LiteralInteger': 630 if quantifier == '': 631 arg_type = 'I32Attr' 632 elif quantifier == '?': 633 arg_type = 'OptionalAttr<I32Attr>' 634 else: 635 arg_type = 'OptionalAttr<I32ArrayAttr>' 636 elif kind == 'LiteralString' or \ 637 kind == 'LiteralContextDependentNumber' or \ 638 kind == 'LiteralExtInstInteger' or \ 639 kind == 'LiteralSpecConstantOpInteger' or \ 640 kind == 'PairLiteralIntegerIdRef' or \ 641 kind == 'PairIdRefLiteralInteger' or \ 642 kind == 'PairIdRefIdRef': 643 assert False, '"{}" kind unimplemented'.format(kind) 644 else: 645 # The rest are all enum operands that we represent with op attributes. 646 assert quantifier != '*', 'unexpected to have variadic enum attribute' 647 arg_type = 'SPV_{}Attr'.format(kind) 648 if quantifier == '?': 649 arg_type = 'OptionalAttr<{}>'.format(arg_type) 650 651 name = operand.get('name', '') 652 name = snake_casify(name) if name else kind.lower() 653 654 return '{}:${}'.format(arg_type, name) 655 656 657def get_description(text, appendix): 658 """Generates the description for the given SPIR-V instruction. 659 660 Arguments: 661 - text: Textual description of the operation as string. 662 - appendix: Additional contents to attach in description as string, 663 includking IR examples, and others. 664 665 Returns: 666 - A string that corresponds to the description of the Tablegen op. 667 """ 668 fmt_str = '{text}\n\n <!-- End of AutoGen section -->\n{appendix}\n ' 669 return fmt_str.format(text=text, appendix=appendix) 670 671 672def get_op_definition(instruction, doc, existing_info, capability_mapping): 673 """Generates the TableGen op definition for the given SPIR-V instruction. 674 675 Arguments: 676 - instruction: the instruction's SPIR-V JSON grammar 677 - doc: the instruction's SPIR-V HTML doc 678 - existing_info: a dict containing potential manually specified sections for 679 this instruction 680 - capability_mapping: mapping from duplicated capability symbols to the 681 canonicalized symbol chosen for SPIRVBase.td 682 683 Returns: 684 - A string containing the TableGen op definition 685 """ 686 fmt_str = ('def SPV_{opname}Op : ' 687 'SPV_{inst_category}<"{opname}"{category_args}[{traits}]> ' 688 '{{\n let summary = {summary};\n\n let description = ' 689 '[{{\n{description}}}];{availability}\n') 690 inst_category = existing_info.get('inst_category', 'Op') 691 if inst_category == 'Op': 692 fmt_str +='\n let arguments = (ins{args});\n\n'\ 693 ' let results = (outs{results});\n' 694 695 fmt_str +='{extras}'\ 696 '}}\n' 697 698 opname = instruction['opname'][2:] 699 category_args = existing_info.get('category_args', '') 700 701 if '\n' in doc: 702 summary, text = doc.split('\n', 1) 703 else: 704 summary = doc 705 text = '' 706 wrapper = textwrap.TextWrapper( 707 width=76, initial_indent=' ', subsequent_indent=' ') 708 709 # Format summary. If the summary can fit in the same line, we print it out 710 # as a "-quoted string; otherwise, wrap the lines using "[{...}]". 711 summary = summary.strip(); 712 if len(summary) + len(' let summary = "";') <= 80: 713 summary = '"{}"'.format(summary) 714 else: 715 summary = '[{{\n{}\n }}]'.format(wrapper.fill(summary)) 716 717 # Wrap text 718 text = text.split('\n') 719 text = [wrapper.fill(line) for line in text if line] 720 text = '\n\n'.join(text) 721 722 operands = instruction.get('operands', []) 723 724 # Op availability 725 avail = '' 726 # We assume other instruction categories has a base availability spec, so 727 # only add this if this is directly using SPV_Op as the base. 728 if inst_category == 'Op': 729 avail = get_availability_spec(instruction, capability_mapping, True, False) 730 if avail: 731 avail = '\n\n {0}'.format(avail) 732 733 # Set op's result 734 results = '' 735 if len(operands) > 0 and operands[0]['kind'] == 'IdResultType': 736 results = '\n SPV_Type:$result\n ' 737 operands = operands[1:] 738 if 'results' in existing_info: 739 results = existing_info['results'] 740 741 # Ignore the operand standing for the result <id> 742 if len(operands) > 0 and operands[0]['kind'] == 'IdResult': 743 operands = operands[1:] 744 745 # Set op' argument 746 arguments = existing_info.get('arguments', None) 747 if arguments is None: 748 arguments = [map_spec_operand_to_ods_argument(o) for o in operands] 749 arguments = ',\n '.join(arguments) 750 if arguments: 751 # Prepend and append whitespace for formatting 752 arguments = '\n {}\n '.format(arguments) 753 754 description = existing_info.get('description', None) 755 if description is None: 756 assembly = '\n ```\n'\ 757 ' [TODO]\n'\ 758 ' ```mlir\n\n'\ 759 ' #### Example:\n\n'\ 760 ' ```\n'\ 761 ' [TODO]\n' \ 762 ' ```' 763 description = get_description(text, assembly) 764 765 return fmt_str.format( 766 opname=opname, 767 category_args=category_args, 768 inst_category=inst_category, 769 traits=existing_info.get('traits', ''), 770 summary=summary, 771 description=description, 772 availability=avail, 773 args=arguments, 774 results=results, 775 extras=existing_info.get('extras', '')) 776 777 778def get_string_between(base, start, end): 779 """Extracts a substring with a specified start and end from a string. 780 781 Arguments: 782 - base: string to extract from. 783 - start: string to use as the start of the substring. 784 - end: string to use as the end of the substring. 785 786 Returns: 787 - The substring if found 788 - The part of the base after end of the substring. Is the base string itself 789 if the substring wasnt found. 790 """ 791 split = base.split(start, 1) 792 if len(split) == 2: 793 rest = split[1].split(end, 1) 794 assert len(rest) == 2, \ 795 'cannot find end "{end}" while extracting substring '\ 796 'starting with {start}'.format(start=start, end=end) 797 return rest[0].rstrip(end), rest[1] 798 return '', split[0] 799 800 801def get_string_between_nested(base, start, end): 802 """Extracts a substring with a nested start and end from a string. 803 804 Arguments: 805 - base: string to extract from. 806 - start: string to use as the start of the substring. 807 - end: string to use as the end of the substring. 808 809 Returns: 810 - The substring if found 811 - The part of the base after end of the substring. Is the base string itself 812 if the substring wasn't found. 813 """ 814 split = base.split(start, 1) 815 if len(split) == 2: 816 # Handle nesting delimiters 817 rest = split[1] 818 unmatched_start = 1 819 index = 0 820 while unmatched_start > 0 and index < len(rest): 821 if rest[index:].startswith(end): 822 unmatched_start -= 1 823 if unmatched_start == 0: 824 break 825 index += len(end) 826 elif rest[index:].startswith(start): 827 unmatched_start += 1 828 index += len(start) 829 else: 830 index += 1 831 832 assert index < len(rest), \ 833 'cannot find end "{end}" while extracting substring '\ 834 'starting with "{start}"'.format(start=start, end=end) 835 return rest[:index], rest[index + len(end):] 836 return '', split[0] 837 838 839def extract_td_op_info(op_def): 840 """Extracts potentially manually specified sections in op's definition. 841 842 Arguments: - A string containing the op's TableGen definition 843 844 Returns: 845 - A dict containing potential manually specified sections 846 """ 847 # Get opname 848 opname = [o[8:-2] for o in re.findall('def SPV_\w+Op', op_def)] 849 assert len(opname) == 1, 'more than one ops in the same section!' 850 opname = opname[0] 851 852 # Get instruction category 853 inst_category = [ 854 o[4:] for o in re.findall('SPV_\w+Op', 855 op_def.split(':', 1)[1]) 856 ] 857 assert len(inst_category) <= 1, 'more than one ops in the same section!' 858 inst_category = inst_category[0] if len(inst_category) == 1 else 'Op' 859 860 # Get category_args 861 op_tmpl_params, _ = get_string_between_nested(op_def, '<', '>') 862 opstringname, rest = get_string_between(op_tmpl_params, '"', '"') 863 category_args = rest.split('[', 1)[0] 864 865 # Get traits 866 traits, _ = get_string_between_nested(rest, '[', ']') 867 868 # Get description 869 description, rest = get_string_between(op_def, 'let description = [{\n', 870 '}];\n') 871 872 # Get arguments 873 args, rest = get_string_between(rest, ' let arguments = (ins', ');\n') 874 875 # Get results 876 results, rest = get_string_between(rest, ' let results = (outs', ');\n') 877 878 extras = rest.strip(' }\n') 879 if extras: 880 extras = '\n {}\n'.format(extras) 881 882 return { 883 # Prefix with 'Op' to make it consistent with SPIR-V spec 884 'opname': 'Op{}'.format(opname), 885 'inst_category': inst_category, 886 'category_args': category_args, 887 'traits': traits, 888 'description': description, 889 'arguments': args, 890 'results': results, 891 'extras': extras 892 } 893 894 895def update_td_op_definitions(path, instructions, docs, filter_list, 896 inst_category, capability_mapping): 897 """Updates SPIRVOps.td with newly generated op definition. 898 899 Arguments: 900 - path: path to SPIRVOps.td 901 - instructions: SPIR-V JSON grammar for all instructions 902 - docs: SPIR-V HTML doc for all instructions 903 - filter_list: a list containing new opnames to include 904 - capability_mapping: mapping from duplicated capability symbols to the 905 canonicalized symbol chosen for SPIRVBase.td. 906 907 Returns: 908 - A string containing all the TableGen op definitions 909 """ 910 with open(path, 'r') as f: 911 content = f.read() 912 913 # Split the file into chunks, each containing one op. 914 ops = content.split(AUTOGEN_OP_DEF_SEPARATOR) 915 header = ops[0] 916 footer = ops[-1] 917 ops = ops[1:-1] 918 919 # For each existing op, extract the manually-written sections out to retain 920 # them when re-generating the ops. Also append the existing ops to filter 921 # list. 922 name_op_map = {} # Map from opname to its existing ODS definition 923 op_info_dict = {} 924 for op in ops: 925 info_dict = extract_td_op_info(op) 926 opname = info_dict['opname'] 927 name_op_map[opname] = op 928 op_info_dict[opname] = info_dict 929 filter_list.append(opname) 930 filter_list = sorted(list(set(filter_list))) 931 932 op_defs = [] 933 for opname in filter_list: 934 # Find the grammar spec for this op 935 try: 936 instruction = next( 937 inst for inst in instructions if inst['opname'] == opname) 938 op_defs.append( 939 get_op_definition( 940 instruction, docs[opname], 941 op_info_dict.get(opname, {'inst_category': inst_category}), 942 capability_mapping)) 943 except StopIteration: 944 # This is an op added by us; use the existing ODS definition. 945 op_defs.append(name_op_map[opname]) 946 947 # Substitute the old op definitions 948 op_defs = [header] + op_defs + [footer] 949 content = AUTOGEN_OP_DEF_SEPARATOR.join(op_defs) 950 951 with open(path, 'w') as f: 952 f.write(content) 953 954 955if __name__ == '__main__': 956 import argparse 957 958 cli_parser = argparse.ArgumentParser( 959 description='Update SPIR-V dialect definitions using SPIR-V spec') 960 961 cli_parser.add_argument( 962 '--base-td-path', 963 dest='base_td_path', 964 type=str, 965 default=None, 966 help='Path to SPIRVBase.td') 967 cli_parser.add_argument( 968 '--op-td-path', 969 dest='op_td_path', 970 type=str, 971 default=None, 972 help='Path to SPIRVOps.td') 973 974 cli_parser.add_argument( 975 '--new-enum', 976 dest='new_enum', 977 type=str, 978 default=None, 979 help='SPIR-V enum to be added to SPIRVBase.td') 980 cli_parser.add_argument( 981 '--new-opcodes', 982 dest='new_opcodes', 983 type=str, 984 default=None, 985 nargs='*', 986 help='update SPIR-V opcodes in SPIRVBase.td') 987 cli_parser.add_argument( 988 '--new-inst', 989 dest='new_inst', 990 type=str, 991 default=None, 992 nargs='*', 993 help='SPIR-V instruction to be added to ops file') 994 cli_parser.add_argument( 995 '--inst-category', 996 dest='inst_category', 997 type=str, 998 default='Op', 999 help='SPIR-V instruction category used for choosing '\ 1000 'the TableGen base class to define this op') 1001 cli_parser.add_argument('--gen-inst-coverage', dest='gen_inst_coverage', action='store_true') 1002 cli_parser.set_defaults(gen_inst_coverage=False) 1003 1004 args = cli_parser.parse_args() 1005 1006 operand_kinds, instructions = get_spirv_grammar_from_json_spec() 1007 1008 # Define new enum attr 1009 if args.new_enum is not None: 1010 assert args.base_td_path is not None 1011 filter_list = [args.new_enum] if args.new_enum else [] 1012 update_td_enum_attrs(args.base_td_path, operand_kinds, filter_list) 1013 1014 # Define new opcode 1015 if args.new_opcodes is not None: 1016 assert args.base_td_path is not None 1017 update_td_opcodes(args.base_td_path, instructions, args.new_opcodes) 1018 1019 # Define new op 1020 if args.new_inst is not None: 1021 assert args.op_td_path is not None 1022 docs = get_spirv_doc_from_html_spec() 1023 capability_mapping = get_capability_mapping(operand_kinds) 1024 update_td_op_definitions(args.op_td_path, instructions, docs, args.new_inst, 1025 args.inst_category, capability_mapping) 1026 print('Done. Note that this script just generates a template; ', end='') 1027 print('please read the spec and update traits, arguments, and ', end='') 1028 print('results accordingly.') 1029 1030 if args.gen_inst_coverage: 1031 gen_instr_coverage_report(args.base_td_path, instructions) 1032