#!/usr/bin/env python3 # -*- coding: utf-8 -*- # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Script for updating SPIR-V dialect by scraping information from SPIR-V # HTML and JSON specs from the Internet. # # For example, to define the enum attribute for SPIR-V memory model: # # ./gen_spirv_dialect.py --base-td-path /path/to/SPIRVBase.td \ # --new-enum MemoryModel # # The 'operand_kinds' dict of spirv.core.grammar.json contains all supported # SPIR-V enum classes. import itertools import math import re import requests import textwrap import yaml SPIRV_HTML_SPEC_URL = ( "https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html" ) SPIRV_JSON_SPEC_URL = "https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/spirv.core.grammar.json" SPIRV_CL_EXT_HTML_SPEC_URL = "https://www.khronos.org/registry/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html" SPIRV_CL_EXT_JSON_SPEC_URL = "https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/extinst.opencl.std.100.grammar.json" AUTOGEN_OP_DEF_SEPARATOR = "\n// -----\n\n" AUTOGEN_ENUM_SECTION_MARKER = "enum section. Generated from SPIR-V spec; DO NOT MODIFY!" AUTOGEN_OPCODE_SECTION_MARKER = ( "opcode section. Generated from SPIR-V spec; DO NOT MODIFY!" ) def get_spirv_doc_from_html_spec(url, settings): """Extracts instruction documentation from SPIR-V HTML spec. Returns: - A dict mapping from instruction opcode to documentation. """ if url is None: url = SPIRV_HTML_SPEC_URL response = requests.get(url) spec = response.content from bs4 import BeautifulSoup spirv = BeautifulSoup(spec, "html.parser") doc = {} if settings.gen_cl_ops: section_anchor = spirv.find("h2", {"id": "_binary_form"}) for section in section_anchor.parent.find_all("div", {"class": "sect2"}): for table in section.find_all("table"): inst_html = table.tbody.tr.td opname = inst_html.a["id"] # Ignore the first line, which is just the opname. doc[opname] = inst_html.text.split("\n", 1)[1].strip() else: section_anchor = spirv.find("h3", {"id": "_instructions_3"}) for section in section_anchor.parent.find_all("div", {"class": "sect3"}): for table in section.find_all("table"): inst_html = table.tbody.tr.td.p opname = inst_html.a["id"] # Ignore the first line, which is just the opname. doc[opname] = inst_html.text.split("\n", 1)[1].strip() return doc def get_spirv_grammar_from_json_spec(url): """Extracts operand kind and instruction grammar from SPIR-V JSON spec. Returns: - A list containing all operand kinds' grammar - A list containing all instructions' grammar """ response = requests.get(SPIRV_JSON_SPEC_URL) spec = response.content import json spirv = json.loads(spec) if url is None: return spirv["operand_kinds"], spirv["instructions"] response_ext = requests.get(url) spec_ext = response_ext.content spirv_ext = json.loads(spec_ext) return spirv["operand_kinds"], spirv_ext["instructions"] def split_list_into_sublists(items): """Split the list of items into multiple sublists. This is to make sure the string composed from each sublist won't exceed 80 characters. Arguments: - items: a list of strings """ chuncks = [] chunk = [] chunk_len = 0 for item in items: chunk_len += len(item) + 2 if chunk_len > 80: chuncks.append(chunk) chunk = [] chunk_len = len(item) + 2 chunk.append(item) if len(chunk) != 0: chuncks.append(chunk) return chuncks def uniquify_enum_cases(lst): """Prunes duplicate enum cases from the list. Arguments: - lst: List whose elements are to be uniqued. Assumes each element is a (symbol, value) pair and elements already sorted according to value. Returns: - A list with all duplicates removed. The elements are sorted according to value and, for each value, uniqued according to symbol. original list, - A map from deduplicated cases to the uniqued case. """ cases = lst uniqued_cases = [] duplicated_cases = {} # First sort according to the value cases.sort(key=lambda x: x[1]) # Then group them according to the value for _, groups in itertools.groupby(cases, key=lambda x: x[1]): # For each value, sort according to the enumerant symbol. sorted_group = sorted(groups, key=lambda x: x[0]) # Keep the "smallest" case, which is typically the symbol without extension # suffix. But we have special cases that we want to fix. case = sorted_group[0] for i in range(1, len(sorted_group)): duplicated_cases[sorted_group[i][0]] = case[0] if case[0] == "HlslSemanticGOOGLE": assert len(sorted_group) == 2, "unexpected new variant for HlslSemantic" case = sorted_group[1] duplicated_cases[sorted_group[0][0]] = case[0] uniqued_cases.append(case) return uniqued_cases, duplicated_cases def toposort(dag, sort_fn): """Topologically sorts the given dag. Arguments: - dag: a dict mapping from a node to its incoming nodes. - sort_fn: a function for sorting nodes in the same batch. Returns: A list containing topologically sorted nodes. """ # Returns the next batch of nodes without incoming edges def get_next_batch(dag): while True: no_prev_nodes = set(node for node, prev in dag.items() if not prev) if not no_prev_nodes: break yield sorted(no_prev_nodes, key=sort_fn) dag = { node: (prev - no_prev_nodes) for node, prev in dag.items() if node not in no_prev_nodes } assert not dag, "found cyclic dependency" sorted_nodes = [] for batch in get_next_batch(dag): sorted_nodes.extend(batch) return sorted_nodes def toposort_capabilities(all_cases, capability_mapping): """Returns topologically sorted capability (symbol, value) pairs. Arguments: - all_cases: all capability cases (containing symbol, value, and implied capabilities). - capability_mapping: mapping from duplicated capability symbols to the canonicalized symbol chosen for SPIRVBase.td. Returns: A list containing topologically sorted capability (symbol, value) pairs. """ dag = {} name_to_value = {} for case in all_cases: # Get the current capability. cur = case["enumerant"] name_to_value[cur] = case["value"] # Ignore duplicated symbols. if cur in capability_mapping: continue # Get capabilities implied by the current capability. prev = case.get("capabilities", []) uniqued_prev = set([capability_mapping.get(c, c) for c in prev]) dag[cur] = uniqued_prev sorted_caps = toposort(dag, lambda x: name_to_value[x]) # Attach the capability's value as the second component of the pair. return [(c, name_to_value[c]) for c in sorted_caps] def get_capability_mapping(operand_kinds): """Returns the capability mapping from duplicated cases to canonicalized ones. Arguments: - operand_kinds: all operand kinds' grammar spec Returns: - A map mapping from duplicated capability symbols to the canonicalized symbol chosen for SPIRVBase.td. """ # Find the operand kind for capability cap_kind = {} for kind in operand_kinds: if kind["kind"] == "Capability": cap_kind = kind kind_cases = [(case["enumerant"], case["value"]) for case in cap_kind["enumerants"]] _, capability_mapping = uniquify_enum_cases(kind_cases) return capability_mapping def get_availability_spec(enum_case, capability_mapping, for_op, for_cap): """Returns the availability specification string for the given enum case. Arguments: - enum_case: the enum case to generate availability spec for. It may contain 'version', 'lastVersion', 'extensions', or 'capabilities'. - capability_mapping: mapping from duplicated capability symbols to the canonicalized symbol chosen for SPIRVBase.td. - for_op: bool value indicating whether this is the availability spec for an op itself. - for_cap: bool value indicating whether this is the availability spec for capabilities themselves. Returns: - A `let availability = [...];` string if with availability spec or empty string if without availability spec """ assert not (for_op and for_cap), "cannot set both for_op and for_cap" DEFAULT_MIN_VERSION = "MinVersion" DEFAULT_MAX_VERSION = "MaxVersion" DEFAULT_CAP = "Capability<[]>" DEFAULT_EXT = "Extension<[]>" min_version = enum_case.get("version", "") if min_version == "None": min_version = "" elif min_version: min_version = "MinVersion".format(min_version.replace(".", "_")) # TODO: delete this once ODS can support dialect-specific content # and we can use omission to mean no requirements. if for_op and not min_version: min_version = DEFAULT_MIN_VERSION max_version = enum_case.get("lastVersion", "") if max_version: max_version = "MaxVersion".format(max_version.replace(".", "_")) # TODO: delete this once ODS can support dialect-specific content # and we can use omission to mean no requirements. if for_op and not max_version: max_version = DEFAULT_MAX_VERSION exts = enum_case.get("extensions", []) if exts: exts = "Extension<[{}]>".format(", ".join(sorted(set(exts)))) # We need to strip the minimal version requirement if this symbol is # available via an extension, which means *any* SPIR-V version can support # it as long as the extension is provided. The grammar's 'version' field # under such case should be interpreted as this symbol is introduced as # a core symbol since the given version, rather than a minimal version # requirement. min_version = DEFAULT_MIN_VERSION if for_op else "" # TODO: delete this once ODS can support dialect-specific content # and we can use omission to mean no requirements. if for_op and not exts: exts = DEFAULT_EXT caps = enum_case.get("capabilities", []) implies = "" if caps: canonicalized_caps = [] for c in caps: if c in capability_mapping: canonicalized_caps.append(capability_mapping[c]) else: canonicalized_caps.append(c) prefixed_caps = [ "SPIRV_C_{}".format(c) for c in sorted(set(canonicalized_caps)) ] if for_cap: # If this is generating the availability for capabilities, we need to # put the capability "requirements" in implies field because now # the "capabilities" field in the source grammar means so. caps = "" implies = "list implies = [{}];".format( ", ".join(prefixed_caps) ) else: caps = "Capability<[{}]>".format(", ".join(prefixed_caps)) implies = "" # TODO: delete this once ODS can support dialect-specific content # and we can use omission to mean no requirements. if for_op and not caps: caps = DEFAULT_CAP avail = "" # Compose availability spec if any of the requirements is not empty. # For ops, because we have a default in SPIRV_Op class, omit if the spec # is the same. if (min_version or max_version or caps or exts) and not ( for_op and min_version == DEFAULT_MIN_VERSION and max_version == DEFAULT_MAX_VERSION and caps == DEFAULT_CAP and exts == DEFAULT_EXT ): joined_spec = ",\n ".join( [e for e in [min_version, max_version, exts, caps] if e] ) avail = "{} availability = [\n {}\n ];".format( "let" if for_op else "list", joined_spec ) return "{}{}{}".format(implies, "\n " if implies and avail else "", avail) def gen_operand_kind_enum_attr(operand_kind, capability_mapping): """Generates the TableGen EnumAttr definition for the given operand kind. Returns: - The operand kind's name - A string containing the TableGen EnumAttr definition """ if "enumerants" not in operand_kind: return "", "" # Returns a symbol for the given case in the given kind. This function # handles Dim specially to avoid having numbers as the start of symbols, # which does not play well with C++ and the MLIR parser. def get_case_symbol(kind_name, case_name): if kind_name == "Dim": if case_name == "1D" or case_name == "2D" or case_name == "3D": return "Dim{}".format(case_name) return case_name kind_name = operand_kind["kind"] is_bit_enum = operand_kind["category"] == "BitEnum" kind_acronym = "".join([c for c in kind_name if c >= "A" and c <= "Z"]) name_to_case_dict = {} for case in operand_kind["enumerants"]: name_to_case_dict[case["enumerant"]] = case if kind_name == "Capability": # Special treatment for capability cases: we need to sort them topologically # because a capability can refer to another via the 'implies' field. kind_cases = toposort_capabilities( operand_kind["enumerants"], capability_mapping ) else: kind_cases = [ (case["enumerant"], case["value"]) for case in operand_kind["enumerants"] ] kind_cases, _ = uniquify_enum_cases(kind_cases) max_len = max([len(symbol) for (symbol, _) in kind_cases]) # Generate the definition for each enum case case_category = "I32Bit" if is_bit_enum else "I32" fmt_str = ( "def SPIRV_{acronym}_{case_name} {colon:>{offset}} " '{category}EnumAttrCase{suffix}<"{symbol}"{case_value_part}>{avail}' ) case_defs = [] for case_pair in kind_cases: name = case_pair[0] if is_bit_enum: value = int(case_pair[1], base=16) else: value = int(case_pair[1]) avail = get_availability_spec( name_to_case_dict[name], capability_mapping, False, kind_name == "Capability", ) if is_bit_enum: if value == 0: suffix = "None" value = "" else: suffix = "Bit" value = ", {}".format(int(math.log2(value))) else: suffix = "" value = ", {}".format(value) case_def = fmt_str.format( category=case_category, suffix=suffix, acronym=kind_acronym, case_name=name, symbol=get_case_symbol(kind_name, name), case_value_part=value, avail=" {{\n {}\n}}".format(avail) if avail else ";", colon=":", offset=(max_len + 1 - len(name)), ) case_defs.append(case_def) case_defs = "\n".join(case_defs) # Generate the list of enum case names fmt_str = "SPIRV_{acronym}_{symbol}" case_names = [ fmt_str.format(acronym=kind_acronym, symbol=case[0]) for case in kind_cases ] # Split them into sublists and concatenate into multiple lines case_names = split_list_into_sublists(case_names) case_names = ["{:6}".format("") + ", ".join(sublist) for sublist in case_names] case_names = ",\n".join(case_names) # Generate the enum attribute definition kind_category = "Bit" if is_bit_enum else "I32" enum_attr = """def SPIRV_{name}Attr : SPIRV_{category}EnumAttr<"{name}", "valid SPIR-V {name}", "{snake_name}", [ {cases} ]>;""".format( name=kind_name, snake_name=snake_casify(kind_name), category=kind_category, cases=case_names, ) return kind_name, case_defs + "\n\n" + enum_attr def gen_opcode(instructions): """Generates the TableGen definition to map opname to opcode Returns: - A string containing the TableGen SPIRV_OpCode definition """ max_len = max([len(inst["opname"]) for inst in instructions]) def_fmt_str = ( "def SPIRV_OC_{name} {colon:>{offset}} " 'I32EnumAttrCase<"{name}", {value}>;' ) opcode_defs = [ def_fmt_str.format( name=inst["opname"], value=inst["opcode"], colon=":", offset=(max_len + 1 - len(inst["opname"])), ) for inst in instructions ] opcode_str = "\n".join(opcode_defs) decl_fmt_str = "SPIRV_OC_{name}" opcode_list = [decl_fmt_str.format(name=inst["opname"]) for inst in instructions] opcode_list = split_list_into_sublists(opcode_list) opcode_list = ["{:6}".format("") + ", ".join(sublist) for sublist in opcode_list] opcode_list = ",\n".join(opcode_list) enum_attr = ( "def SPIRV_OpcodeAttr :\n" ' SPIRV_I32EnumAttr<"{name}", "valid SPIR-V instructions", ' '"opcode", [\n' "{lst}\n" " ]>;".format(name="Opcode", lst=opcode_list) ) return opcode_str + "\n\n" + enum_attr def map_cap_to_opnames(instructions): """Maps capabilities to instructions enabled by those capabilities Arguments: - instructions: a list containing a subset of SPIR-V instructions' grammar Returns: - A map with keys representing capabilities and values of lists of instructions enabled by the corresponding key """ cap_to_inst = {} for inst in instructions: caps = inst["capabilities"] if "capabilities" in inst else ["0_core_0"] for cap in caps: if cap not in cap_to_inst: cap_to_inst[cap] = [] cap_to_inst[cap].append(inst["opname"]) return cap_to_inst def gen_instr_coverage_report(path, instructions): """Dumps to standard output a YAML report of current instruction coverage Arguments: - path: the path to SPIRBase.td - instructions: a list containing all SPIR-V instructions' grammar """ with open(path, "r") as f: content = f.read() content = content.split(AUTOGEN_OPCODE_SECTION_MARKER) existing_opcodes = [k[11:] for k in re.findall("def SPIRV_OC_\w+", content[1])] existing_instructions = list( filter(lambda inst: (inst["opname"] in existing_opcodes), instructions) ) instructions_opnames = [inst["opname"] for inst in instructions] remaining_opcodes = list(set(instructions_opnames) - set(existing_opcodes)) remaining_instructions = list( filter(lambda inst: (inst["opname"] in remaining_opcodes), instructions) ) rem_cap_to_instr = map_cap_to_opnames(remaining_instructions) ex_cap_to_instr = map_cap_to_opnames(existing_instructions) rem_cap_to_cov = {} # Calculate coverage for each capability for cap in rem_cap_to_instr: if cap not in ex_cap_to_instr: rem_cap_to_cov[cap] = 0.0 else: rem_cap_to_cov[cap] = len(ex_cap_to_instr[cap]) / ( len(ex_cap_to_instr[cap]) + len(rem_cap_to_instr[cap]) ) report = {} # Merge the 3 maps into one report for cap in rem_cap_to_instr: report[cap] = {} report[cap]["Supported Instructions"] = ( ex_cap_to_instr[cap] if cap in ex_cap_to_instr else [] ) report[cap]["Unsupported Instructions"] = rem_cap_to_instr[cap] report[cap]["Coverage"] = "{}%".format(int(rem_cap_to_cov[cap] * 100)) print(yaml.dump(report)) def update_td_opcodes(path, instructions, filter_list): """Updates SPIRBase.td with new generated opcode cases. Arguments: - path: the path to SPIRBase.td - instructions: a list containing all SPIR-V instructions' grammar - filter_list: a list containing new opnames to add """ with open(path, "r") as f: content = f.read() content = content.split(AUTOGEN_OPCODE_SECTION_MARKER) assert len(content) == 3 # Extend opcode list with existing list prefix = "def SPIRV_OC_" existing_opcodes = [ k[len(prefix) :] for k in re.findall(prefix + "\w+", content[1]) ] filter_list.extend(existing_opcodes) filter_list = list(set(filter_list)) # Generate the opcode for all instructions in SPIR-V filter_instrs = list( filter(lambda inst: (inst["opname"] in filter_list), instructions) ) # Sort instruction based on opcode filter_instrs.sort(key=lambda inst: inst["opcode"]) opcode = gen_opcode(filter_instrs) # Substitute the opcode content = ( content[0] + AUTOGEN_OPCODE_SECTION_MARKER + "\n\n" + opcode + "\n\n// End " + AUTOGEN_OPCODE_SECTION_MARKER + content[2] ) with open(path, "w") as f: f.write(content) def update_td_enum_attrs(path, operand_kinds, filter_list): """Updates SPIRBase.td with new generated enum definitions. Arguments: - path: the path to SPIRBase.td - operand_kinds: a list containing all operand kinds' grammar - filter_list: a list containing new enums to add """ with open(path, "r") as f: content = f.read() content = content.split(AUTOGEN_ENUM_SECTION_MARKER) assert len(content) == 3 # Extend filter list with existing enum definitions existing_kinds = [k[8:-4] for k in re.findall("def SPIRV_\w+Attr", content[1])] filter_list.extend(existing_kinds) capability_mapping = get_capability_mapping(operand_kinds) # Generate definitions for all enums in filter list defs = [ gen_operand_kind_enum_attr(kind, capability_mapping) for kind in operand_kinds if kind["kind"] in filter_list ] # Sort alphabetically according to enum name defs.sort(key=lambda enum: enum[0]) # Only keep the definitions from now on # Put Capability's definition at the very beginning because capability cases # will be referenced later defs = [enum[1] for enum in defs if enum[0] == "Capability"] + [ enum[1] for enum in defs if enum[0] != "Capability" ] # Substitute the old section content = ( content[0] + AUTOGEN_ENUM_SECTION_MARKER + "\n\n" + "\n\n".join(defs) + "\n\n// End " + AUTOGEN_ENUM_SECTION_MARKER + content[2] ) with open(path, "w") as f: f.write(content) def snake_casify(name): """Turns the given name to follow snake_case convention.""" return re.sub(r"(?" else: arg_type = "Variadic" elif kind == "IdMemorySemantics" or kind == "IdScope": # TODO: Need to further constrain 'IdMemorySemantics' # and 'IdScope' given that they should be generated from OpConstant. assert quantifier == "", ( "unexpected to have optional/variadic memory " "semantics or scope " ) arg_type = "SPIRV_" + kind[2:] + "Attr" elif kind == "LiteralInteger": if quantifier == "": arg_type = "I32Attr" elif quantifier == "?": arg_type = "OptionalAttr" else: arg_type = "OptionalAttr" elif ( kind == "LiteralString" or kind == "LiteralContextDependentNumber" or kind == "LiteralExtInstInteger" or kind == "LiteralSpecConstantOpInteger" or kind == "PairLiteralIntegerIdRef" or kind == "PairIdRefLiteralInteger" or kind == "PairIdRefIdRef" ): assert False, '"{}" kind unimplemented'.format(kind) else: # The rest are all enum operands that we represent with op attributes. assert quantifier != "*", "unexpected to have variadic enum attribute" arg_type = "SPIRV_{}Attr".format(kind) if quantifier == "?": arg_type = "OptionalAttr<{}>".format(arg_type) name = operand.get("name", "") name = snake_casify(name) if name else kind.lower() return "{}:${}".format(arg_type, name) def get_description(text, appendix): """Generates the description for the given SPIR-V instruction. Arguments: - text: Textual description of the operation as string. - appendix: Additional contents to attach in description as string, includking IR examples, and others. Returns: - A string that corresponds to the description of the Tablegen op. """ fmt_str = "{text}\n\n \n{appendix}\n " return fmt_str.format(text=text, appendix=appendix) def get_op_definition( instruction, opname, doc, existing_info, capability_mapping, settings ): """Generates the TableGen op definition for the given SPIR-V instruction. Arguments: - instruction: the instruction's SPIR-V JSON grammar - doc: the instruction's SPIR-V HTML doc - existing_info: a dict containing potential manually specified sections for this instruction - capability_mapping: mapping from duplicated capability symbols to the canonicalized symbol chosen for SPIRVBase.td Returns: - A string containing the TableGen op definition """ if settings.gen_cl_ops: fmt_str = ( "def SPIRV_{opname}Op : " 'SPIRV_{inst_category}<"{opname_src}", {opcode}, <> > ' "{{\n let summary = {summary};\n\n let description = " "[{{\n{description}}}];{availability}\n" ) else: fmt_str = ( "def SPIRV_{vendor_name}{opname_src}Op : " 'SPIRV_{inst_category}<"{opname_src}"{category_args}, [{traits}]> ' "{{\n let summary = {summary};\n\n let description = " "[{{\n{description}}}];{availability}\n" ) vendor_name = "" inst_category = existing_info.get("inst_category", "Op") if inst_category == "Op": fmt_str += ( "\n let arguments = (ins{args});\n\n" " let results = (outs{results});\n" ) elif inst_category.endswith("VendorOp"): vendor_name = inst_category.split("VendorOp")[0].upper() assert len(vendor_name) != 0, "Invalid instruction category" fmt_str += "{extras}" "}}\n" opname_src = instruction["opname"] if opname.startswith("Op"): opname_src = opname_src[2:] if len(vendor_name) > 0: assert opname_src.endswith( vendor_name ), "op name does not match the instruction category" opname_src = opname_src[: -len(vendor_name)] category_args = existing_info.get("category_args", "") if "\n" in doc: summary, text = doc.split("\n", 1) else: summary = doc text = "" wrapper = textwrap.TextWrapper( width=76, initial_indent=" ", subsequent_indent=" " ) # Format summary. If the summary can fit in the same line, we print it out # as a "-quoted string; otherwise, wrap the lines using "[{...}]". summary = summary.strip() if len(summary) + len(' let summary = "";') <= 80: summary = '"{}"'.format(summary) else: summary = "[{{\n{}\n }}]".format(wrapper.fill(summary)) # Wrap text text = text.split("\n") text = [wrapper.fill(line) for line in text if line] text = "\n\n".join(text) operands = instruction.get("operands", []) # Op availability avail = get_availability_spec(instruction, capability_mapping, True, False) if avail: avail = "\n\n {0}".format(avail) # Set op's result results = "" if len(operands) > 0 and operands[0]["kind"] == "IdResultType": results = "\n SPIRV_Type:$result\n " operands = operands[1:] if "results" in existing_info: results = existing_info["results"] # Ignore the operand standing for the result if len(operands) > 0 and operands[0]["kind"] == "IdResult": operands = operands[1:] # Set op' argument arguments = existing_info.get("arguments", None) if arguments is None: arguments = [map_spec_operand_to_ods_argument(o) for o in operands] arguments = ",\n ".join(arguments) if arguments: # Prepend and append whitespace for formatting arguments = "\n {}\n ".format(arguments) description = existing_info.get("description", None) if description is None: assembly = ( "\n ```\n" " [TODO]\n" " ```\n\n" " #### Example:\n\n" " ```mlir\n" " [TODO]\n" " ```" ) description = get_description(text, assembly) return fmt_str.format( opname=opname, opname_src=opname_src, opcode=instruction["opcode"], category_args=category_args, inst_category=inst_category, vendor_name=vendor_name, traits=existing_info.get("traits", ""), summary=summary, description=description, availability=avail, args=arguments, results=results, extras=existing_info.get("extras", ""), ) def get_string_between(base, start, end): """Extracts a substring with a specified start and end from a string. Arguments: - base: string to extract from. - start: string to use as the start of the substring. - end: string to use as the end of the substring. Returns: - The substring if found - The part of the base after end of the substring. Is the base string itself if the substring wasnt found. """ split = base.split(start, 1) if len(split) == 2: rest = split[1].split(end, 1) assert len(rest) == 2, ( 'cannot find end "{end}" while extracting substring ' "starting with {start}".format(start=start, end=end) ) return rest[0].rstrip(end), rest[1] return "", split[0] def get_string_between_nested(base, start, end): """Extracts a substring with a nested start and end from a string. Arguments: - base: string to extract from. - start: string to use as the start of the substring. - end: string to use as the end of the substring. Returns: - The substring if found - The part of the base after end of the substring. Is the base string itself if the substring wasn't found. """ split = base.split(start, 1) if len(split) == 2: # Handle nesting delimiters rest = split[1] unmatched_start = 1 index = 0 while unmatched_start > 0 and index < len(rest): if rest[index:].startswith(end): unmatched_start -= 1 if unmatched_start == 0: break index += len(end) elif rest[index:].startswith(start): unmatched_start += 1 index += len(start) else: index += 1 assert index < len(rest), ( 'cannot find end "{end}" while extracting substring ' 'starting with "{start}"'.format(start=start, end=end) ) return rest[:index], rest[index + len(end) :] return "", split[0] def extract_td_op_info(op_def): """Extracts potentially manually specified sections in op's definition. Arguments: - A string containing the op's TableGen definition Returns: - A dict containing potential manually specified sections """ # Get opname opname = [o[8:-2] for o in re.findall("def SPIRV_\w+Op", op_def)] assert len(opname) == 1, "more than one ops in the same section!" opname = opname[0] # Get instruction category inst_category = [o[4:] for o in re.findall("SPIRV_\w+Op", op_def.split(":", 1)[1])] assert len(inst_category) <= 1, "more than one ops in the same section!" inst_category = inst_category[0] if len(inst_category) == 1 else "Op" # Get category_args op_tmpl_params, _ = get_string_between_nested(op_def, "<", ">") opstringname, rest = get_string_between(op_tmpl_params, '"', '"') category_args = rest.split("[", 1)[0] # Get traits traits, _ = get_string_between_nested(rest, "[", "]") # Get description description, rest = get_string_between(op_def, "let description = [{\n", "}];\n") # Get arguments args, rest = get_string_between(rest, " let arguments = (ins", ");\n") # Get results results, rest = get_string_between(rest, " let results = (outs", ");\n") extras = rest.strip(" }\n") if extras: extras = "\n {}\n".format(extras) return { # Prefix with 'Op' to make it consistent with SPIR-V spec "opname": "Op{}".format(opname), "inst_category": inst_category, "category_args": category_args, "traits": traits, "description": description, "arguments": args, "results": results, "extras": extras, } def update_td_op_definitions( path, instructions, docs, filter_list, inst_category, capability_mapping, settings ): """Updates SPIRVOps.td with newly generated op definition. Arguments: - path: path to SPIRVOps.td - instructions: SPIR-V JSON grammar for all instructions - docs: SPIR-V HTML doc for all instructions - filter_list: a list containing new opnames to include - capability_mapping: mapping from duplicated capability symbols to the canonicalized symbol chosen for SPIRVBase.td. Returns: - A string containing all the TableGen op definitions """ with open(path, "r") as f: content = f.read() # Split the file into chunks, each containing one op. ops = content.split(AUTOGEN_OP_DEF_SEPARATOR) header = ops[0] footer = ops[-1] ops = ops[1:-1] # For each existing op, extract the manually-written sections out to retain # them when re-generating the ops. Also append the existing ops to filter # list. name_op_map = {} # Map from opname to its existing ODS definition op_info_dict = {} for op in ops: info_dict = extract_td_op_info(op) opname = info_dict["opname"] name_op_map[opname] = op op_info_dict[opname] = info_dict filter_list.append(opname) filter_list = sorted(list(set(filter_list))) op_defs = [] if settings.gen_cl_ops: fix_opname = lambda src: src.replace("CL", "").lower() else: fix_opname = lambda src: src for opname in filter_list: # Find the grammar spec for this op try: fixed_opname = fix_opname(opname) instruction = next( inst for inst in instructions if inst["opname"] == fixed_opname ) op_defs.append( get_op_definition( instruction, opname, docs[fixed_opname], op_info_dict.get(opname, {"inst_category": inst_category}), capability_mapping, settings, ) ) except StopIteration: # This is an op added by us; use the existing ODS definition. op_defs.append(name_op_map[opname]) # Substitute the old op definitions op_defs = [header] + op_defs + [footer] content = AUTOGEN_OP_DEF_SEPARATOR.join(op_defs) with open(path, "w") as f: f.write(content) if __name__ == "__main__": import argparse cli_parser = argparse.ArgumentParser( description="Update SPIR-V dialect definitions using SPIR-V spec" ) cli_parser.add_argument( "--base-td-path", dest="base_td_path", type=str, default=None, help="Path to SPIRVBase.td", ) cli_parser.add_argument( "--op-td-path", dest="op_td_path", type=str, default=None, help="Path to SPIRVOps.td", ) cli_parser.add_argument( "--new-enum", dest="new_enum", type=str, default=None, help="SPIR-V enum to be added to SPIRVBase.td", ) cli_parser.add_argument( "--new-opcodes", dest="new_opcodes", type=str, default=None, nargs="*", help="update SPIR-V opcodes in SPIRVBase.td", ) cli_parser.add_argument( "--new-inst", dest="new_inst", type=str, default=None, nargs="*", help="SPIR-V instruction to be added to ops file", ) cli_parser.add_argument( "--inst-category", dest="inst_category", type=str, default="Op", help="SPIR-V instruction category used for choosing " "the TableGen base class to define this op", ) cli_parser.add_argument( "--gen-cl-ops", dest="gen_cl_ops", help="Generate OpenCL Extended Instruction Set op", action="store_true", ) cli_parser.set_defaults(gen_cl_ops=False) cli_parser.add_argument( "--gen-inst-coverage", dest="gen_inst_coverage", action="store_true" ) cli_parser.set_defaults(gen_inst_coverage=False) args = cli_parser.parse_args() if args.gen_cl_ops: ext_html_url = SPIRV_CL_EXT_HTML_SPEC_URL ext_json_url = SPIRV_CL_EXT_JSON_SPEC_URL else: ext_html_url = None ext_json_url = None operand_kinds, instructions = get_spirv_grammar_from_json_spec(ext_json_url) # Define new enum attr if args.new_enum is not None: assert args.base_td_path is not None filter_list = [args.new_enum] if args.new_enum else [] update_td_enum_attrs(args.base_td_path, operand_kinds, filter_list) # Define new opcode if args.new_opcodes is not None: assert args.base_td_path is not None update_td_opcodes(args.base_td_path, instructions, args.new_opcodes) # Define new op if args.new_inst is not None: assert args.op_td_path is not None docs = get_spirv_doc_from_html_spec(ext_html_url, args) capability_mapping = get_capability_mapping(operand_kinds) update_td_op_definitions( args.op_td_path, instructions, docs, args.new_inst, args.inst_category, capability_mapping, args, ) print("Done. Note that this script just generates a template; ", end="") print("please read the spec and update traits, arguments, and ", end="") print("results accordingly.") if args.gen_inst_coverage: gen_instr_coverage_report(args.base_td_path, instructions)