#!/usr/bin/env python3 # -*- coding: utf-8 -*- # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Script for updating SPIR-V dialect by scraping information from SPIR-V # HTML and JSON specs from the Internet. # # For example, to define the enum attribute for SPIR-V memory model: # # ./gen_spirv_dialect.py --base_td_path /path/to/SPIRVBase.td \ # --new-enum MemoryModel # # The 'operand_kinds' dict of spirv.core.grammar.json contains all supported # SPIR-V enum classes. import itertools import re import requests import textwrap import yaml SPIRV_HTML_SPEC_URL = 'https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html' SPIRV_JSON_SPEC_URL = 'https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/spirv.core.grammar.json' AUTOGEN_OP_DEF_SEPARATOR = '\n// -----\n\n' AUTOGEN_ENUM_SECTION_MARKER = 'enum section. Generated from SPIR-V spec; DO NOT MODIFY!' AUTOGEN_OPCODE_SECTION_MARKER = ( 'opcode section. Generated from SPIR-V spec; DO NOT MODIFY!') def get_spirv_doc_from_html_spec(): """Extracts instruction documentation from SPIR-V HTML spec. Returns: - A dict mapping from instruction opcode to documentation. """ response = requests.get(SPIRV_HTML_SPEC_URL) spec = response.content from bs4 import BeautifulSoup spirv = BeautifulSoup(spec, 'html.parser') section_anchor = spirv.find('h3', {'id': '_a_id_instructions_a_instructions'}) doc = {} for section in section_anchor.parent.find_all('div', {'class': 'sect3'}): for table in section.find_all('table'): inst_html = table.tbody.tr.td.p opname = inst_html.a['id'] # Ignore the first line, which is just the opname. doc[opname] = inst_html.text.split('\n', 1)[1].strip() return doc def get_spirv_grammar_from_json_spec(): """Extracts operand kind and instruction grammar from SPIR-V JSON spec. Returns: - A list containing all operand kinds' grammar - A list containing all instructions' grammar """ response = requests.get(SPIRV_JSON_SPEC_URL) spec = response.content import json spirv = json.loads(spec) return spirv['operand_kinds'], spirv['instructions'] def split_list_into_sublists(items): """Split the list of items into multiple sublists. This is to make sure the string composed from each sublist won't exceed 80 characters. Arguments: - items: a list of strings """ chuncks = [] chunk = [] chunk_len = 0 for item in items: chunk_len += len(item) + 2 if chunk_len > 80: chuncks.append(chunk) chunk = [] chunk_len = len(item) + 2 chunk.append(item) if len(chunk) != 0: chuncks.append(chunk) return chuncks def uniquify_enum_cases(lst): """Prunes duplicate enum cases from the list. Arguments: - lst: List whose elements are to be uniqued. Assumes each element is a (symbol, value) pair and elements already sorted according to value. Returns: - A list with all duplicates removed. The elements are sorted according to value and, for each value, uniqued according to symbol. original list, - A map from deduplicated cases to the uniqued case. """ cases = lst uniqued_cases = [] duplicated_cases = {} # First sort according to the value cases.sort(key=lambda x: x[1]) # Then group them according to the value for _, groups in itertools.groupby(cases, key=lambda x: x[1]): # For each value, sort according to the enumerant symbol. sorted_group = sorted(groups, key=lambda x: x[0]) # Keep the "smallest" case, which is typically the symbol without extension # suffix. But we have special cases that we want to fix. case = sorted_group[0] for i in range(1, len(sorted_group)): duplicated_cases[sorted_group[i][0]] = case[0] if case[0] == 'HlslSemanticGOOGLE': assert len(sorted_group) == 2, 'unexpected new variant for HlslSemantic' case = sorted_group[1] duplicated_cases[sorted_group[0][0]] = case[0] uniqued_cases.append(case) return uniqued_cases, duplicated_cases def toposort(dag, sort_fn): """Topologically sorts the given dag. Arguments: - dag: a dict mapping from a node to its incoming nodes. - sort_fn: a function for sorting nodes in the same batch. Returns: A list containing topologically sorted nodes. """ # Returns the next batch of nodes without incoming edges def get_next_batch(dag): while True: no_prev_nodes = set(node for node, prev in dag.items() if not prev) if not no_prev_nodes: break yield sorted(no_prev_nodes, key=sort_fn) dag = { node: (prev - no_prev_nodes) for node, prev in dag.items() if node not in no_prev_nodes } assert not dag, 'found cyclic dependency' sorted_nodes = [] for batch in get_next_batch(dag): sorted_nodes.extend(batch) return sorted_nodes def toposort_capabilities(all_cases, capability_mapping): """Returns topologically sorted capability (symbol, value) pairs. Arguments: - all_cases: all capability cases (containing symbol, value, and implied capabilities). - capability_mapping: mapping from duplicated capability symbols to the canonicalized symbol chosen for SPIRVBase.td. Returns: A list containing topologically sorted capability (symbol, value) pairs. """ dag = {} name_to_value = {} for case in all_cases: # Get the current capability. cur = case['enumerant'] name_to_value[cur] = case['value'] # Ignore duplicated symbols. if cur in capability_mapping: continue # Get capabilities implied by the current capability. prev = case.get('capabilities', []) uniqued_prev = set([capability_mapping.get(c, c) for c in prev]) dag[cur] = uniqued_prev sorted_caps = toposort(dag, lambda x: name_to_value[x]) # Attach the capability's value as the second component of the pair. return [(c, name_to_value[c]) for c in sorted_caps] def get_capability_mapping(operand_kinds): """Returns the capability mapping from duplicated cases to canonicalized ones. Arguments: - operand_kinds: all operand kinds' grammar spec Returns: - A map mapping from duplicated capability symbols to the canonicalized symbol chosen for SPIRVBase.td. """ # Find the operand kind for capability cap_kind = {} for kind in operand_kinds: if kind['kind'] == 'Capability': cap_kind = kind kind_cases = [ (case['enumerant'], case['value']) for case in cap_kind['enumerants'] ] _, capability_mapping = uniquify_enum_cases(kind_cases) return capability_mapping def get_availability_spec(enum_case, capability_mapping, for_op, for_cap): """Returns the availability specification string for the given enum case. Arguments: - enum_case: the enum case to generate availability spec for. It may contain 'version', 'lastVersion', 'extensions', or 'capabilities'. - capability_mapping: mapping from duplicated capability symbols to the canonicalized symbol chosen for SPIRVBase.td. - for_op: bool value indicating whether this is the availability spec for an op itself. - for_cap: bool value indicating whether this is the availability spec for capabilities themselves. Returns: - A `let availability = [...];` string if with availability spec or empty string if without availability spec """ assert not (for_op and for_cap), 'cannot set both for_op and for_cap' DEFAULT_MIN_VERSION = 'MinVersion' DEFAULT_MAX_VERSION = 'MaxVersion' DEFAULT_CAP = 'Capability<[]>' DEFAULT_EXT = 'Extension<[]>' min_version = enum_case.get('version', '') if min_version == 'None': min_version = '' elif min_version: min_version = 'MinVersion'.format(min_version.replace('.', '_')) # TODO: delete this once ODS can support dialect-specific content # and we can use omission to mean no requirements. if for_op and not min_version: min_version = DEFAULT_MIN_VERSION max_version = enum_case.get('lastVersion', '') if max_version: max_version = 'MaxVersion'.format(max_version.replace('.', '_')) # TODO: delete this once ODS can support dialect-specific content # and we can use omission to mean no requirements. if for_op and not max_version: max_version = DEFAULT_MAX_VERSION exts = enum_case.get('extensions', []) if exts: exts = 'Extension<[{}]>'.format(', '.join(sorted(set(exts)))) # We need to strip the minimal version requirement if this symbol is # available via an extension, which means *any* SPIR-V version can support # it as long as the extension is provided. The grammar's 'version' field # under such case should be interpreted as this symbol is introduced as # a core symbol since the given version, rather than a minimal version # requirement. min_version = DEFAULT_MIN_VERSION if for_op else '' # TODO: delete this once ODS can support dialect-specific content # and we can use omission to mean no requirements. if for_op and not exts: exts = DEFAULT_EXT caps = enum_case.get('capabilities', []) implies = '' if caps: canonicalized_caps = [] for c in caps: if c in capability_mapping: canonicalized_caps.append(capability_mapping[c]) else: canonicalized_caps.append(c) prefixed_caps = [ 'SPV_C_{}'.format(c) for c in sorted(set(canonicalized_caps)) ] if for_cap: # If this is generating the availability for capabilities, we need to # put the capability "requirements" in implies field because now # the "capabilities" field in the source grammar means so. caps = '' implies = 'list implies = [{}];'.format( ', '.join(prefixed_caps)) else: caps = 'Capability<[{}]>'.format(', '.join(prefixed_caps)) implies = '' # TODO: delete this once ODS can support dialect-specific content # and we can use omission to mean no requirements. if for_op and not caps: caps = DEFAULT_CAP avail = '' # Compose availability spec if any of the requirements is not empty. # For ops, because we have a default in SPV_Op class, omit if the spec # is the same. if (min_version or max_version or caps or exts) and not ( for_op and min_version == DEFAULT_MIN_VERSION and max_version == DEFAULT_MAX_VERSION and caps == DEFAULT_CAP and exts == DEFAULT_EXT): joined_spec = ',\n '.join( [e for e in [min_version, max_version, exts, caps] if e]) avail = '{} availability = [\n {}\n ];'.format( 'let' if for_op else 'list', joined_spec) return '{}{}{}'.format(implies, '\n ' if implies and avail else '', avail) def gen_operand_kind_enum_attr(operand_kind, capability_mapping): """Generates the TableGen EnumAttr definition for the given operand kind. Returns: - The operand kind's name - A string containing the TableGen EnumAttr definition """ if 'enumerants' not in operand_kind: return '', '' # Returns a symbol for the given case in the given kind. This function # handles Dim specially to avoid having numbers as the start of symbols, # which does not play well with C++ and the MLIR parser. def get_case_symbol(kind_name, case_name): if kind_name == 'Dim': if case_name == '1D' or case_name == '2D' or case_name == '3D': return 'Dim{}'.format(case_name) return case_name kind_name = operand_kind['kind'] is_bit_enum = operand_kind['category'] == 'BitEnum' kind_category = 'Bit' if is_bit_enum else 'I32' kind_acronym = ''.join([c for c in kind_name if c >= 'A' and c <= 'Z']) name_to_case_dict = {} for case in operand_kind['enumerants']: name_to_case_dict[case['enumerant']] = case if kind_name == 'Capability': # Special treatment for capability cases: we need to sort them topologically # because a capability can refer to another via the 'implies' field. kind_cases = toposort_capabilities(operand_kind['enumerants'], capability_mapping) else: kind_cases = [(case['enumerant'], case['value']) for case in operand_kind['enumerants']] kind_cases, _ = uniquify_enum_cases(kind_cases) max_len = max([len(symbol) for (symbol, _) in kind_cases]) # Generate the definition for each enum case fmt_str = 'def SPV_{acronym}_{case} {colon:>{offset}} '\ '{category}EnumAttrCase<"{symbol}", {value}>{avail}' case_defs = [] for case in kind_cases: avail = get_availability_spec(name_to_case_dict[case[0]], capability_mapping, False, kind_name == 'Capability') case_def = fmt_str.format( category=kind_category, acronym=kind_acronym, case=case[0], symbol=get_case_symbol(kind_name, case[0]), value=case[1], avail=' {{\n {}\n}}'.format(avail) if avail else ';', colon=':', offset=(max_len + 1 - len(case[0]))) case_defs.append(case_def) case_defs = '\n'.join(case_defs) # Generate the list of enum case names fmt_str = 'SPV_{acronym}_{symbol}'; case_names = [fmt_str.format(acronym=kind_acronym,symbol=case[0]) for case in kind_cases] # Split them into sublists and concatenate into multiple lines case_names = split_list_into_sublists(case_names) case_names = ['{:6}'.format('') + ', '.join(sublist) for sublist in case_names] case_names = ',\n'.join(case_names) # Generate the enum attribute definition enum_attr = '''def SPV_{name}Attr : SPV_{category}EnumAttr<"{name}", "valid SPIR-V {name}", [ {cases} ]>;'''.format( name=kind_name, category=kind_category, cases=case_names) return kind_name, case_defs + '\n\n' + enum_attr def gen_opcode(instructions): """ Generates the TableGen definition to map opname to opcode Returns: - A string containing the TableGen SPV_OpCode definition """ max_len = max([len(inst['opname']) for inst in instructions]) def_fmt_str = 'def SPV_OC_{name} {colon:>{offset}} '\ 'I32EnumAttrCase<"{name}", {value}>;' opcode_defs = [ def_fmt_str.format( name=inst['opname'], value=inst['opcode'], colon=':', offset=(max_len + 1 - len(inst['opname']))) for inst in instructions ] opcode_str = '\n'.join(opcode_defs) decl_fmt_str = 'SPV_OC_{name}' opcode_list = [ decl_fmt_str.format(name=inst['opname']) for inst in instructions ] opcode_list = split_list_into_sublists(opcode_list) opcode_list = [ '{:6}'.format('') + ', '.join(sublist) for sublist in opcode_list ] opcode_list = ',\n'.join(opcode_list) enum_attr = 'def SPV_OpcodeAttr :\n'\ ' SPV_I32EnumAttr<"{name}", "valid SPIR-V instructions", [\n'\ '{lst}\n'\ ' ]>;'.format(name='Opcode', lst=opcode_list) return opcode_str + '\n\n' + enum_attr def map_cap_to_opnames(instructions): """Maps capabilities to instructions enabled by those capabilities Arguments: - instructions: a list containing a subset of SPIR-V instructions' grammar Returns: - A map with keys representing capabilities and values of lists of instructions enabled by the corresponding key """ cap_to_inst = {} for inst in instructions: caps = inst['capabilities'] if 'capabilities' in inst else ['0_core_0'] for cap in caps: if cap not in cap_to_inst: cap_to_inst[cap] = [] cap_to_inst[cap].append(inst['opname']) return cap_to_inst def gen_instr_coverage_report(path, instructions): """Dumps to standard output a YAML report of current instruction coverage Arguments: - path: the path to SPIRBase.td - instructions: a list containing all SPIR-V instructions' grammar """ with open(path, 'r') as f: content = f.read() content = content.split(AUTOGEN_OPCODE_SECTION_MARKER) existing_opcodes = [k[11:] for k in re.findall('def SPV_OC_\w+', content[1])] existing_instructions = list( filter(lambda inst: (inst['opname'] in existing_opcodes), instructions)) instructions_opnames = [inst['opname'] for inst in instructions] remaining_opcodes = list(set(instructions_opnames) - set(existing_opcodes)) remaining_instructions = list( filter(lambda inst: (inst['opname'] in remaining_opcodes), instructions)) rem_cap_to_instr = map_cap_to_opnames(remaining_instructions) ex_cap_to_instr = map_cap_to_opnames(existing_instructions) rem_cap_to_cov = {} # Calculate coverage for each capability for cap in rem_cap_to_instr: if cap not in ex_cap_to_instr: rem_cap_to_cov[cap] = 0.0 else: rem_cap_to_cov[cap] = \ (len(ex_cap_to_instr[cap]) / (len(ex_cap_to_instr[cap]) \ + len(rem_cap_to_instr[cap]))) report = {} # Merge the 3 maps into one report for cap in rem_cap_to_instr: report[cap] = {} report[cap]['Supported Instructions'] = \ ex_cap_to_instr[cap] if cap in ex_cap_to_instr else [] report[cap]['Unsupported Instructions'] = rem_cap_to_instr[cap] report[cap]['Coverage'] = '{}%'.format(int(rem_cap_to_cov[cap] * 100)) print(yaml.dump(report)) def update_td_opcodes(path, instructions, filter_list): """Updates SPIRBase.td with new generated opcode cases. Arguments: - path: the path to SPIRBase.td - instructions: a list containing all SPIR-V instructions' grammar - filter_list: a list containing new opnames to add """ with open(path, 'r') as f: content = f.read() content = content.split(AUTOGEN_OPCODE_SECTION_MARKER) assert len(content) == 3 # Extend opcode list with existing list existing_opcodes = [k[11:] for k in re.findall('def SPV_OC_\w+', content[1])] filter_list.extend(existing_opcodes) filter_list = list(set(filter_list)) # Generate the opcode for all instructions in SPIR-V filter_instrs = list( filter(lambda inst: (inst['opname'] in filter_list), instructions)) # Sort instruction based on opcode filter_instrs.sort(key=lambda inst: inst['opcode']) opcode = gen_opcode(filter_instrs) # Substitute the opcode content = content[0] + AUTOGEN_OPCODE_SECTION_MARKER + '\n\n' + \ opcode + '\n\n// End ' + AUTOGEN_OPCODE_SECTION_MARKER \ + content[2] with open(path, 'w') as f: f.write(content) def update_td_enum_attrs(path, operand_kinds, filter_list): """Updates SPIRBase.td with new generated enum definitions. Arguments: - path: the path to SPIRBase.td - operand_kinds: a list containing all operand kinds' grammar - filter_list: a list containing new enums to add """ with open(path, 'r') as f: content = f.read() content = content.split(AUTOGEN_ENUM_SECTION_MARKER) assert len(content) == 3 # Extend filter list with existing enum definitions existing_kinds = [ k[8:-4] for k in re.findall('def SPV_\w+Attr', content[1])] filter_list.extend(existing_kinds) capability_mapping = get_capability_mapping(operand_kinds) # Generate definitions for all enums in filter list defs = [ gen_operand_kind_enum_attr(kind, capability_mapping) for kind in operand_kinds if kind['kind'] in filter_list ] # Sort alphabetically according to enum name defs.sort(key=lambda enum : enum[0]) # Only keep the definitions from now on # Put Capability's definition at the very beginning because capability cases # will be referenced later defs = [enum[1] for enum in defs if enum[0] == 'Capability' ] + [enum[1] for enum in defs if enum[0] != 'Capability'] # Substitute the old section content = content[0] + AUTOGEN_ENUM_SECTION_MARKER + '\n\n' + \ '\n\n'.join(defs) + "\n\n// End " + AUTOGEN_ENUM_SECTION_MARKER \ + content[2]; with open(path, 'w') as f: f.write(content) def snake_casify(name): """Turns the given name to follow snake_case convention.""" name = re.sub('\W+', '', name).split() name = [s.lower() for s in name] return '_'.join(name) def map_spec_operand_to_ods_argument(operand): """Maps an operand in SPIR-V JSON spec to an op argument in ODS. Arguments: - A dict containing the operand's kind, quantifier, and name Returns: - A string containing both the type and name for the argument """ kind = operand['kind'] quantifier = operand.get('quantifier', '') # These instruction "operands" are for encoding the results; they should # not be handled here. assert kind != 'IdResultType', 'unexpected to handle "IdResultType" kind' assert kind != 'IdResult', 'unexpected to handle "IdResult" kind' if kind == 'IdRef': if quantifier == '': arg_type = 'SPV_Type' elif quantifier == '?': arg_type = 'Optional' else: arg_type = 'Variadic' elif kind == 'IdMemorySemantics' or kind == 'IdScope': # TODO: Need to further constrain 'IdMemorySemantics' # and 'IdScope' given that they should be generated from OpConstant. assert quantifier == '', ('unexpected to have optional/variadic memory ' 'semantics or scope ') arg_type = 'SPV_' + kind[2:] + 'Attr' elif kind == 'LiteralInteger': if quantifier == '': arg_type = 'I32Attr' elif quantifier == '?': arg_type = 'OptionalAttr' else: arg_type = 'OptionalAttr' elif kind == 'LiteralString' or \ kind == 'LiteralContextDependentNumber' or \ kind == 'LiteralExtInstInteger' or \ kind == 'LiteralSpecConstantOpInteger' or \ kind == 'PairLiteralIntegerIdRef' or \ kind == 'PairIdRefLiteralInteger' or \ kind == 'PairIdRefIdRef': assert False, '"{}" kind unimplemented'.format(kind) else: # The rest are all enum operands that we represent with op attributes. assert quantifier != '*', 'unexpected to have variadic enum attribute' arg_type = 'SPV_{}Attr'.format(kind) if quantifier == '?': arg_type = 'OptionalAttr<{}>'.format(arg_type) name = operand.get('name', '') name = snake_casify(name) if name else kind.lower() return '{}:${}'.format(arg_type, name) def get_description(text, appendix): """Generates the description for the given SPIR-V instruction. Arguments: - text: Textual description of the operation as string. - appendix: Additional contents to attach in description as string, includking IR examples, and others. Returns: - A string that corresponds to the description of the Tablegen op. """ fmt_str = '{text}\n\n \n{appendix}\n ' return fmt_str.format(text=text, appendix=appendix) def get_op_definition(instruction, doc, existing_info, capability_mapping): """Generates the TableGen op definition for the given SPIR-V instruction. Arguments: - instruction: the instruction's SPIR-V JSON grammar - doc: the instruction's SPIR-V HTML doc - existing_info: a dict containing potential manually specified sections for this instruction - capability_mapping: mapping from duplicated capability symbols to the canonicalized symbol chosen for SPIRVBase.td Returns: - A string containing the TableGen op definition """ fmt_str = ('def SPV_{opname}Op : ' 'SPV_{inst_category}<"{opname}"{category_args}[{traits}]> ' '{{\n let summary = {summary};\n\n let description = ' '[{{\n{description}}}];{availability}\n') inst_category = existing_info.get('inst_category', 'Op') if inst_category == 'Op': fmt_str +='\n let arguments = (ins{args});\n\n'\ ' let results = (outs{results});\n' fmt_str +='{extras}'\ '}}\n' opname = instruction['opname'][2:] category_args = existing_info.get('category_args', '') if '\n' in doc: summary, text = doc.split('\n', 1) else: summary = doc text = '' wrapper = textwrap.TextWrapper( width=76, initial_indent=' ', subsequent_indent=' ') # Format summary. If the summary can fit in the same line, we print it out # as a "-quoted string; otherwise, wrap the lines using "[{...}]". summary = summary.strip(); if len(summary) + len(' let summary = "";') <= 80: summary = '"{}"'.format(summary) else: summary = '[{{\n{}\n }}]'.format(wrapper.fill(summary)) # Wrap text text = text.split('\n') text = [wrapper.fill(line) for line in text if line] text = '\n\n'.join(text) operands = instruction.get('operands', []) # Op availability avail = '' # We assume other instruction categories has a base availability spec, so # only add this if this is directly using SPV_Op as the base. if inst_category == 'Op': avail = get_availability_spec(instruction, capability_mapping, True, False) if avail: avail = '\n\n {0}'.format(avail) # Set op's result results = '' if len(operands) > 0 and operands[0]['kind'] == 'IdResultType': results = '\n SPV_Type:$result\n ' operands = operands[1:] if 'results' in existing_info: results = existing_info['results'] # Ignore the operand standing for the result if len(operands) > 0 and operands[0]['kind'] == 'IdResult': operands = operands[1:] # Set op' argument arguments = existing_info.get('arguments', None) if arguments is None: arguments = [map_spec_operand_to_ods_argument(o) for o in operands] arguments = ',\n '.join(arguments) if arguments: # Prepend and append whitespace for formatting arguments = '\n {}\n '.format(arguments) description = existing_info.get('description', None) if description is None: assembly = '\n ```\n'\ ' [TODO]\n'\ ' ```mlir\n\n'\ ' #### Example:\n\n'\ ' ```\n'\ ' [TODO]\n' \ ' ```' description = get_description(text, assembly) return fmt_str.format( opname=opname, category_args=category_args, inst_category=inst_category, traits=existing_info.get('traits', ''), summary=summary, description=description, availability=avail, args=arguments, results=results, extras=existing_info.get('extras', '')) def get_string_between(base, start, end): """Extracts a substring with a specified start and end from a string. Arguments: - base: string to extract from. - start: string to use as the start of the substring. - end: string to use as the end of the substring. Returns: - The substring if found - The part of the base after end of the substring. Is the base string itself if the substring wasnt found. """ split = base.split(start, 1) if len(split) == 2: rest = split[1].split(end, 1) assert len(rest) == 2, \ 'cannot find end "{end}" while extracting substring '\ 'starting with {start}'.format(start=start, end=end) return rest[0].rstrip(end), rest[1] return '', split[0] def get_string_between_nested(base, start, end): """Extracts a substring with a nested start and end from a string. Arguments: - base: string to extract from. - start: string to use as the start of the substring. - end: string to use as the end of the substring. Returns: - The substring if found - The part of the base after end of the substring. Is the base string itself if the substring wasn't found. """ split = base.split(start, 1) if len(split) == 2: # Handle nesting delimiters rest = split[1] unmatched_start = 1 index = 0 while unmatched_start > 0 and index < len(rest): if rest[index:].startswith(end): unmatched_start -= 1 if unmatched_start == 0: break index += len(end) elif rest[index:].startswith(start): unmatched_start += 1 index += len(start) else: index += 1 assert index < len(rest), \ 'cannot find end "{end}" while extracting substring '\ 'starting with "{start}"'.format(start=start, end=end) return rest[:index], rest[index + len(end):] return '', split[0] def extract_td_op_info(op_def): """Extracts potentially manually specified sections in op's definition. Arguments: - A string containing the op's TableGen definition Returns: - A dict containing potential manually specified sections """ # Get opname opname = [o[8:-2] for o in re.findall('def SPV_\w+Op', op_def)] assert len(opname) == 1, 'more than one ops in the same section!' opname = opname[0] # Get instruction category inst_category = [ o[4:] for o in re.findall('SPV_\w+Op', op_def.split(':', 1)[1]) ] assert len(inst_category) <= 1, 'more than one ops in the same section!' inst_category = inst_category[0] if len(inst_category) == 1 else 'Op' # Get category_args op_tmpl_params, _ = get_string_between_nested(op_def, '<', '>') opstringname, rest = get_string_between(op_tmpl_params, '"', '"') category_args = rest.split('[', 1)[0] # Get traits traits, _ = get_string_between_nested(rest, '[', ']') # Get description description, rest = get_string_between(op_def, 'let description = [{\n', '}];\n') # Get arguments args, rest = get_string_between(rest, ' let arguments = (ins', ');\n') # Get results results, rest = get_string_between(rest, ' let results = (outs', ');\n') extras = rest.strip(' }\n') if extras: extras = '\n {}\n'.format(extras) return { # Prefix with 'Op' to make it consistent with SPIR-V spec 'opname': 'Op{}'.format(opname), 'inst_category': inst_category, 'category_args': category_args, 'traits': traits, 'description': description, 'arguments': args, 'results': results, 'extras': extras } def update_td_op_definitions(path, instructions, docs, filter_list, inst_category, capability_mapping): """Updates SPIRVOps.td with newly generated op definition. Arguments: - path: path to SPIRVOps.td - instructions: SPIR-V JSON grammar for all instructions - docs: SPIR-V HTML doc for all instructions - filter_list: a list containing new opnames to include - capability_mapping: mapping from duplicated capability symbols to the canonicalized symbol chosen for SPIRVBase.td. Returns: - A string containing all the TableGen op definitions """ with open(path, 'r') as f: content = f.read() # Split the file into chunks, each containing one op. ops = content.split(AUTOGEN_OP_DEF_SEPARATOR) header = ops[0] footer = ops[-1] ops = ops[1:-1] # For each existing op, extract the manually-written sections out to retain # them when re-generating the ops. Also append the existing ops to filter # list. name_op_map = {} # Map from opname to its existing ODS definition op_info_dict = {} for op in ops: info_dict = extract_td_op_info(op) opname = info_dict['opname'] name_op_map[opname] = op op_info_dict[opname] = info_dict filter_list.append(opname) filter_list = sorted(list(set(filter_list))) op_defs = [] for opname in filter_list: # Find the grammar spec for this op try: instruction = next( inst for inst in instructions if inst['opname'] == opname) op_defs.append( get_op_definition( instruction, docs[opname], op_info_dict.get(opname, {'inst_category': inst_category}), capability_mapping)) except StopIteration: # This is an op added by us; use the existing ODS definition. op_defs.append(name_op_map[opname]) # Substitute the old op definitions op_defs = [header] + op_defs + [footer] content = AUTOGEN_OP_DEF_SEPARATOR.join(op_defs) with open(path, 'w') as f: f.write(content) if __name__ == '__main__': import argparse cli_parser = argparse.ArgumentParser( description='Update SPIR-V dialect definitions using SPIR-V spec') cli_parser.add_argument( '--base-td-path', dest='base_td_path', type=str, default=None, help='Path to SPIRVBase.td') cli_parser.add_argument( '--op-td-path', dest='op_td_path', type=str, default=None, help='Path to SPIRVOps.td') cli_parser.add_argument( '--new-enum', dest='new_enum', type=str, default=None, help='SPIR-V enum to be added to SPIRVBase.td') cli_parser.add_argument( '--new-opcodes', dest='new_opcodes', type=str, default=None, nargs='*', help='update SPIR-V opcodes in SPIRVBase.td') cli_parser.add_argument( '--new-inst', dest='new_inst', type=str, default=None, nargs='*', help='SPIR-V instruction to be added to ops file') cli_parser.add_argument( '--inst-category', dest='inst_category', type=str, default='Op', help='SPIR-V instruction category used for choosing '\ 'the TableGen base class to define this op') cli_parser.add_argument('--gen-inst-coverage', dest='gen_inst_coverage', action='store_true') cli_parser.set_defaults(gen_inst_coverage=False) args = cli_parser.parse_args() operand_kinds, instructions = get_spirv_grammar_from_json_spec() # Define new enum attr if args.new_enum is not None: assert args.base_td_path is not None filter_list = [args.new_enum] if args.new_enum else [] update_td_enum_attrs(args.base_td_path, operand_kinds, filter_list) # Define new opcode if args.new_opcodes is not None: assert args.base_td_path is not None update_td_opcodes(args.base_td_path, instructions, args.new_opcodes) # Define new op if args.new_inst is not None: assert args.op_td_path is not None docs = get_spirv_doc_from_html_spec() capability_mapping = get_capability_mapping(operand_kinds) update_td_op_definitions(args.op_td_path, instructions, docs, args.new_inst, args.inst_category, capability_mapping) print('Done. Note that this script just generates a template; ', end='') print('please read the spec and update traits, arguments, and ', end='') print('results accordingly.') if args.gen_inst_coverage: gen_instr_coverage_report(args.base_td_path, instructions)