//===- EnumPythonBindingGen.cpp - Generator of Python API for ODS enums ---===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // EnumPythonBindingGen uses ODS specification of MLIR enum attributes to // generate the corresponding Python binding classes. // //===----------------------------------------------------------------------===// #include "OpGenHelpers.h" #include "mlir/TableGen/AttrOrTypeDef.h" #include "mlir/TableGen/Attribute.h" #include "mlir/TableGen/Dialect.h" #include "mlir/TableGen/GenInfo.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/TableGen/Record.h" using namespace mlir; using namespace mlir::tblgen; /// File header and includes. constexpr const char *fileHeader = R"Py( # Autogenerated by mlir-tblgen; don't manually edit. from enum import IntEnum, auto, IntFlag from ._ods_common import _cext as _ods_cext from ..ir import register_attribute_builder _ods_ir = _ods_cext.ir )Py"; /// Makes enum case name Python-compatible, i.e. UPPER_SNAKE_CASE. static std::string makePythonEnumCaseName(StringRef name) { if (isPythonReserved(name.str())) return (name + "_").str(); return name.str(); } /// Emits the Python class for the given enum. static void emitEnumClass(EnumAttr enumAttr, raw_ostream &os) { os << llvm::formatv("class {0}({1}):\n", enumAttr.getEnumClassName(), enumAttr.isBitEnum() ? "IntFlag" : "IntEnum"); if (!enumAttr.getSummary().empty()) os << llvm::formatv(" \"\"\"{0}\"\"\"\n", enumAttr.getSummary()); os << "\n"; for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) { os << llvm::formatv( " {0} = {1}\n", makePythonEnumCaseName(enumCase.getSymbol()), enumCase.getValue() >= 0 ? std::to_string(enumCase.getValue()) : "auto()"); } os << "\n"; if (enumAttr.isBitEnum()) { os << llvm::formatv(" def __iter__(self):\n" " return iter([case for case in type(self) if " "(self & case) is case])\n"); os << llvm::formatv(" def __len__(self):\n" " return bin(self).count(\"1\")\n"); os << "\n"; } os << llvm::formatv(" def __str__(self):\n"); if (enumAttr.isBitEnum()) os << llvm::formatv(" if len(self) > 1:\n" " return \"{0}\".join(map(str, self))\n", enumAttr.getDef().getValueAsString("separator")); for (const EnumAttrCase &enumCase : enumAttr.getAllCases()) { os << llvm::formatv(" if self is {0}.{1}:\n", enumAttr.getEnumClassName(), makePythonEnumCaseName(enumCase.getSymbol())); os << llvm::formatv(" return \"{0}\"\n", enumCase.getStr()); } os << llvm::formatv( " raise ValueError(\"Unknown {0} enum entry.\")\n\n\n", enumAttr.getEnumClassName()); os << "\n"; } /// Attempts to extract the bitwidth B from string "uintB_t" describing the /// type. This bitwidth information is not readily available in ODS. Returns /// `false` on success, `true` on failure. static bool extractUIntBitwidth(StringRef uintType, int64_t &bitwidth) { if (!uintType.consume_front("uint")) return true; if (!uintType.consume_back("_t")) return true; return uintType.getAsInteger(/*Radix=*/10, bitwidth); } /// Emits an attribute builder for the given enum attribute to support automatic /// conversion between enum values and attributes in Python. Returns /// `false` on success, `true` on failure. static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) { int64_t bitwidth; if (extractUIntBitwidth(enumAttr.getUnderlyingType(), bitwidth)) { llvm::errs() << "failed to identify bitwidth of " << enumAttr.getUnderlyingType(); return true; } os << llvm::formatv("@register_attribute_builder(\"{0}\")\n", enumAttr.getAttrDefName()); os << llvm::formatv("def _{0}(x, context):\n", enumAttr.getAttrDefName().lower()); os << llvm::formatv( " return " "_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, " "context=context), int(x))\n\n", bitwidth); return false; } /// Emits an attribute builder for the given dialect enum attribute to support /// automatic conversion between enum values and attributes in Python. Returns /// `false` on success, `true` on failure. static bool emitDialectEnumAttributeBuilder(StringRef attrDefName, StringRef formatString, raw_ostream &os) { os << llvm::formatv("@register_attribute_builder(\"{0}\")\n", attrDefName); os << llvm::formatv("def _{0}(x, context):\n", attrDefName.lower()); os << llvm::formatv(" return " "_ods_ir.Attribute.parse(f'{0}', context=context)\n\n", formatString); return false; } /// Emits Python bindings for all enums in the record keeper. Returns /// `false` on success, `true` on failure. static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) { os << fileHeader; for (auto &it : recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo")) { EnumAttr enumAttr(*it); emitEnumClass(enumAttr, os); emitAttributeBuilder(enumAttr, os); } for (auto &it : recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttr")) { AttrOrTypeDef attr(&*it); if (!attr.getMnemonic()) { llvm::errs() << "enum case " << attr << " needs mnemonic for python enum bindings generation"; return true; } StringRef mnemonic = attr.getMnemonic().value(); std::optional assemblyFormat = attr.getAssemblyFormat(); StringRef dialect = attr.getDialect().getName(); if (assemblyFormat == "`<` $value `>`") { emitDialectEnumAttributeBuilder( attr.getName(), llvm::formatv("#{0}.{1}<{{str(x)}>", dialect, mnemonic).str(), os); } else if (assemblyFormat == "$value") { emitDialectEnumAttributeBuilder( attr.getName(), llvm::formatv("#{0}<{1} {{str(x)}>", dialect, mnemonic).str(), os); } else { llvm::errs() << "unsupported assembly format for python enum bindings generation"; return true; } } return false; } // Registers the enum utility generator to mlir-tblgen. static mlir::GenRegistration genPythonEnumBindings("gen-python-enum-bindings", "Generate Python bindings for enum attributes", &emitPythonEnums);