//===- OpDefinitionsGen.cpp - MLIR op definitions generator ---------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // OpDefinitionsGen uses the description of operations to generate C++ // definitions for ops. // //===----------------------------------------------------------------------===// #include "OpClass.h" #include "OpFormatGen.h" #include "OpGenHelpers.h" #include "mlir/TableGen/Argument.h" #include "mlir/TableGen/Attribute.h" #include "mlir/TableGen/Class.h" #include "mlir/TableGen/CodeGenHelpers.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Interfaces.h" #include "mlir/TableGen/Operator.h" #include "mlir/TableGen/Property.h" #include "mlir/TableGen/SideEffects.h" #include "mlir/TableGen/Trait.h" #include "llvm/ADT/BitVector.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/Signals.h" #include "llvm/Support/raw_ostream.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" #define DEBUG_TYPE "mlir-tblgen-opdefgen" using namespace llvm; using namespace mlir; using namespace mlir::tblgen; static const char *const tblgenNamePrefix = "tblgen_"; static const char *const generatedArgName = "odsArg"; static const char *const odsBuilder = "odsBuilder"; static const char *const builderOpState = "odsState"; static const char *const propertyStorage = "propStorage"; static const char *const propertyValue = "propValue"; static const char *const propertyAttr = "propAttr"; static const char *const propertyDiag = "emitError"; /// The names of the implicit attributes that contain variadic operand and /// result segment sizes. static const char *const operandSegmentAttrName = "operandSegmentSizes"; static const char *const resultSegmentAttrName = "resultSegmentSizes"; /// Code for an Op to lookup an attribute. Uses cached identifiers and subrange /// lookup. /// /// {0}: Code snippet to get the attribute's name or identifier. /// {1}: The lower bound on the sorted subrange. /// {2}: The upper bound on the sorted subrange. /// {3}: Code snippet to get the array of named attributes. /// {4}: "Named" to get the named attribute. static const char *const subrangeGetAttr = "::mlir::impl::get{4}AttrFromSortedRange({3}.begin() + {1}, {3}.end() - " "{2}, {0})"; /// The logic to calculate the actual value range for a declared operand/result /// of an op with variadic operands/results. Note that this logic is not for /// general use; it assumes all variadic operands/results must have the same /// number of values. /// /// {0}: The list of whether each declared operand/result is variadic. /// {1}: The total number of non-variadic operands/results. /// {2}: The total number of variadic operands/results. /// {3}: The total number of actual values. /// {4}: "operand" or "result". static const char *const sameVariadicSizeValueRangeCalcCode = R"( bool isVariadic[] = {{{0}}; int prevVariadicCount = 0; for (unsigned i = 0; i < index; ++i) if (isVariadic[i]) ++prevVariadicCount; // Calculate how many dynamic values a static variadic {4} corresponds to. // This assumes all static variadic {4}s have the same dynamic value count. int variadicSize = ({3} - {1}) / {2}; // `index` passed in as the parameter is the static index which counts each // {4} (variadic or not) as size 1. So here for each previous static variadic // {4}, we need to offset by (variadicSize - 1) to get where the dynamic // value pack for this static {4} starts. int start = index + (variadicSize - 1) * prevVariadicCount; int size = isVariadic[index] ? variadicSize : 1; return {{start, size}; )"; /// The logic to calculate the actual value range for a declared operand/result /// of an op with variadic operands/results. Note that this logic is assumes /// the op has an attribute specifying the size of each operand/result segment /// (variadic or not). static const char *const attrSizedSegmentValueRangeCalcCode = R"( unsigned start = 0; for (unsigned i = 0; i < index; ++i) start += sizeAttr[i]; return {start, sizeAttr[index]}; )"; /// The code snippet to initialize the sizes for the value range calculation. /// /// {0}: The code to get the attribute. static const char *const adapterSegmentSizeAttrInitCode = R"( assert({0} && "missing segment size attribute for op"); auto sizeAttr = ::llvm::cast<::mlir::DenseI32ArrayAttr>({0}); )"; static const char *const adapterSegmentSizeAttrInitCodeProperties = R"( ::llvm::ArrayRef sizeAttr = {0}; )"; /// The code snippet to initialize the sizes for the value range calculation. /// /// {0}: The code to get the attribute. static const char *const opSegmentSizeAttrInitCode = R"( auto sizeAttr = ::llvm::cast<::mlir::DenseI32ArrayAttr>({0}); )"; /// The logic to calculate the actual value range for a declared operand /// of an op with variadic of variadic operands within the OpAdaptor. /// /// {0}: The name of the segment attribute. /// {1}: The index of the main operand. /// {2}: The range type of adaptor. static const char *const variadicOfVariadicAdaptorCalcCode = R"( auto tblgenTmpOperands = getODSOperands({1}); auto sizes = {0}(); ::llvm::SmallVector<{2}> tblgenTmpOperandGroups; for (int i = 0, e = sizes.size(); i < e; ++i) {{ tblgenTmpOperandGroups.push_back(tblgenTmpOperands.take_front(sizes[i])); tblgenTmpOperands = tblgenTmpOperands.drop_front(sizes[i]); } return tblgenTmpOperandGroups; )"; /// The logic to build a range of either operand or result values. /// /// {0}: The begin iterator of the actual values. /// {1}: The call to generate the start and length of the value range. static const char *const valueRangeReturnCode = R"( auto valueRange = {1}; return {{std::next({0}, valueRange.first), std::next({0}, valueRange.first + valueRange.second)}; )"; /// Read operand/result segment_size from bytecode. static const char *const readBytecodeSegmentSizeNative = R"( if ($_reader.getBytecodeVersion() >= /*kNativePropertiesODSSegmentSize=*/6) return $_reader.readSparseArray(::llvm::MutableArrayRef($_storage)); )"; static const char *const readBytecodeSegmentSizeLegacy = R"( if ($_reader.getBytecodeVersion() < /*kNativePropertiesODSSegmentSize=*/6) { auto &$_storage = prop.$_propName; ::mlir::DenseI32ArrayAttr attr; if (::mlir::failed($_reader.readAttribute(attr))) return ::mlir::failure(); if (attr.size() > static_cast(sizeof($_storage) / sizeof(int32_t))) { $_reader.emitError("size mismatch for operand/result_segment_size"); return ::mlir::failure(); } ::llvm::copy(::llvm::ArrayRef(attr), $_storage.begin()); } )"; /// Write operand/result segment_size to bytecode. static const char *const writeBytecodeSegmentSizeNative = R"( if ($_writer.getBytecodeVersion() >= /*kNativePropertiesODSSegmentSize=*/6) $_writer.writeSparseArray(::llvm::ArrayRef($_storage)); )"; /// Write operand/result segment_size to bytecode. static const char *const writeBytecodeSegmentSizeLegacy = R"( if ($_writer.getBytecodeVersion() < /*kNativePropertiesODSSegmentSize=*/6) { auto &$_storage = prop.$_propName; $_writer.writeAttribute(::mlir::DenseI32ArrayAttr::get($_ctxt, $_storage)); } )"; /// A header for indicating code sections. /// /// {0}: Some text, or a class name. /// {1}: Some text. static const char *const opCommentHeader = R"( //===----------------------------------------------------------------------===// // {0} {1} //===----------------------------------------------------------------------===// )"; //===----------------------------------------------------------------------===// // Utility structs and functions //===----------------------------------------------------------------------===// // Replaces all occurrences of `match` in `str` with `substitute`. static std::string replaceAllSubstrs(std::string str, const std::string &match, const std::string &substitute) { std::string::size_type scanLoc = 0, matchLoc = std::string::npos; while ((matchLoc = str.find(match, scanLoc)) != std::string::npos) { str = str.replace(matchLoc, match.size(), substitute); scanLoc = matchLoc + substitute.size(); } return str; } // Returns whether the record has a value of the given name that can be returned // via getValueAsString. static inline bool hasStringAttribute(const Record &record, StringRef fieldName) { auto *valueInit = record.getValueInit(fieldName); return isa(valueInit); } static std::string getArgumentName(const Operator &op, int index) { const auto &operand = op.getOperand(index); if (!operand.name.empty()) return std::string(operand.name); return std::string(formatv("{0}_{1}", generatedArgName, index)); } // Returns true if we can use unwrapped value for the given `attr` in builders. static bool canUseUnwrappedRawValue(const tblgen::Attribute &attr) { return attr.getReturnType() != attr.getStorageType() && // We need to wrap the raw value into an attribute in the builder impl // so we need to make sure that the attribute specifies how to do that. !attr.getConstBuilderTemplate().empty(); } /// Build an attribute from a parameter value using the constant builder. static std::string constBuildAttrFromParam(const tblgen::Attribute &attr, FmtContext &fctx, StringRef paramName) { std::string builderTemplate = attr.getConstBuilderTemplate().str(); // For StringAttr, its constant builder call will wrap the input in // quotes, which is correct for normal string literals, but incorrect // here given we use function arguments. So we need to strip the // wrapping quotes. if (StringRef(builderTemplate).contains("\"$0\"")) builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0"); return tgfmt(builderTemplate, &fctx, paramName).str(); } namespace { /// Metadata on a registered attribute. Given that attributes are stored in /// sorted order on operations, we can use information from ODS to deduce the /// number of required attributes less and and greater than each attribute, /// allowing us to search only a subrange of the attributes in ODS-generated /// getters. struct AttributeMetadata { /// The attribute name. StringRef attrName; /// Whether the attribute is required. bool isRequired; /// The ODS attribute constraint. Not present for implicit attributes. std::optional constraint; /// The number of required attributes less than this attribute. unsigned lowerBound = 0; /// The number of required attributes greater than this attribute. unsigned upperBound = 0; }; /// Helper class to select between OpAdaptor and Op code templates. class OpOrAdaptorHelper { public: OpOrAdaptorHelper(const Operator &op, bool emitForOp) : op(op), emitForOp(emitForOp) { computeAttrMetadata(); } /// Object that wraps a functor in a stream operator for interop with /// llvm::formatv. class Formatter { public: template Formatter(Functor &&func) : func(std::forward(func)) {} std::string str() const { std::string result; llvm::raw_string_ostream os(result); os << *this; return os.str(); } private: std::function func; friend raw_ostream &operator<<(raw_ostream &os, const Formatter &fmt) { return fmt.func(os); } }; // Generate code for getting an attribute. Formatter getAttr(StringRef attrName, bool isNamed = false) const { assert(attrMetadata.count(attrName) && "expected attribute metadata"); return [this, attrName, isNamed](raw_ostream &os) -> raw_ostream & { const AttributeMetadata &attr = attrMetadata.find(attrName)->second; if (hasProperties()) { assert(!isNamed); return os << "getProperties()." << attrName; } return os << formatv(subrangeGetAttr, getAttrName(attrName), attr.lowerBound, attr.upperBound, getAttrRange(), isNamed ? "Named" : ""); }; } // Generate code for getting the name of an attribute. Formatter getAttrName(StringRef attrName) const { return [this, attrName](raw_ostream &os) -> raw_ostream & { if (emitForOp) return os << op.getGetterName(attrName) << "AttrName()"; return os << formatv("{0}::{1}AttrName(*odsOpName)", op.getCppClassName(), op.getGetterName(attrName)); }; } // Get the code snippet for getting the named attribute range. StringRef getAttrRange() const { return emitForOp ? "(*this)->getAttrs()" : "odsAttrs"; } // Get the prefix code for emitting an error. Formatter emitErrorPrefix() const { return [this](raw_ostream &os) -> raw_ostream & { if (emitForOp) return os << "emitOpError("; return os << formatv("emitError(loc, \"'{0}' op \"", op.getOperationName()); }; } // Get the call to get an operand or segment of operands. Formatter getOperand(unsigned index) const { return [this, index](raw_ostream &os) -> raw_ostream & { return os << formatv(op.getOperand(index).isVariadic() ? "this->getODSOperands({0})" : "(*this->getODSOperands({0}).begin())", index); }; } // Get the call to get a result of segment of results. Formatter getResult(unsigned index) const { return [this, index](raw_ostream &os) -> raw_ostream & { if (!emitForOp) return os << ""; return os << formatv(op.getResult(index).isVariadic() ? "this->getODSResults({0})" : "(*this->getODSResults({0}).begin())", index); }; } // Return whether an op instance is available. bool isEmittingForOp() const { return emitForOp; } // Return the ODS operation wrapper. const Operator &getOp() const { return op; } // Get the attribute metadata sorted by name. const llvm::MapVector &getAttrMetadata() const { return attrMetadata; } /// Returns whether to emit a `Properties` struct for this operation or not. bool hasProperties() const { if (!op.getProperties().empty()) return true; if (!op.getDialect().usePropertiesForAttributes()) return false; if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments") || op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) return true; return llvm::any_of(getAttrMetadata(), [](const std::pair &it) { return !it.second.constraint || !it.second.constraint->isDerivedAttr(); }); } std::optional &getOperandSegmentsSize() { return operandSegmentsSize; } std::optional &getResultSegmentsSize() { return resultSegmentsSize; } uint32_t getOperandSegmentSizesLegacyIndex() { return operandSegmentSizesLegacyIndex; } uint32_t getResultSegmentSizesLegacyIndex() { return resultSegmentSizesLegacyIndex; } private: // Compute the attribute metadata. void computeAttrMetadata(); // The operation ODS wrapper. const Operator &op; // True if code is being generate for an op. False for an adaptor. const bool emitForOp; // The attribute metadata, mapped by name. llvm::MapVector attrMetadata; // Property std::optional operandSegmentsSize; std::string operandSegmentsSizeStorage; std::optional resultSegmentsSize; std::string resultSegmentsSizeStorage; // Indices to store the position in the emission order of the operand/result // segment sizes attribute if emitted as part of the properties for legacy // bytecode encodings, i.e. versions less than 6. uint32_t operandSegmentSizesLegacyIndex = 0; uint32_t resultSegmentSizesLegacyIndex = 0; // The number of required attributes. unsigned numRequired; }; } // namespace void OpOrAdaptorHelper::computeAttrMetadata() { // Enumerate the attribute names of this op, ensuring the attribute names are // unique in case implicit attributes are explicitly registered. for (const NamedAttribute &namedAttr : op.getAttributes()) { Attribute attr = namedAttr.attr; bool isOptional = attr.hasDefaultValue() || attr.isOptional() || attr.isDerivedAttr(); attrMetadata.insert( {namedAttr.name, AttributeMetadata{namedAttr.name, !isOptional, attr}}); } auto makeProperty = [&](StringRef storageType) { return Property( /*storageType=*/storageType, /*interfaceType=*/"::llvm::ArrayRef", /*convertFromStorageCall=*/"$_storage", /*assignToStorageCall=*/ "::llvm::copy($_value, $_storage.begin())", /*convertToAttributeCall=*/ "::mlir::DenseI32ArrayAttr::get($_ctxt, $_storage)", /*convertFromAttributeCall=*/ "return convertFromAttribute($_storage, $_attr, $_diag);", /*readFromMlirBytecodeCall=*/readBytecodeSegmentSizeNative, /*writeToMlirBytecodeCall=*/writeBytecodeSegmentSizeNative, /*hashPropertyCall=*/ "::llvm::hash_combine_range(std::begin($_storage), " "std::end($_storage));", /*StringRef defaultValue=*/""); }; // Include key attributes from several traits as implicitly registered. if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { if (op.getDialect().usePropertiesForAttributes()) { operandSegmentsSizeStorage = llvm::formatv("std::array", op.getNumOperands()); operandSegmentsSize = {"operandSegmentSizes", makeProperty(operandSegmentsSizeStorage)}; } else { attrMetadata.insert( {operandSegmentAttrName, AttributeMetadata{operandSegmentAttrName, /*isRequired=*/true, /*attr=*/std::nullopt}}); } } if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) { if (op.getDialect().usePropertiesForAttributes()) { resultSegmentsSizeStorage = llvm::formatv("std::array", op.getNumResults()); resultSegmentsSize = {"resultSegmentSizes", makeProperty(resultSegmentsSizeStorage)}; } else { attrMetadata.insert( {resultSegmentAttrName, AttributeMetadata{resultSegmentAttrName, /*isRequired=*/true, /*attr=*/std::nullopt}}); } } // Store the metadata in sorted order. SmallVector sortedAttrMetadata = llvm::to_vector(llvm::make_second_range(attrMetadata.takeVector())); llvm::sort(sortedAttrMetadata, [](const AttributeMetadata &lhs, const AttributeMetadata &rhs) { return lhs.attrName < rhs.attrName; }); // Store the position of the legacy operand_segment_sizes / // result_segment_sizes so we can emit a backward compatible property readers // and writers. StringRef legacyOperandSegmentSizeName = StringLiteral("operand_segment_sizes"); StringRef legacyResultSegmentSizeName = StringLiteral("result_segment_sizes"); operandSegmentSizesLegacyIndex = 0; resultSegmentSizesLegacyIndex = 0; for (auto item : sortedAttrMetadata) { if (item.attrName < legacyOperandSegmentSizeName) ++operandSegmentSizesLegacyIndex; if (item.attrName < legacyResultSegmentSizeName) ++resultSegmentSizesLegacyIndex; } // Compute the subrange bounds for each attribute. numRequired = 0; for (AttributeMetadata &attr : sortedAttrMetadata) { attr.lowerBound = numRequired; numRequired += attr.isRequired; }; for (AttributeMetadata &attr : sortedAttrMetadata) attr.upperBound = numRequired - attr.lowerBound - attr.isRequired; // Store the results back into the map. for (const AttributeMetadata &attr : sortedAttrMetadata) attrMetadata.insert({attr.attrName, attr}); } //===----------------------------------------------------------------------===// // Op emitter //===----------------------------------------------------------------------===// namespace { // Helper class to emit a record into the given output stream. class OpEmitter { using ConstArgument = llvm::PointerUnion; public: static void emitDecl(const Operator &op, raw_ostream &os, const StaticVerifierFunctionEmitter &staticVerifierEmitter); static void emitDef(const Operator &op, raw_ostream &os, const StaticVerifierFunctionEmitter &staticVerifierEmitter); private: OpEmitter(const Operator &op, const StaticVerifierFunctionEmitter &staticVerifierEmitter); void emitDecl(raw_ostream &os); void emitDef(raw_ostream &os); // Generate methods for accessing the attribute names of this operation. void genAttrNameGetters(); // Generates the OpAsmOpInterface for this operation if possible. void genOpAsmInterface(); // Generates the `getOperationName` method for this op. void genOpNameGetter(); // Generates code to manage the properties, if any! void genPropertiesSupport(); // Generates code to manage the encoding of properties to bytecode. void genPropertiesSupportForBytecode(ArrayRef attrOrProperties); // Generates getters for the attributes. void genAttrGetters(); // Generates setter for the attributes. void genAttrSetters(); // Generates removers for optional attributes. void genOptionalAttrRemovers(); // Generates getters for named operands. void genNamedOperandGetters(); // Generates setters for named operands. void genNamedOperandSetters(); // Generates getters for named results. void genNamedResultGetters(); // Generates getters for named regions. void genNamedRegionGetters(); // Generates getters for named successors. void genNamedSuccessorGetters(); // Generates the method to populate default attributes. void genPopulateDefaultAttributes(); // Generates builder methods for the operation. void genBuilder(); // Generates the build() method that takes each operand/attribute // as a stand-alone parameter. void genSeparateArgParamBuilder(); // Generates the build() method that takes each operand/attribute as a // stand-alone parameter. The generated build() method uses first operand's // type as all results' types. void genUseOperandAsResultTypeSeparateParamBuilder(); // Generates the build() method that takes all operands/attributes // collectively as one parameter. The generated build() method uses first // operand's type as all results' types. void genUseOperandAsResultTypeCollectiveParamBuilder(); // Generates the build() method that takes aggregate operands/attributes // parameters. This build() method uses inferred types as result types. // Requires: The type needs to be inferable via InferTypeOpInterface. void genInferredTypeCollectiveParamBuilder(); // Generates the build() method that takes each operand/attribute as a // stand-alone parameter. The generated build() method uses first attribute's // type as all result's types. void genUseAttrAsResultTypeBuilder(); // Generates the build() method that takes all result types collectively as // one parameter. Similarly for operands and attributes. void genCollectiveParamBuilder(); // The kind of parameter to generate for result types in builders. enum class TypeParamKind { None, // No result type in parameter list. Separate, // A separate parameter for each result type. Collective, // An ArrayRef for all result types. }; // The kind of parameter to generate for attributes in builders. enum class AttrParamKind { WrappedAttr, // A wrapped MLIR Attribute instance. UnwrappedValue, // A raw value without MLIR Attribute wrapper. }; // Builds the parameter list for build() method of this op. This method writes // to `paramList` the comma-separated parameter list and updates // `resultTypeNames` with the names for parameters for specifying result // types. `inferredAttributes` is populated with any attributes that are // elided from the build list. The given `typeParamKind` and `attrParamKind` // controls how result types and attributes are placed in the parameter list. void buildParamList(SmallVectorImpl ¶mList, llvm::StringSet<> &inferredAttributes, SmallVectorImpl &resultTypeNames, TypeParamKind typeParamKind, AttrParamKind attrParamKind = AttrParamKind::WrappedAttr); // Adds op arguments and regions into operation state for build() methods. void genCodeForAddingArgAndRegionForBuilder(MethodBody &body, llvm::StringSet<> &inferredAttributes, bool isRawValueAttr = false); // Generates canonicalizer declaration for the operation. void genCanonicalizerDecls(); // Generates the folder declaration for the operation. void genFolderDecls(); // Generates the parser for the operation. void genParser(); // Generates the printer for the operation. void genPrinter(); // Generates verify method for the operation. void genVerifier(); // Generates custom verify methods for the operation. void genCustomVerifier(); // Generates verify statements for operands and results in the operation. // The generated code will be attached to `body`. void genOperandResultVerifier(MethodBody &body, Operator::const_value_range values, StringRef valueKind); // Generates verify statements for regions in the operation. // The generated code will be attached to `body`. void genRegionVerifier(MethodBody &body); // Generates verify statements for successors in the operation. // The generated code will be attached to `body`. void genSuccessorVerifier(MethodBody &body); // Generates the traits used by the object. void genTraits(); // Generate the OpInterface methods for all interfaces. void genOpInterfaceMethods(); // Generate op interface methods for the given interface. void genOpInterfaceMethods(const tblgen::InterfaceTrait *trait); // Generate op interface method for the given interface method. If // 'declaration' is true, generates a declaration, else a definition. Method *genOpInterfaceMethod(const tblgen::InterfaceMethod &method, bool declaration = true); // Generate the side effect interface methods. void genSideEffectInterfaceMethods(); // Generate the type inference interface methods. void genTypeInterfaceMethods(); private: // The TableGen record for this op. // TODO: OpEmitter should not have a Record directly, // it should rather go through the Operator for better abstraction. const Record &def; // The wrapper operator class for querying information from this op. const Operator &op; // The C++ code builder for this op OpClass opClass; // The format context for verification code generation. FmtContext verifyCtx; // The emitter containing all of the locally emitted verification functions. const StaticVerifierFunctionEmitter &staticVerifierEmitter; // Helper for emitting op code. OpOrAdaptorHelper emitHelper; }; } // namespace // Populate the format context `ctx` with substitutions of attributes, operands // and results. static void populateSubstitutions(const OpOrAdaptorHelper &emitHelper, FmtContext &ctx) { // Populate substitutions for attributes. auto &op = emitHelper.getOp(); for (const auto &namedAttr : op.getAttributes()) ctx.addSubst(namedAttr.name, emitHelper.getOp().getGetterName(namedAttr.name) + "()"); // Populate substitutions for named operands. for (int i = 0, e = op.getNumOperands(); i < e; ++i) { auto &value = op.getOperand(i); if (!value.name.empty()) ctx.addSubst(value.name, emitHelper.getOperand(i).str()); } // Populate substitutions for results. for (int i = 0, e = op.getNumResults(); i < e; ++i) { auto &value = op.getResult(i); if (!value.name.empty()) ctx.addSubst(value.name, emitHelper.getResult(i).str()); } } /// Generate verification on native traits requiring attributes. static void genNativeTraitAttrVerifier(MethodBody &body, const OpOrAdaptorHelper &emitHelper) { // Check that the variadic segment sizes attribute exists and contains the // expected number of elements. // // {0}: Attribute name. // {1}: Expected number of elements. // {2}: "operand" or "result". // {3}: Emit error prefix. const char *const checkAttrSizedValueSegmentsCode = R"( { auto sizeAttr = ::llvm::cast<::mlir::DenseI32ArrayAttr>(tblgen_{0}); auto numElements = sizeAttr.asArrayRef().size(); if (numElements != {1}) return {3}"'{0}' attribute for specifying {2} segments must have {1} " "elements, but got ") << numElements; } )"; // Verify a few traits first so that we can use getODSOperands() and // getODSResults() in the rest of the verifier. auto &op = emitHelper.getOp(); if (!op.getDialect().usePropertiesForAttributes()) { if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { body << formatv(checkAttrSizedValueSegmentsCode, operandSegmentAttrName, op.getNumOperands(), "operand", emitHelper.emitErrorPrefix()); } if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) { body << formatv(checkAttrSizedValueSegmentsCode, resultSegmentAttrName, op.getNumResults(), "result", emitHelper.emitErrorPrefix()); } } } // Return true if a verifier can be emitted for the attribute: it is not a // derived attribute, it has a predicate, its condition is not empty, and, for // adaptors, the condition does not reference the op. static bool canEmitAttrVerifier(Attribute attr, bool isEmittingForOp) { if (attr.isDerivedAttr()) return false; Pred pred = attr.getPredicate(); if (pred.isNull()) return false; std::string condition = pred.getCondition(); return !condition.empty() && (!StringRef(condition).contains("$_op") || isEmittingForOp); } // Generate attribute verification. If an op instance is not available, then // attribute checks that require one will not be emitted. // // Attribute verification is performed as follows: // // 1. Verify that all required attributes are present in sorted order. This // ensures that we can use subrange lookup even with potentially missing // attributes. // 2. Verify native trait attributes so that other attributes may call methods // that depend on the validity of these attributes, e.g. segment size attributes // and operand or result getters. // 3. Verify the constraints on all present attributes. static void genAttributeVerifier(const OpOrAdaptorHelper &emitHelper, FmtContext &ctx, MethodBody &body, const StaticVerifierFunctionEmitter &staticVerifierEmitter, bool useProperties) { if (emitHelper.getAttrMetadata().empty()) return; // Verify the attribute if it is present. This assumes that default values // are valid. This code snippet pastes the condition inline. // // TODO: verify the default value is valid (perhaps in debug mode only). // // {0}: Attribute variable name. // {1}: Attribute condition code. // {2}: Emit error prefix. // {3}: Attribute name. // {4}: Attribute/constraint description. const char *const verifyAttrInline = R"( if ({0} && !({1})) return {2}"attribute '{3}' failed to satisfy constraint: {4}"); )"; // Verify the attribute using a uniqued constraint. Can only be used within // the context of an op. // // {0}: Unique constraint name. // {1}: Attribute variable name. // {2}: Attribute name. const char *const verifyAttrUnique = R"( if (::mlir::failed({0}(*this, {1}, "{2}"))) return ::mlir::failure(); )"; // Traverse the array until the required attribute is found. Return an error // if the traversal reached the end. // // {0}: Code to get the name of the attribute. // {1}: The emit error prefix. // {2}: The name of the attribute. const char *const findRequiredAttr = R"( while (true) {{ if (namedAttrIt == namedAttrRange.end()) return {1}"requires attribute '{2}'"); if (namedAttrIt->getName() == {0}) {{ tblgen_{2} = namedAttrIt->getValue(); break; })"; // Emit a check to see if the iteration has encountered an optional attribute. // // {0}: Code to get the name of the attribute. // {1}: The name of the attribute. const char *const checkOptionalAttr = R"( else if (namedAttrIt->getName() == {0}) {{ tblgen_{1} = namedAttrIt->getValue(); })"; // Emit the start of the loop for checking trailing attributes. const char *const checkTrailingAttrs = R"(while (true) { if (namedAttrIt == namedAttrRange.end()) { break; })"; // Emit the verifier for the attribute. const auto emitVerifier = [&](Attribute attr, StringRef attrName, StringRef varName) { std::string condition = attr.getPredicate().getCondition(); std::optional constraintFn; if (emitHelper.isEmittingForOp() && (constraintFn = staticVerifierEmitter.getAttrConstraintFn(attr))) { body << formatv(verifyAttrUnique, *constraintFn, varName, attrName); } else { body << formatv(verifyAttrInline, varName, tgfmt(condition, &ctx.withSelf(varName)), emitHelper.emitErrorPrefix(), attrName, escapeString(attr.getSummary())); } }; // Prefix variables with `tblgen_` to avoid hiding the attribute accessor. const auto getVarName = [&](StringRef attrName) { return (tblgenNamePrefix + attrName).str(); }; body.indent(); if (useProperties) { for (const std::pair &it : emitHelper.getAttrMetadata()) { const AttributeMetadata &metadata = it.second; if (metadata.constraint && metadata.constraint->isDerivedAttr()) continue; body << formatv( "auto tblgen_{0} = getProperties().{0}; (void)tblgen_{0};\n", it.first); if (metadata.isRequired) body << formatv( "if (!tblgen_{0}) return {1}\"requires attribute '{0}'\");\n", it.first, emitHelper.emitErrorPrefix()); } } else { body << formatv("auto namedAttrRange = {0};\n", emitHelper.getAttrRange()); body << "auto namedAttrIt = namedAttrRange.begin();\n"; // Iterate over the attributes in sorted order. Keep track of the optional // attributes that may be encountered along the way. SmallVector optionalAttrs; for (const std::pair &it : emitHelper.getAttrMetadata()) { const AttributeMetadata &metadata = it.second; if (!metadata.isRequired) { optionalAttrs.push_back(&metadata); continue; } body << formatv("::mlir::Attribute {0};\n", getVarName(it.first)); for (const AttributeMetadata *optional : optionalAttrs) { body << formatv("::mlir::Attribute {0};\n", getVarName(optional->attrName)); } body << formatv(findRequiredAttr, emitHelper.getAttrName(it.first), emitHelper.emitErrorPrefix(), it.first); for (const AttributeMetadata *optional : optionalAttrs) { body << formatv(checkOptionalAttr, emitHelper.getAttrName(optional->attrName), optional->attrName); } body << "\n ++namedAttrIt;\n}\n"; optionalAttrs.clear(); } // Get trailing optional attributes. if (!optionalAttrs.empty()) { for (const AttributeMetadata *optional : optionalAttrs) { body << formatv("::mlir::Attribute {0};\n", getVarName(optional->attrName)); } body << checkTrailingAttrs; for (const AttributeMetadata *optional : optionalAttrs) { body << formatv(checkOptionalAttr, emitHelper.getAttrName(optional->attrName), optional->attrName); } body << "\n ++namedAttrIt;\n}\n"; } } body.unindent(); // Emit the checks for segment attributes first so that the other // constraints can call operand and result getters. genNativeTraitAttrVerifier(body, emitHelper); bool isEmittingForOp = emitHelper.isEmittingForOp(); for (const auto &namedAttr : emitHelper.getOp().getAttributes()) if (canEmitAttrVerifier(namedAttr.attr, isEmittingForOp)) emitVerifier(namedAttr.attr, namedAttr.name, getVarName(namedAttr.name)); } /// Include declarations specified on NativeTrait static std::string formatExtraDeclarations(const Operator &op) { SmallVector extraDeclarations; // Include extra class declarations from NativeTrait for (const auto &trait : op.getTraits()) { if (auto *opTrait = dyn_cast(&trait)) { StringRef value = opTrait->getExtraConcreteClassDeclaration(); if (value.empty()) continue; extraDeclarations.push_back(value); } } extraDeclarations.push_back(op.getExtraClassDeclaration()); return llvm::join(extraDeclarations, "\n"); } /// Op extra class definitions have a `$cppClass` substitution that is to be /// replaced by the C++ class name. /// Include declarations specified on NativeTrait static std::string formatExtraDefinitions(const Operator &op) { SmallVector extraDefinitions; // Include extra class definitions from NativeTrait for (const auto &trait : op.getTraits()) { if (auto *opTrait = dyn_cast(&trait)) { StringRef value = opTrait->getExtraConcreteClassDefinition(); if (value.empty()) continue; extraDefinitions.push_back(value); } } extraDefinitions.push_back(op.getExtraClassDefinition()); FmtContext ctx = FmtContext().addSubst("cppClass", op.getCppClassName()); return tgfmt(llvm::join(extraDefinitions, "\n"), &ctx).str(); } OpEmitter::OpEmitter(const Operator &op, const StaticVerifierFunctionEmitter &staticVerifierEmitter) : def(op.getDef()), op(op), opClass(op.getCppClassName(), formatExtraDeclarations(op), formatExtraDefinitions(op)), staticVerifierEmitter(staticVerifierEmitter), emitHelper(op, /*emitForOp=*/true) { verifyCtx.addSubst("_op", "(*this->getOperation())"); verifyCtx.addSubst("_ctxt", "this->getOperation()->getContext()"); genTraits(); // Generate C++ code for various op methods. The order here determines the // methods in the generated file. genAttrNameGetters(); genOpAsmInterface(); genOpNameGetter(); genNamedOperandGetters(); genNamedOperandSetters(); genNamedResultGetters(); genNamedRegionGetters(); genNamedSuccessorGetters(); genPropertiesSupport(); genAttrGetters(); genAttrSetters(); genOptionalAttrRemovers(); genBuilder(); genPopulateDefaultAttributes(); genParser(); genPrinter(); genVerifier(); genCustomVerifier(); genCanonicalizerDecls(); genFolderDecls(); genTypeInterfaceMethods(); genOpInterfaceMethods(); generateOpFormat(op, opClass); genSideEffectInterfaceMethods(); } void OpEmitter::emitDecl( const Operator &op, raw_ostream &os, const StaticVerifierFunctionEmitter &staticVerifierEmitter) { OpEmitter(op, staticVerifierEmitter).emitDecl(os); } void OpEmitter::emitDef( const Operator &op, raw_ostream &os, const StaticVerifierFunctionEmitter &staticVerifierEmitter) { OpEmitter(op, staticVerifierEmitter).emitDef(os); } void OpEmitter::emitDecl(raw_ostream &os) { opClass.finalize(); opClass.writeDeclTo(os); } void OpEmitter::emitDef(raw_ostream &os) { opClass.finalize(); opClass.writeDefTo(os); } static void errorIfPruned(size_t line, Method *m, const Twine &methodName, const Operator &op) { if (m) return; PrintFatalError(op.getLoc(), "Unexpected overlap when generating `" + methodName + "` for " + op.getOperationName() + " (from line " + Twine(line) + ")"); } #define ERROR_IF_PRUNED(M, N, O) errorIfPruned(__LINE__, M, N, O) void OpEmitter::genAttrNameGetters() { const llvm::MapVector &attributes = emitHelper.getAttrMetadata(); bool hasOperandSegmentsSize = op.getDialect().usePropertiesForAttributes() && op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"); // Emit the getAttributeNames method. { auto *method = opClass.addStaticInlineMethod( "::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames"); ERROR_IF_PRUNED(method, "getAttributeNames", op); auto &body = method->body(); if (!hasOperandSegmentsSize && attributes.empty()) { body << " return {};"; // Nothing else to do if there are no registered attributes. Exit early. return; } body << " static ::llvm::StringRef attrNames[] = {"; llvm::interleaveComma(llvm::make_first_range(attributes), body, [&](StringRef attrName) { body << "::llvm::StringRef(\"" << attrName << "\")"; }); if (hasOperandSegmentsSize) { if (!attributes.empty()) body << ", "; body << "::llvm::StringRef(\"" << operandSegmentAttrName << "\")"; } body << "};\n return ::llvm::ArrayRef(attrNames);"; } // Emit the getAttributeNameForIndex methods. { auto *method = opClass.addInlineMethod( "::mlir::StringAttr", "getAttributeNameForIndex", MethodParameter("unsigned", "index")); ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op); method->body() << " return getAttributeNameForIndex((*this)->getName(), index);"; } { auto *method = opClass.addStaticInlineMethod( "::mlir::StringAttr", "getAttributeNameForIndex", MethodParameter("::mlir::OperationName", "name"), MethodParameter("unsigned", "index")); ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op); if (attributes.empty()) { method->body() << " return {};"; } else { const char *const getAttrName = R"( assert(index < {0} && "invalid attribute index"); assert(name.getStringRef() == getOperationName() && "invalid operation name"); assert(name.isRegistered() && "Operation isn't registered, missing a " "dependent dialect loading?"); return name.getAttributeNames()[index]; )"; method->body() << formatv(getAttrName, attributes.size()); } } // Generate the AttrName methods, that expose the attribute names to // users. const char *attrNameMethodBody = " return getAttributeNameForIndex({0});"; for (auto [index, attr] : llvm::enumerate(llvm::make_first_range(attributes))) { std::string name = op.getGetterName(attr); std::string methodName = name + "AttrName"; // Generate the non-static variant. { auto *method = opClass.addInlineMethod("::mlir::StringAttr", methodName); ERROR_IF_PRUNED(method, methodName, op); method->body() << llvm::formatv(attrNameMethodBody, index); } // Generate the static variant. { auto *method = opClass.addStaticInlineMethod( "::mlir::StringAttr", methodName, MethodParameter("::mlir::OperationName", "name")); ERROR_IF_PRUNED(method, methodName, op); method->body() << llvm::formatv(attrNameMethodBody, "name, " + Twine(index)); } } if (hasOperandSegmentsSize) { std::string name = op.getGetterName(operandSegmentAttrName); std::string methodName = name + "AttrName"; // Generate the non-static variant. { auto *method = opClass.addInlineMethod("::mlir::StringAttr", methodName); ERROR_IF_PRUNED(method, methodName, op); method->body() << " return (*this)->getName().getAttributeNames().back();"; } // Generate the static variant. { auto *method = opClass.addStaticInlineMethod( "::mlir::StringAttr", methodName, MethodParameter("::mlir::OperationName", "name")); ERROR_IF_PRUNED(method, methodName, op); method->body() << " return name.getAttributeNames().back();"; } } } // Emit the getter for an attribute with the return type specified. // It is templated to be shared between the Op and the adaptor class. template static void emitAttrGetterWithReturnType(FmtContext &fctx, OpClassOrAdaptor &opClass, const Operator &op, StringRef name, Attribute attr) { auto *method = opClass.addMethod(attr.getReturnType(), name); ERROR_IF_PRUNED(method, name, op); auto &body = method->body(); body << " auto attr = " << name << "Attr();\n"; if (attr.hasDefaultValue() && attr.isOptional()) { // Returns the default value if not set. // TODO: this is inefficient, we are recreating the attribute for every // call. This should be set instead. if (!attr.isConstBuildable()) { PrintFatalError("DefaultValuedAttr of type " + attr.getAttrDefName() + " must have a constBuilder"); } std::string defaultValue = std::string( tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue())); body << " if (!attr)\n return " << tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf(defaultValue)) << ";\n"; } body << " return " << tgfmt(attr.getConvertFromStorageCall(), &fctx.withSelf("attr")) << ";\n"; } void OpEmitter::genPropertiesSupport() { if (!emitHelper.hasProperties()) return; SmallVector attrOrProperties; for (const std::pair &it : emitHelper.getAttrMetadata()) { if (!it.second.constraint || !it.second.constraint->isDerivedAttr()) attrOrProperties.push_back(&it.second); } for (const NamedProperty &prop : op.getProperties()) attrOrProperties.push_back(&prop); if (emitHelper.getOperandSegmentsSize()) attrOrProperties.push_back(&emitHelper.getOperandSegmentsSize().value()); if (emitHelper.getResultSegmentsSize()) attrOrProperties.push_back(&emitHelper.getResultSegmentsSize().value()); if (attrOrProperties.empty()) return; auto &setPropMethod = opClass .addStaticMethod( "::mlir::LogicalResult", "setPropertiesFromAttr", MethodParameter("Properties &", "prop"), MethodParameter("::mlir::Attribute", "attr"), MethodParameter( "::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError")) ->body(); auto &getPropMethod = opClass .addStaticMethod("::mlir::Attribute", "getPropertiesAsAttr", MethodParameter("::mlir::MLIRContext *", "ctx"), MethodParameter("const Properties &", "prop")) ->body(); auto &hashMethod = opClass .addStaticMethod("llvm::hash_code", "computePropertiesHash", MethodParameter("const Properties &", "prop")) ->body(); auto &getInherentAttrMethod = opClass .addStaticMethod("std::optional", "getInherentAttr", MethodParameter("::mlir::MLIRContext *", "ctx"), MethodParameter("const Properties &", "prop"), MethodParameter("llvm::StringRef", "name")) ->body(); auto &setInherentAttrMethod = opClass .addStaticMethod("void", "setInherentAttr", MethodParameter("Properties &", "prop"), MethodParameter("llvm::StringRef", "name"), MethodParameter("mlir::Attribute", "value")) ->body(); auto &populateInherentAttrsMethod = opClass .addStaticMethod("void", "populateInherentAttrs", MethodParameter("::mlir::MLIRContext *", "ctx"), MethodParameter("const Properties &", "prop"), MethodParameter("::mlir::NamedAttrList &", "attrs")) ->body(); auto &verifyInherentAttrsMethod = opClass .addStaticMethod( "::mlir::LogicalResult", "verifyInherentAttrs", MethodParameter("::mlir::OperationName", "opName"), MethodParameter("::mlir::NamedAttrList &", "attrs"), MethodParameter( "llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError")) ->body(); opClass.declare("Properties", "FoldAdaptor::Properties"); // Convert the property to the attribute form. setPropMethod << R"decl( ::mlir::DictionaryAttr dict = ::llvm::dyn_cast<::mlir::DictionaryAttr>(attr); if (!dict) { emitError() << "expected DictionaryAttr to set properties"; return ::mlir::failure(); } )decl"; // TODO: properties might be optional as well. const char *propFromAttrFmt = R"decl(; {{ auto setFromAttr = [] (auto &propStorage, ::mlir::Attribute propAttr, ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError) {{ {0}; }; {2}; if (!attr) {{ emitError() << "expected key entry for {1} in DictionaryAttr to set " "Properties."; return ::mlir::failure(); } if (::mlir::failed(setFromAttr(prop.{1}, attr, emitError))) return ::mlir::failure(); } )decl"; for (const auto &attrOrProp : attrOrProperties) { if (const auto *namedProperty = llvm::dyn_cast_if_present(attrOrProp)) { StringRef name = namedProperty->name; auto &prop = namedProperty->prop; FmtContext fctx; std::string getAttr; llvm::raw_string_ostream os(getAttr); os << " auto attr = dict.get(\"" << name << "\");"; if (name == operandSegmentAttrName) { // Backward compat for now, TODO: Remove at some point. os << " if (!attr) attr = dict.get(\"operand_segment_sizes\");"; } if (name == resultSegmentAttrName) { // Backward compat for now, TODO: Remove at some point. os << " if (!attr) attr = dict.get(\"result_segment_sizes\");"; } os.flush(); setPropMethod << formatv(propFromAttrFmt, tgfmt(prop.getConvertFromAttributeCall(), &fctx.addSubst("_attr", propertyAttr) .addSubst("_storage", propertyStorage) .addSubst("_diag", propertyDiag)), name, getAttr); } else { const auto *namedAttr = llvm::dyn_cast_if_present(attrOrProp); StringRef name = namedAttr->attrName; std::string getAttr; llvm::raw_string_ostream os(getAttr); os << " auto attr = dict.get(\"" << name << "\");"; if (name == operandSegmentAttrName) { // Backward compat for now os << " if (!attr) attr = dict.get(\"operand_segment_sizes\");"; } if (name == resultSegmentAttrName) { // Backward compat for now os << " if (!attr) attr = dict.get(\"result_segment_sizes\");"; } os.flush(); setPropMethod << formatv(R"decl( {{ auto &propStorage = prop.{0}; {2} if (attr || /*isRequired=*/{1}) {{ if (!attr) {{ emitError() << "expected key entry for {0} in DictionaryAttr to set " "Properties."; return ::mlir::failure(); } auto convertedAttr = ::llvm::dyn_cast>(attr); if (convertedAttr) {{ propStorage = convertedAttr; } else {{ emitError() << "Invalid attribute `{0}` in property conversion: " << attr; return ::mlir::failure(); } } } )decl", name, namedAttr->isRequired, getAttr); } } setPropMethod << " return ::mlir::success();\n"; // Convert the attribute form to the property. getPropMethod << " ::mlir::SmallVector<::mlir::NamedAttribute> attrs;\n" << " ::mlir::Builder odsBuilder{ctx};\n"; const char *propToAttrFmt = R"decl( { const auto &propStorage = prop.{0}; attrs.push_back(odsBuilder.getNamedAttr("{0}", {1})); } )decl"; for (const auto &attrOrProp : attrOrProperties) { if (const auto *namedProperty = llvm::dyn_cast_if_present(attrOrProp)) { StringRef name = namedProperty->name; auto &prop = namedProperty->prop; FmtContext fctx; getPropMethod << formatv( propToAttrFmt, name, tgfmt(prop.getConvertToAttributeCall(), &fctx.addSubst("_ctxt", "ctx") .addSubst("_storage", propertyStorage))); continue; } const auto *namedAttr = llvm::dyn_cast_if_present(attrOrProp); StringRef name = namedAttr->attrName; getPropMethod << formatv(R"decl( {{ const auto &propStorage = prop.{0}; if (propStorage) attrs.push_back(odsBuilder.getNamedAttr("{0}", propStorage)); } )decl", name); } getPropMethod << R"decl( if (!attrs.empty()) return odsBuilder.getDictionaryAttr(attrs); return {}; )decl"; // Hashing for the property const char *propHashFmt = R"decl( auto hash_{0} = [] (const auto &propStorage) -> llvm::hash_code { return {1}; }; )decl"; for (const auto &attrOrProp : attrOrProperties) { if (const auto *namedProperty = llvm::dyn_cast_if_present(attrOrProp)) { StringRef name = namedProperty->name; auto &prop = namedProperty->prop; FmtContext fctx; hashMethod << formatv(propHashFmt, name, tgfmt(prop.getHashPropertyCall(), &fctx.addSubst("_storage", propertyStorage))); } } hashMethod << " return llvm::hash_combine("; llvm::interleaveComma( attrOrProperties, hashMethod, [&](const ConstArgument &attrOrProp) { if (const auto *namedProperty = llvm::dyn_cast_if_present(attrOrProp)) { hashMethod << "\n hash_" << namedProperty->name << "(prop." << namedProperty->name << ")"; return; } const auto *namedAttr = llvm::dyn_cast_if_present(attrOrProp); StringRef name = namedAttr->attrName; hashMethod << "\n llvm::hash_value(prop." << name << ".getAsOpaquePointer())"; }); hashMethod << ");\n"; const char *getInherentAttrMethodFmt = R"decl( if (name == "{0}") return prop.{0}; )decl"; const char *setInherentAttrMethodFmt = R"decl( if (name == "{0}") {{ prop.{0} = ::llvm::dyn_cast_or_null>(value); return; } )decl"; const char *populateInherentAttrsMethodFmt = R"decl( if (prop.{0}) attrs.append("{0}", prop.{0}); )decl"; for (const auto &attrOrProp : attrOrProperties) { if (const auto *namedAttr = llvm::dyn_cast_if_present(attrOrProp)) { StringRef name = namedAttr->attrName; getInherentAttrMethod << formatv(getInherentAttrMethodFmt, name); setInherentAttrMethod << formatv(setInherentAttrMethodFmt, name); populateInherentAttrsMethod << formatv(populateInherentAttrsMethodFmt, name); continue; } // The ODS segment size property is "special": we expose it as an attribute // even though it is a native property. const auto *namedProperty = cast(attrOrProp); StringRef name = namedProperty->name; if (name != operandSegmentAttrName && name != resultSegmentAttrName) continue; auto &prop = namedProperty->prop; FmtContext fctx; fctx.addSubst("_ctxt", "ctx"); fctx.addSubst("_storage", Twine("prop.") + name); if (name == operandSegmentAttrName) { getInherentAttrMethod << formatv(" if (name == \"operand_segment_sizes\" || name == " "\"{0}\") return ", operandSegmentAttrName); } else { getInherentAttrMethod << formatv(" if (name == \"result_segment_sizes\" || name == " "\"{0}\") return ", resultSegmentAttrName); } getInherentAttrMethod << tgfmt(prop.getConvertToAttributeCall(), &fctx) << ";\n"; if (name == operandSegmentAttrName) { setInherentAttrMethod << formatv(" if (name == \"operand_segment_sizes\" || name == " "\"{0}\") {{", operandSegmentAttrName); } else { setInherentAttrMethod << formatv(" if (name == \"result_segment_sizes\" || name == " "\"{0}\") {{", resultSegmentAttrName); } setInherentAttrMethod << formatv(R"decl( auto arrAttr = ::llvm::dyn_cast_or_null<::mlir::DenseI32ArrayAttr>(value); if (!arrAttr) return; if (arrAttr.size() != sizeof(prop.{0}) / sizeof(int32_t)) return; llvm::copy(arrAttr.asArrayRef(), prop.{0}.begin()); return; } )decl", name); if (name == operandSegmentAttrName) { populateInherentAttrsMethod << formatv(" attrs.append(\"{0}\", {1});\n", operandSegmentAttrName, tgfmt(prop.getConvertToAttributeCall(), &fctx)); } else { populateInherentAttrsMethod << formatv(" attrs.append(\"{0}\", {1});\n", resultSegmentAttrName, tgfmt(prop.getConvertToAttributeCall(), &fctx)); } } getInherentAttrMethod << " return std::nullopt;\n"; // Emit the verifiers method for backward compatibility with the generic // syntax. This method verifies the constraint on the properties attributes // before they are set, since dyn_cast<> will silently omit failures. for (const auto &attrOrProp : attrOrProperties) { const auto *namedAttr = llvm::dyn_cast_if_present(attrOrProp); if (!namedAttr || !namedAttr->constraint) continue; Attribute attr = *namedAttr->constraint; std::optional constraintFn = staticVerifierEmitter.getAttrConstraintFn(attr); if (!constraintFn) continue; if (canEmitAttrVerifier(attr, /*isEmittingForOp=*/false)) { std::string name = op.getGetterName(namedAttr->attrName); verifyInherentAttrsMethod << formatv(R"( {{ ::mlir::Attribute attr = attrs.get({0}AttrName(opName)); if (attr && ::mlir::failed({1}(attr, "{2}", emitError))) return ::mlir::failure(); } )", name, constraintFn, namedAttr->attrName); } } verifyInherentAttrsMethod << " return ::mlir::success();"; // Generate methods to interact with bytecode. genPropertiesSupportForBytecode(attrOrProperties); } void OpEmitter::genPropertiesSupportForBytecode( ArrayRef attrOrProperties) { if (op.useCustomPropertiesEncoding()) { opClass.declareStaticMethod( "::mlir::LogicalResult", "readProperties", MethodParameter("::mlir::DialectBytecodeReader &", "reader"), MethodParameter("::mlir::OperationState &", "state")); opClass.declareMethod( "void", "writeProperties", MethodParameter("::mlir::DialectBytecodeWriter &", "writer")); return; } auto &readPropertiesMethod = opClass .addStaticMethod( "::mlir::LogicalResult", "readProperties", MethodParameter("::mlir::DialectBytecodeReader &", "reader"), MethodParameter("::mlir::OperationState &", "state")) ->body(); auto &writePropertiesMethod = opClass .addMethod( "void", "writeProperties", MethodParameter("::mlir::DialectBytecodeWriter &", "writer")) ->body(); // Populate bytecode serialization logic. readPropertiesMethod << " auto &prop = state.getOrAddProperties(); (void)prop;"; writePropertiesMethod << " auto &prop = getProperties(); (void)prop;\n"; for (const auto &item : llvm::enumerate(attrOrProperties)) { auto &attrOrProp = item.value(); FmtContext fctx; fctx.addSubst("_reader", "reader") .addSubst("_writer", "writer") .addSubst("_storage", propertyStorage) .addSubst("_ctxt", "this->getContext()"); // If the op emits operand/result segment sizes as a property, emit the // legacy reader/writer in the appropriate order to allow backward // compatibility and back deployment. if (emitHelper.getOperandSegmentsSize().has_value() && item.index() == emitHelper.getOperandSegmentSizesLegacyIndex()) { FmtContext fmtCtxt(fctx); fmtCtxt.addSubst("_propName", operandSegmentAttrName); readPropertiesMethod << tgfmt(readBytecodeSegmentSizeLegacy, &fmtCtxt); writePropertiesMethod << tgfmt(writeBytecodeSegmentSizeLegacy, &fmtCtxt); } if (emitHelper.getResultSegmentsSize().has_value() && item.index() == emitHelper.getResultSegmentSizesLegacyIndex()) { FmtContext fmtCtxt(fctx); fmtCtxt.addSubst("_propName", resultSegmentAttrName); readPropertiesMethod << tgfmt(readBytecodeSegmentSizeLegacy, &fmtCtxt); writePropertiesMethod << tgfmt(writeBytecodeSegmentSizeLegacy, &fmtCtxt); } if (const auto *namedProperty = attrOrProp.dyn_cast()) { StringRef name = namedProperty->name; readPropertiesMethod << formatv( R"( {{ auto &propStorage = prop.{0}; auto readProp = [&]() { {1}; return ::mlir::success(); }; if (::mlir::failed(readProp())) return ::mlir::failure(); } )", name, tgfmt(namedProperty->prop.getReadFromMlirBytecodeCall(), &fctx)); writePropertiesMethod << formatv( R"( {{ auto &propStorage = prop.{0}; {1}; } )", name, tgfmt(namedProperty->prop.getWriteToMlirBytecodeCall(), &fctx)); continue; } const auto *namedAttr = attrOrProp.dyn_cast(); StringRef name = namedAttr->attrName; if (namedAttr->isRequired) { readPropertiesMethod << formatv(R"( if (::mlir::failed(reader.readAttribute(prop.{0}))) return ::mlir::failure(); )", name); writePropertiesMethod << formatv(" writer.writeAttribute(prop.{0});\n", name); } else { readPropertiesMethod << formatv(R"( if (::mlir::failed(reader.readOptionalAttribute(prop.{0}))) return ::mlir::failure(); )", name); writePropertiesMethod << formatv(R"( writer.writeOptionalAttribute(prop.{0}); )", name); } } readPropertiesMethod << " return ::mlir::success();"; } void OpEmitter::genAttrGetters() { FmtContext fctx; fctx.withBuilder("::mlir::Builder((*this)->getContext())"); // Emit the derived attribute body. auto emitDerivedAttr = [&](StringRef name, Attribute attr) { if (auto *method = opClass.addMethod(attr.getReturnType(), name)) method->body() << " " << attr.getDerivedCodeBody() << "\n"; }; // Generate named accessor with Attribute return type. This is a wrapper // class that allows referring to the attributes via accessors instead of // having to use the string interface for better compile time verification. auto emitAttrWithStorageType = [&](StringRef name, StringRef attrName, Attribute attr) { auto *method = opClass.addMethod(attr.getStorageType(), name + "Attr"); if (!method) return; method->body() << formatv( " return ::llvm::{1}<{2}>({0});", emitHelper.getAttr(attrName), attr.isOptional() || attr.hasDefaultValue() ? "dyn_cast_or_null" : "cast", attr.getStorageType()); }; for (const NamedAttribute &namedAttr : op.getAttributes()) { std::string name = op.getGetterName(namedAttr.name); if (namedAttr.attr.isDerivedAttr()) { emitDerivedAttr(name, namedAttr.attr); } else { emitAttrWithStorageType(name, namedAttr.name, namedAttr.attr); emitAttrGetterWithReturnType(fctx, opClass, op, name, namedAttr.attr); } } auto derivedAttrs = make_filter_range(op.getAttributes(), [](const NamedAttribute &namedAttr) { return namedAttr.attr.isDerivedAttr(); }); if (derivedAttrs.empty()) return; opClass.addTrait("::mlir::DerivedAttributeOpInterface::Trait"); // Generate helper method to query whether a named attribute is a derived // attribute. This enables, for example, avoiding adding an attribute that // overlaps with a derived attribute. { auto *method = opClass.addStaticMethod("bool", "isDerivedAttribute", MethodParameter("::llvm::StringRef", "name")); ERROR_IF_PRUNED(method, "isDerivedAttribute", op); auto &body = method->body(); for (auto namedAttr : derivedAttrs) body << " if (name == \"" << namedAttr.name << "\") return true;\n"; body << " return false;"; } // Generate method to materialize derived attributes as a DictionaryAttr. { auto *method = opClass.addMethod("::mlir::DictionaryAttr", "materializeDerivedAttributes"); ERROR_IF_PRUNED(method, "materializeDerivedAttributes", op); auto &body = method->body(); auto nonMaterializable = make_filter_range(derivedAttrs, [](const NamedAttribute &namedAttr) { return namedAttr.attr.getConvertFromStorageCall().empty(); }); if (!nonMaterializable.empty()) { std::string attrs; llvm::raw_string_ostream os(attrs); interleaveComma(nonMaterializable, os, [&](const NamedAttribute &attr) { os << op.getGetterName(attr.name); }); PrintWarning( op.getLoc(), formatv( "op has non-materializable derived attributes '{0}', skipping", os.str())); body << formatv(" emitOpError(\"op has non-materializable derived " "attributes '{0}'\");\n", attrs); body << " return nullptr;"; return; } body << " ::mlir::MLIRContext* ctx = getContext();\n"; body << " ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n"; body << " return ::mlir::DictionaryAttr::get("; body << " ctx, {\n"; interleave( derivedAttrs, body, [&](const NamedAttribute &namedAttr) { auto tmpl = namedAttr.attr.getConvertFromStorageCall(); std::string name = op.getGetterName(namedAttr.name); body << " {" << name << "AttrName(),\n" << tgfmt(tmpl, &fctx.withSelf(name + "()") .withBuilder("odsBuilder") .addSubst("_ctxt", "ctx") .addSubst("_storage", "ctx")) << "}"; }, ",\n"); body << "});"; } } void OpEmitter::genAttrSetters() { // Generate raw named setter type. This is a wrapper class that allows setting // to the attributes via setters instead of having to use the string interface // for better compile time verification. auto emitAttrWithStorageType = [&](StringRef setterName, StringRef getterName, Attribute attr) { auto *method = opClass.addMethod("void", setterName + "Attr", MethodParameter(attr.getStorageType(), "attr")); if (method) method->body() << formatv(" (*this)->setAttr({0}AttrName(), attr);", getterName); }; // Generate a setter that accepts the underlying C++ type as opposed to the // attribute type. auto emitAttrWithReturnType = [&](StringRef setterName, StringRef getterName, Attribute attr) { Attribute baseAttr = attr.getBaseAttr(); if (!canUseUnwrappedRawValue(baseAttr)) return; FmtContext fctx; fctx.withBuilder("::mlir::Builder((*this)->getContext())"); bool isUnitAttr = attr.getAttrDefName() == "UnitAttr"; bool isOptional = attr.isOptional(); auto createMethod = [&](const Twine ¶mType) { return opClass.addMethod("void", setterName, MethodParameter(paramType.str(), "attrValue")); }; // Build the method using the correct parameter type depending on // optionality. Method *method = nullptr; if (isUnitAttr) method = createMethod("bool"); else if (isOptional) method = createMethod("::std::optional<" + baseAttr.getReturnType() + ">"); else method = createMethod(attr.getReturnType()); if (!method) return; // If the value isn't optional, just set it directly. if (!isOptional) { method->body() << formatv( " (*this)->setAttr({0}AttrName(), {1});", getterName, constBuildAttrFromParam(attr, fctx, "attrValue")); return; } // Otherwise, we only set if the provided value is valid. If it isn't, we // remove the attribute. // TODO: Handle unit attr parameters specially, given that it is treated as // optional but not in the same way as the others (i.e. it uses bool over // std::optional<>). StringRef paramStr = isUnitAttr ? "attrValue" : "*attrValue"; const char *optionalCodeBody = R"( if (attrValue) return (*this)->setAttr({0}AttrName(), {1}); (*this)->removeAttr({0}AttrName());)"; method->body() << formatv( optionalCodeBody, getterName, constBuildAttrFromParam(baseAttr, fctx, paramStr)); }; for (const NamedAttribute &namedAttr : op.getAttributes()) { if (namedAttr.attr.isDerivedAttr()) continue; std::string setterName = op.getSetterName(namedAttr.name); std::string getterName = op.getGetterName(namedAttr.name); emitAttrWithStorageType(setterName, getterName, namedAttr.attr); emitAttrWithReturnType(setterName, getterName, namedAttr.attr); } } void OpEmitter::genOptionalAttrRemovers() { // Generate methods for removing optional attributes, instead of having to // use the string interface. Enables better compile time verification. auto emitRemoveAttr = [&](StringRef name, bool useProperties) { auto upperInitial = name.take_front().upper(); auto *method = opClass.addMethod("::mlir::Attribute", op.getRemoverName(name) + "Attr"); if (!method) return; if (useProperties) { method->body() << formatv(R"( auto &attr = getProperties().{0}; attr = {{}; return attr; )", name); return; } method->body() << formatv("return (*this)->removeAttr({0}AttrName());", op.getGetterName(name)); }; for (const NamedAttribute &namedAttr : op.getAttributes()) if (namedAttr.attr.isOptional()) emitRemoveAttr(namedAttr.name, op.getDialect().usePropertiesForAttributes()); } // Generates the code to compute the start and end index of an operand or result // range. template static void generateValueRangeStartAndEnd( Class &opClass, bool isGenericAdaptorBase, StringRef methodName, int numVariadic, int numNonVariadic, StringRef rangeSizeCall, bool hasAttrSegmentSize, StringRef sizeAttrInit, RangeT &&odsValues) { SmallVector parameters{MethodParameter("unsigned", "index")}; if (isGenericAdaptorBase) { parameters.emplace_back("unsigned", "odsOperandsSize"); // The range size is passed per parameter for generic adaptor bases as // using the rangeSizeCall would require the operands, which are not // accessible in the base class. rangeSizeCall = "odsOperandsSize"; } auto *method = opClass.addMethod("std::pair", methodName, parameters); if (!method) return; auto &body = method->body(); if (numVariadic == 0) { body << " return {index, 1};\n"; } else if (hasAttrSegmentSize) { body << sizeAttrInit << attrSizedSegmentValueRangeCalcCode; } else { // Because the op can have arbitrarily interleaved variadic and non-variadic // operands, we need to embed a list in the "sink" getter method for // calculation at run-time. SmallVector isVariadic; isVariadic.reserve(llvm::size(odsValues)); for (auto &it : odsValues) isVariadic.push_back(it.isVariableLength() ? "true" : "false"); std::string isVariadicList = llvm::join(isVariadic, ", "); body << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList, numNonVariadic, numVariadic, rangeSizeCall, "operand"); } } static std::string generateTypeForGetter(const NamedTypeConstraint &value) { std::string str = "::mlir::Value"; /// If the CPPClassName is not a fully qualified type. Uses of types /// across Dialect fail because they are not in the correct namespace. So we /// dont generate TypedValue unless the type is fully qualified. /// getCPPClassName doesn't return the fully qualified path for /// `mlir::pdl::OperationType` see /// https://github.com/llvm/llvm-project/issues/57279. /// Adaptor will have values that are not from the type of their operation and /// this is expected, so we dont generate TypedValue for Adaptor if (value.constraint.getCPPClassName() != "::mlir::Type" && StringRef(value.constraint.getCPPClassName()).starts_with("::")) str = llvm::formatv("::mlir::TypedValue<{0}>", value.constraint.getCPPClassName()) .str(); return str; } // Generates the named operand getter methods for the given Operator `op` and // puts them in `opClass`. Uses `rangeType` as the return type of getters that // return a range of operands (individual operands are `Value ` and each // element in the range must also be `Value `); use `rangeBeginCall` to get // an iterator to the beginning of the operand range; use `rangeSizeCall` to // obtain the number of operands. `getOperandCallPattern` contains the code // necessary to obtain a single operand whose position will be substituted // instead of // "{0}" marker in the pattern. Note that the pattern should work for any kind // of ops, in particular for one-operand ops that may not have the // `getOperand(unsigned)` method. static void generateNamedOperandGetters(const Operator &op, Class &opClass, Class *genericAdaptorBase, StringRef sizeAttrInit, StringRef rangeType, StringRef rangeElementType, StringRef rangeBeginCall, StringRef rangeSizeCall, StringRef getOperandCallPattern) { const int numOperands = op.getNumOperands(); const int numVariadicOperands = op.getNumVariableLengthOperands(); const int numNormalOperands = numOperands - numVariadicOperands; const auto *sameVariadicSize = op.getTrait("::mlir::OpTrait::SameVariadicOperandSize"); const auto *attrSizedOperands = op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"); if (numVariadicOperands > 1 && !sameVariadicSize && !attrSizedOperands) { PrintFatalError(op.getLoc(), "op has multiple variadic operands but no " "specification over their sizes"); } if (numVariadicOperands < 2 && attrSizedOperands) { PrintFatalError(op.getLoc(), "op must have at least two variadic operands " "to use 'AttrSizedOperandSegments' trait"); } if (attrSizedOperands && sameVariadicSize) { PrintFatalError(op.getLoc(), "op cannot have both 'AttrSizedOperandSegments' and " "'SameVariadicOperandSize' traits"); } // First emit a few "sink" getter methods upon which we layer all nicer named // getter methods. // If generating for an adaptor, the method is put into the non-templated // generic base class, to not require being defined in the header. // Since the operand size can't be determined from the base class however, // it has to be passed as an additional argument. The trampoline below // generates the function with the same signature as the Op in the generic // adaptor. bool isGenericAdaptorBase = genericAdaptorBase != nullptr; generateValueRangeStartAndEnd( /*opClass=*/isGenericAdaptorBase ? *genericAdaptorBase : opClass, isGenericAdaptorBase, /*methodName=*/"getODSOperandIndexAndLength", numVariadicOperands, numNormalOperands, rangeSizeCall, attrSizedOperands, sizeAttrInit, const_cast(op).getOperands()); if (isGenericAdaptorBase) { // Generate trampoline for calling 'getODSOperandIndexAndLength' with just // the index. This just calls the implementation in the base class but // passes the operand size as parameter. Method *method = opClass.addMethod("std::pair", "getODSOperandIndexAndLength", MethodParameter("unsigned", "index")); ERROR_IF_PRUNED(method, "getODSOperandIndexAndLength", op); MethodBody &body = method->body(); body.indent() << formatv( "return Base::getODSOperandIndexAndLength(index, {0});", rangeSizeCall); } auto *m = opClass.addMethod(rangeType, "getODSOperands", MethodParameter("unsigned", "index")); ERROR_IF_PRUNED(m, "getODSOperands", op); auto &body = m->body(); body << formatv(valueRangeReturnCode, rangeBeginCall, "getODSOperandIndexAndLength(index)"); // Then we emit nicer named getter methods by redirecting to the "sink" getter // method. for (int i = 0; i != numOperands; ++i) { const auto &operand = op.getOperand(i); if (operand.name.empty()) continue; std::string name = op.getGetterName(operand.name); if (operand.isOptional()) { m = opClass.addMethod(isGenericAdaptorBase ? rangeElementType : generateTypeForGetter(operand), name); ERROR_IF_PRUNED(m, name, op); m->body().indent() << formatv("auto operands = getODSOperands({0});\n" "return operands.empty() ? {1}{{} : ", i, m->getReturnType()); if (!isGenericAdaptorBase) m->body() << llvm::formatv("::llvm::cast<{0}>", m->getReturnType()); m->body() << "(*operands.begin());"; } else if (operand.isVariadicOfVariadic()) { std::string segmentAttr = op.getGetterName( operand.constraint.getVariadicOfVariadicSegmentSizeAttr()); if (genericAdaptorBase) { m = opClass.addMethod("::llvm::SmallVector<" + rangeType + ">", name); ERROR_IF_PRUNED(m, name, op); m->body() << llvm::formatv(variadicOfVariadicAdaptorCalcCode, segmentAttr, i, rangeType); continue; } m = opClass.addMethod("::mlir::OperandRangeRange", name); ERROR_IF_PRUNED(m, name, op); m->body() << " return getODSOperands(" << i << ").split(" << segmentAttr << "Attr());"; } else if (operand.isVariadic()) { m = opClass.addMethod(rangeType, name); ERROR_IF_PRUNED(m, name, op); m->body() << " return getODSOperands(" << i << ");"; } else { m = opClass.addMethod(isGenericAdaptorBase ? rangeElementType : generateTypeForGetter(operand), name); ERROR_IF_PRUNED(m, name, op); m->body().indent() << "return "; if (!isGenericAdaptorBase) m->body() << llvm::formatv("::llvm::cast<{0}>", m->getReturnType()); m->body() << llvm::formatv("(*getODSOperands({0}).begin());", i); } } } void OpEmitter::genNamedOperandGetters() { // Build the code snippet used for initializing the operand_segment_size)s // array. std::string attrSizeInitCode; if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { if (op.getDialect().usePropertiesForAttributes()) attrSizeInitCode = formatv(adapterSegmentSizeAttrInitCodeProperties, "getProperties().operandSegmentSizes"); else attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, emitHelper.getAttr(operandSegmentAttrName)); } generateNamedOperandGetters( op, opClass, /*genericAdaptorBase=*/nullptr, /*sizeAttrInit=*/attrSizeInitCode, /*rangeType=*/"::mlir::Operation::operand_range", /*rangeElementType=*/"::mlir::Value", /*rangeBeginCall=*/"getOperation()->operand_begin()", /*rangeSizeCall=*/"getOperation()->getNumOperands()", /*getOperandCallPattern=*/"getOperation()->getOperand({0})"); } void OpEmitter::genNamedOperandSetters() { auto *attrSizedOperands = op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"); for (int i = 0, e = op.getNumOperands(); i != e; ++i) { const auto &operand = op.getOperand(i); if (operand.name.empty()) continue; std::string name = op.getGetterName(operand.name); StringRef returnType; if (operand.isVariadicOfVariadic()) { returnType = "::mlir::MutableOperandRangeRange"; } else if (operand.isVariableLength()) { returnType = "::mlir::MutableOperandRange"; } else { returnType = "::mlir::OpOperand &"; } auto *m = opClass.addMethod(returnType, name + "Mutable"); ERROR_IF_PRUNED(m, name, op); auto &body = m->body(); body << " auto range = getODSOperandIndexAndLength(" << i << ");\n"; if (!operand.isVariadicOfVariadic() && !operand.isVariableLength()) { // In case of a single operand, return a single OpOperand. body << " return getOperation()->getOpOperand(range.first);\n"; continue; } body << " auto mutableRange = " "::mlir::MutableOperandRange(getOperation(), " "range.first, range.second"; if (attrSizedOperands) { if (emitHelper.hasProperties()) body << formatv(", ::mlir::MutableOperandRange::OperandSegment({0}u, " "{{getOperandSegmentSizesAttrName(), " "::mlir::DenseI32ArrayAttr::get(getContext(), " "getProperties().operandSegmentSizes)})", i); else body << formatv( ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", i, emitHelper.getAttr(operandSegmentAttrName, /*isNamed=*/true)); } body << ");\n"; // If this operand is a nested variadic, we split the range into a // MutableOperandRangeRange that provides a range over all of the // sub-ranges. if (operand.isVariadicOfVariadic()) { body << " return " "mutableRange.split(*(*this)->getAttrDictionary().getNamed(" << op.getGetterName( operand.constraint.getVariadicOfVariadicSegmentSizeAttr()) << "AttrName()));\n"; } else { // Otherwise, we use the full range directly. body << " return mutableRange;\n"; } } } void OpEmitter::genNamedResultGetters() { const int numResults = op.getNumResults(); const int numVariadicResults = op.getNumVariableLengthResults(); const int numNormalResults = numResults - numVariadicResults; // If we have more than one variadic results, we need more complicated logic // to calculate the value range for each result. const auto *sameVariadicSize = op.getTrait("::mlir::OpTrait::SameVariadicResultSize"); const auto *attrSizedResults = op.getTrait("::mlir::OpTrait::AttrSizedResultSegments"); if (numVariadicResults > 1 && !sameVariadicSize && !attrSizedResults) { PrintFatalError(op.getLoc(), "op has multiple variadic results but no " "specification over their sizes"); } if (numVariadicResults < 2 && attrSizedResults) { PrintFatalError(op.getLoc(), "op must have at least two variadic results " "to use 'AttrSizedResultSegments' trait"); } if (attrSizedResults && sameVariadicSize) { PrintFatalError(op.getLoc(), "op cannot have both 'AttrSizedResultSegments' and " "'SameVariadicResultSize' traits"); } // Build the initializer string for the result segment size attribute. std::string attrSizeInitCode; if (attrSizedResults) { if (op.getDialect().usePropertiesForAttributes()) attrSizeInitCode = formatv(adapterSegmentSizeAttrInitCodeProperties, "getProperties().resultSegmentSizes"); else attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, emitHelper.getAttr(resultSegmentAttrName)); } generateValueRangeStartAndEnd( opClass, /*isGenericAdaptorBase=*/false, "getODSResultIndexAndLength", numVariadicResults, numNormalResults, "getOperation()->getNumResults()", attrSizedResults, attrSizeInitCode, op.getResults()); auto *m = opClass.addMethod("::mlir::Operation::result_range", "getODSResults", MethodParameter("unsigned", "index")); ERROR_IF_PRUNED(m, "getODSResults", op); m->body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()", "getODSResultIndexAndLength(index)"); for (int i = 0; i != numResults; ++i) { const auto &result = op.getResult(i); if (result.name.empty()) continue; std::string name = op.getGetterName(result.name); if (result.isOptional()) { m = opClass.addMethod(generateTypeForGetter(result), name); ERROR_IF_PRUNED(m, name, op); m->body() << " auto results = getODSResults(" << i << ");\n" << llvm::formatv(" return results.empty()" " ? {0}()" " : ::llvm::cast<{0}>(*results.begin());", m->getReturnType()); } else if (result.isVariadic()) { m = opClass.addMethod("::mlir::Operation::result_range", name); ERROR_IF_PRUNED(m, name, op); m->body() << " return getODSResults(" << i << ");"; } else { m = opClass.addMethod(generateTypeForGetter(result), name); ERROR_IF_PRUNED(m, name, op); m->body() << llvm::formatv( " return ::llvm::cast<{0}>(*getODSResults({1}).begin());", m->getReturnType(), i); } } } void OpEmitter::genNamedRegionGetters() { unsigned numRegions = op.getNumRegions(); for (unsigned i = 0; i < numRegions; ++i) { const auto ®ion = op.getRegion(i); if (region.name.empty()) continue; std::string name = op.getGetterName(region.name); // Generate the accessors for a variadic region. if (region.isVariadic()) { auto *m = opClass.addMethod("::mlir::MutableArrayRef<::mlir::Region>", name); ERROR_IF_PRUNED(m, name, op); m->body() << formatv(" return (*this)->getRegions().drop_front({0});", i); continue; } auto *m = opClass.addMethod("::mlir::Region &", name); ERROR_IF_PRUNED(m, name, op); m->body() << formatv(" return (*this)->getRegion({0});", i); } } void OpEmitter::genNamedSuccessorGetters() { unsigned numSuccessors = op.getNumSuccessors(); for (unsigned i = 0; i < numSuccessors; ++i) { const NamedSuccessor &successor = op.getSuccessor(i); if (successor.name.empty()) continue; std::string name = op.getGetterName(successor.name); // Generate the accessors for a variadic successor list. if (successor.isVariadic()) { auto *m = opClass.addMethod("::mlir::SuccessorRange", name); ERROR_IF_PRUNED(m, name, op); m->body() << formatv( " return {std::next((*this)->successor_begin(), {0}), " "(*this)->successor_end()};", i); continue; } auto *m = opClass.addMethod("::mlir::Block *", name); ERROR_IF_PRUNED(m, name, op); m->body() << formatv(" return (*this)->getSuccessor({0});", i); } } static bool canGenerateUnwrappedBuilder(const Operator &op) { // If this op does not have native attributes at all, return directly to avoid // redefining builders. if (op.getNumNativeAttributes() == 0) return false; bool canGenerate = false; // We are generating builders that take raw values for attributes. We need to // make sure the native attributes have a meaningful "unwrapped" value type // different from the wrapped mlir::Attribute type to avoid redefining // builders. This checks for the op has at least one such native attribute. for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) { const NamedAttribute &namedAttr = op.getAttribute(i); if (canUseUnwrappedRawValue(namedAttr.attr)) { canGenerate = true; break; } } return canGenerate; } static bool canInferType(const Operator &op) { return op.getTrait("::mlir::InferTypeOpInterface::Trait"); } void OpEmitter::genSeparateArgParamBuilder() { SmallVector attrBuilderType; attrBuilderType.push_back(AttrParamKind::WrappedAttr); if (canGenerateUnwrappedBuilder(op)) attrBuilderType.push_back(AttrParamKind::UnwrappedValue); // Emit with separate builders with or without unwrapped attributes and/or // inferring result type. auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind, bool inferType) { SmallVector paramList; SmallVector resultNames; llvm::StringSet<> inferredAttributes; buildParamList(paramList, inferredAttributes, resultNames, paramKind, attrType); auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); // If the builder is redundant, skip generating the method. if (!m) return; auto &body = m->body(); genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes, /*isRawValueAttr=*/attrType == AttrParamKind::UnwrappedValue); // Push all result types to the operation state if (inferType) { // Generate builder that infers type too. // TODO: Subsume this with general checking if type can be // inferred automatically. body << formatv(R"( ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes; if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(), {1}.location, {1}.operands, {1}.attributes.getDictionary({1}.getContext()), {1}.getRawProperties(), {1}.regions, inferredReturnTypes))) {1}.addTypes(inferredReturnTypes); else ::llvm::report_fatal_error("Failed to infer result type(s).");)", opClass.getClassName(), builderOpState); return; } switch (paramKind) { case TypeParamKind::None: return; case TypeParamKind::Separate: for (int i = 0, e = op.getNumResults(); i < e; ++i) { if (op.getResult(i).isOptional()) body << " if (" << resultNames[i] << ")\n "; body << " " << builderOpState << ".addTypes(" << resultNames[i] << ");\n"; } // Automatically create the 'resultSegmentSizes' attribute using // the length of the type ranges. if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) { if (op.getDialect().usePropertiesForAttributes()) { body << " ::llvm::copy(::llvm::ArrayRef({"; } else { std::string getterName = op.getGetterName(resultSegmentAttrName); body << " " << builderOpState << ".addAttribute(" << getterName << "AttrName(" << builderOpState << ".name), " << "odsBuilder.getDenseI32ArrayAttr({"; } interleaveComma( llvm::seq(0, op.getNumResults()), body, [&](int i) { const NamedTypeConstraint &result = op.getResult(i); if (!result.isVariableLength()) { body << "1"; } else if (result.isOptional()) { body << "(" << resultNames[i] << " ? 1 : 0)"; } else { // VariadicOfVariadic of results are currently unsupported in // MLIR, hence it can only be a simple variadic. // TODO: Add implementation for VariadicOfVariadic results here // once supported. assert(result.isVariadic()); body << "static_cast(" << resultNames[i] << ".size())"; } }); if (op.getDialect().usePropertiesForAttributes()) { body << "}), " << builderOpState << ".getOrAddProperties()." "resultSegmentSizes.begin());\n"; } else { body << "}));\n"; } } return; case TypeParamKind::Collective: { int numResults = op.getNumResults(); int numVariadicResults = op.getNumVariableLengthResults(); int numNonVariadicResults = numResults - numVariadicResults; bool hasVariadicResult = numVariadicResults != 0; // Avoid emitting "resultTypes.size() >= 0u" which is always true. if (!hasVariadicResult || numNonVariadicResults != 0) body << " " << "assert(resultTypes.size() " << (hasVariadicResult ? ">=" : "==") << " " << numNonVariadicResults << "u && \"mismatched number of results\");\n"; body << " " << builderOpState << ".addTypes(resultTypes);\n"; } return; } llvm_unreachable("unhandled TypeParamKind"); }; // Some of the build methods generated here may be ambiguous, but TableGen's // ambiguous function detection will elide those ones. for (auto attrType : attrBuilderType) { emit(attrType, TypeParamKind::Separate, /*inferType=*/false); if (canInferType(op)) emit(attrType, TypeParamKind::None, /*inferType=*/true); emit(attrType, TypeParamKind::Collective, /*inferType=*/false); } } void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() { int numResults = op.getNumResults(); // Signature SmallVector paramList; paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); paramList.emplace_back("::mlir::OperationState &", builderOpState); paramList.emplace_back("::mlir::ValueRange", "operands"); // Provide default value for `attributes` when its the last parameter StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}"; paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", "attributes", attributesDefaultValue); if (op.getNumVariadicRegions()) paramList.emplace_back("unsigned", "numRegions"); auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); // If the builder is redundant, skip generating the method if (!m) return; auto &body = m->body(); // Operands body << " " << builderOpState << ".addOperands(operands);\n"; // Attributes body << " " << builderOpState << ".addAttributes(attributes);\n"; // Create the correct number of regions if (int numRegions = op.getNumRegions()) { body << llvm::formatv( " for (unsigned i = 0; i != {0}; ++i)\n", (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions))); body << " (void)" << builderOpState << ".addRegion();\n"; } // Result types SmallVector resultTypes(numResults, "operands[0].getType()"); body << " " << builderOpState << ".addTypes({" << llvm::join(resultTypes, ", ") << "});\n\n"; } void OpEmitter::genPopulateDefaultAttributes() { // All done if no attributes, except optional ones, have default values. if (llvm::all_of(op.getAttributes(), [](const NamedAttribute &named) { return !named.attr.hasDefaultValue() || named.attr.isOptional(); })) return; if (emitHelper.hasProperties()) { SmallVector paramList; paramList.emplace_back("::mlir::OperationName", "opName"); paramList.emplace_back("Properties &", "properties"); auto *m = opClass.addStaticMethod("void", "populateDefaultProperties", paramList); ERROR_IF_PRUNED(m, "populateDefaultProperties", op); auto &body = m->body(); body.indent(); body << "::mlir::Builder " << odsBuilder << "(opName.getContext());\n"; for (const NamedAttribute &namedAttr : op.getAttributes()) { auto &attr = namedAttr.attr; if (!attr.hasDefaultValue() || attr.isOptional()) continue; StringRef name = namedAttr.name; FmtContext fctx; fctx.withBuilder(odsBuilder); body << "if (!properties." << name << ")\n" << " properties." << name << " = " << std::string(tgfmt(attr.getConstBuilderTemplate(), &fctx, tgfmt(attr.getDefaultValue(), &fctx))) << ";\n"; } return; } SmallVector paramList; paramList.emplace_back("const ::mlir::OperationName &", "opName"); paramList.emplace_back("::mlir::NamedAttrList &", "attributes"); auto *m = opClass.addStaticMethod("void", "populateDefaultAttrs", paramList); ERROR_IF_PRUNED(m, "populateDefaultAttrs", op); auto &body = m->body(); body.indent(); // Set default attributes that are unset. body << "auto attrNames = opName.getAttributeNames();\n"; body << "::mlir::Builder " << odsBuilder << "(attrNames.front().getContext());\n"; StringMap attrIndex; for (const auto &it : llvm::enumerate(emitHelper.getAttrMetadata())) { attrIndex[it.value().first] = it.index(); } for (const NamedAttribute &namedAttr : op.getAttributes()) { auto &attr = namedAttr.attr; if (!attr.hasDefaultValue() || attr.isOptional()) continue; auto index = attrIndex[namedAttr.name]; body << "if (!attributes.get(attrNames[" << index << "])) {\n"; FmtContext fctx; fctx.withBuilder(odsBuilder); std::string defaultValue = std::string(tgfmt(attr.getConstBuilderTemplate(), &fctx, tgfmt(attr.getDefaultValue(), &fctx))); body.indent() << formatv("attributes.append(attrNames[{0}], {1});\n", index, defaultValue); body.unindent() << "}\n"; } } void OpEmitter::genInferredTypeCollectiveParamBuilder() { SmallVector paramList; paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); paramList.emplace_back("::mlir::OperationState &", builderOpState); paramList.emplace_back("::mlir::ValueRange", "operands"); StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}"; paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", "attributes", attributesDefaultValue); if (op.getNumVariadicRegions()) paramList.emplace_back("unsigned", "numRegions"); auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); // If the builder is redundant, skip generating the method if (!m) return; auto &body = m->body(); int numResults = op.getNumResults(); int numVariadicResults = op.getNumVariableLengthResults(); int numNonVariadicResults = numResults - numVariadicResults; int numOperands = op.getNumOperands(); int numVariadicOperands = op.getNumVariableLengthOperands(); int numNonVariadicOperands = numOperands - numVariadicOperands; // Operands if (numVariadicOperands == 0 || numNonVariadicOperands != 0) body << " assert(operands.size()" << (numVariadicOperands != 0 ? " >= " : " == ") << numNonVariadicOperands << "u && \"mismatched number of parameters\");\n"; body << " " << builderOpState << ".addOperands(operands);\n"; body << " " << builderOpState << ".addAttributes(attributes);\n"; // Create the correct number of regions if (int numRegions = op.getNumRegions()) { body << llvm::formatv( " for (unsigned i = 0; i != {0}; ++i)\n", (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions))); body << " (void)" << builderOpState << ".addRegion();\n"; } // Result types body << formatv(R"( ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes; if (::mlir::succeeded({0}::inferReturnTypes(odsBuilder.getContext(), {1}.location, operands, {1}.attributes.getDictionary({1}.getContext()), {1}.getRawProperties(), {1}.regions, inferredReturnTypes))) {{)", opClass.getClassName(), builderOpState); if (numVariadicResults == 0 || numNonVariadicResults != 0) body << "\n assert(inferredReturnTypes.size()" << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults << "u && \"mismatched number of return types\");"; body << "\n " << builderOpState << ".addTypes(inferredReturnTypes);"; body << formatv(R"( } else {{ ::llvm::report_fatal_error("Failed to infer result type(s)."); })", opClass.getClassName(), builderOpState); } void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() { auto emit = [&](AttrParamKind attrType) { SmallVector paramList; SmallVector resultNames; llvm::StringSet<> inferredAttributes; buildParamList(paramList, inferredAttributes, resultNames, TypeParamKind::None, attrType); auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); // If the builder is redundant, skip generating the method if (!m) return; auto &body = m->body(); genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes, /*isRawValueAttr=*/attrType == AttrParamKind::UnwrappedValue); auto numResults = op.getNumResults(); if (numResults == 0) return; // Push all result types to the operation state const char *index = op.getOperand(0).isVariadic() ? ".front()" : ""; std::string resultType = formatv("{0}{1}.getType()", getArgumentName(op, 0), index).str(); body << " " << builderOpState << ".addTypes({" << resultType; for (int i = 1; i != numResults; ++i) body << ", " << resultType; body << "});\n\n"; }; emit(AttrParamKind::WrappedAttr); // Generate additional builder(s) if any attributes can be "unwrapped" if (canGenerateUnwrappedBuilder(op)) emit(AttrParamKind::UnwrappedValue); } void OpEmitter::genUseAttrAsResultTypeBuilder() { SmallVector paramList; paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); paramList.emplace_back("::mlir::OperationState &", builderOpState); paramList.emplace_back("::mlir::ValueRange", "operands"); paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", "attributes", "{}"); auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); // If the builder is redundant, skip generating the method if (!m) return; auto &body = m->body(); // Push all result types to the operation state std::string resultType; const auto &namedAttr = op.getAttribute(0); body << " auto attrName = " << op.getGetterName(namedAttr.name) << "AttrName(" << builderOpState << ".name);\n" " for (auto attr : attributes) {\n" " if (attr.getName() != attrName) continue;\n"; if (namedAttr.attr.isTypeAttr()) { resultType = "::llvm::cast<::mlir::TypeAttr>(attr.getValue()).getValue()"; } else { resultType = "::llvm::cast<::mlir::TypedAttr>(attr.getValue()).getType()"; } // Operands body << " " << builderOpState << ".addOperands(operands);\n"; // Attributes body << " " << builderOpState << ".addAttributes(attributes);\n"; // Result types SmallVector resultTypes(op.getNumResults(), resultType); body << " " << builderOpState << ".addTypes({" << llvm::join(resultTypes, ", ") << "});\n"; body << " }\n"; } /// Returns a signature of the builder. Updates the context `fctx` to enable /// replacement of $_builder and $_state in the body. static SmallVector getBuilderSignature(const Builder &builder) { ArrayRef params(builder.getParameters()); // Inject builder and state arguments. SmallVector arguments; arguments.reserve(params.size() + 2); arguments.emplace_back("::mlir::OpBuilder &", odsBuilder); arguments.emplace_back("::mlir::OperationState &", builderOpState); for (unsigned i = 0, e = params.size(); i < e; ++i) { // If no name is provided, generate one. std::optional paramName = params[i].getName(); std::string name = paramName ? paramName->str() : "odsArg" + std::to_string(i); StringRef defaultValue; if (std::optional defaultParamValue = params[i].getDefaultValue()) defaultValue = *defaultParamValue; arguments.emplace_back(params[i].getCppType(), std::move(name), defaultValue); } return arguments; } void OpEmitter::genBuilder() { // Handle custom builders if provided. for (const Builder &builder : op.getBuilders()) { SmallVector arguments = getBuilderSignature(builder); std::optional body = builder.getBody(); auto properties = body ? Method::Static : Method::StaticDeclaration; auto *method = opClass.addMethod("void", "build", properties, std::move(arguments)); if (body) ERROR_IF_PRUNED(method, "build", op); if (method) method->setDeprecated(builder.getDeprecatedMessage()); FmtContext fctx; fctx.withBuilder(odsBuilder); fctx.addSubst("_state", builderOpState); if (body) method->body() << tgfmt(*body, &fctx); } // Generate default builders that requires all result type, operands, and // attributes as parameters. if (op.skipDefaultBuilders()) return; // We generate three classes of builders here: // 1. one having a stand-alone parameter for each operand / attribute, and genSeparateArgParamBuilder(); // 2. one having an aggregated parameter for all result types / operands / // attributes, and genCollectiveParamBuilder(); // 3. one having a stand-alone parameter for each operand and attribute, // use the first operand or attribute's type as all result types // to facilitate different call patterns. if (op.getNumVariableLengthResults() == 0) { if (op.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) { genUseOperandAsResultTypeSeparateParamBuilder(); genUseOperandAsResultTypeCollectiveParamBuilder(); } if (op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType")) genUseAttrAsResultTypeBuilder(); } } void OpEmitter::genCollectiveParamBuilder() { int numResults = op.getNumResults(); int numVariadicResults = op.getNumVariableLengthResults(); int numNonVariadicResults = numResults - numVariadicResults; int numOperands = op.getNumOperands(); int numVariadicOperands = op.getNumVariableLengthOperands(); int numNonVariadicOperands = numOperands - numVariadicOperands; SmallVector paramList; paramList.emplace_back("::mlir::OpBuilder &", ""); paramList.emplace_back("::mlir::OperationState &", builderOpState); paramList.emplace_back("::mlir::TypeRange", "resultTypes"); paramList.emplace_back("::mlir::ValueRange", "operands"); // Provide default value for `attributes` when its the last parameter StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}"; paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", "attributes", attributesDefaultValue); if (op.getNumVariadicRegions()) paramList.emplace_back("unsigned", "numRegions"); auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); // If the builder is redundant, skip generating the method if (!m) return; auto &body = m->body(); // Operands if (numVariadicOperands == 0 || numNonVariadicOperands != 0) body << " assert(operands.size()" << (numVariadicOperands != 0 ? " >= " : " == ") << numNonVariadicOperands << "u && \"mismatched number of parameters\");\n"; body << " " << builderOpState << ".addOperands(operands);\n"; // Attributes body << " " << builderOpState << ".addAttributes(attributes);\n"; // Create the correct number of regions if (int numRegions = op.getNumRegions()) { body << llvm::formatv( " for (unsigned i = 0; i != {0}; ++i)\n", (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions))); body << " (void)" << builderOpState << ".addRegion();\n"; } // Result types if (numVariadicResults == 0 || numNonVariadicResults != 0) body << " assert(resultTypes.size()" << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults << "u && \"mismatched number of return types\");\n"; body << " " << builderOpState << ".addTypes(resultTypes);\n"; // Generate builder that infers type too. // TODO: Expand to handle successors. if (canInferType(op) && op.getNumSuccessors() == 0) genInferredTypeCollectiveParamBuilder(); } void OpEmitter::buildParamList(SmallVectorImpl ¶mList, llvm::StringSet<> &inferredAttributes, SmallVectorImpl &resultTypeNames, TypeParamKind typeParamKind, AttrParamKind attrParamKind) { resultTypeNames.clear(); auto numResults = op.getNumResults(); resultTypeNames.reserve(numResults); paramList.emplace_back("::mlir::OpBuilder &", odsBuilder); paramList.emplace_back("::mlir::OperationState &", builderOpState); switch (typeParamKind) { case TypeParamKind::None: break; case TypeParamKind::Separate: { // Add parameters for all return types for (int i = 0; i < numResults; ++i) { const auto &result = op.getResult(i); std::string resultName = std::string(result.name); if (resultName.empty()) resultName = std::string(formatv("resultType{0}", i)); StringRef type = result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type"; paramList.emplace_back(type, resultName, result.isOptional()); resultTypeNames.emplace_back(std::move(resultName)); } } break; case TypeParamKind::Collective: { paramList.emplace_back("::mlir::TypeRange", "resultTypes"); resultTypeNames.push_back("resultTypes"); } break; } // Add parameters for all arguments (operands and attributes). int defaultValuedAttrStartIndex = op.getNumArgs(); // Successors and variadic regions go at the end of the parameter list, so no // default arguments are possible. bool hasTrailingParams = op.getNumSuccessors() || op.getNumVariadicRegions(); if (!hasTrailingParams) { // Calculate the start index from which we can attach default values in the // builder declaration. for (int i = op.getNumArgs() - 1; i >= 0; --i) { auto *namedAttr = llvm::dyn_cast_if_present(op.getArg(i)); if (!namedAttr) break; Attribute attr = namedAttr->attr; // TODO: Currently we can't differentiate between optional meaning do not // verify/not always error if missing or optional meaning need not be // specified in builder. Expand isOptional once we can differentiate. if (!attr.hasDefaultValue() && !attr.isDerivedAttr()) break; // Creating an APInt requires us to provide bitwidth, value, and // signedness, which is complicated compared to others. Similarly // for APFloat. // TODO: Adjust the 'returnType' field of such attributes // to support them. StringRef retType = namedAttr->attr.getReturnType(); if (retType == "::llvm::APInt" || retType == "::llvm::APFloat") break; defaultValuedAttrStartIndex = i; } } // Avoid generating build methods that are ambiguous due to default values by // requiring at least one attribute. if (defaultValuedAttrStartIndex < op.getNumArgs()) { // TODO: This should have been possible as a cast but // required template instantiations is not yet defined for the tblgen helper // classes. auto *namedAttr = cast(op.getArg(defaultValuedAttrStartIndex)); Attribute attr = namedAttr->attr; if ((attrParamKind == AttrParamKind::WrappedAttr && canUseUnwrappedRawValue(attr)) || (attrParamKind == AttrParamKind::UnwrappedValue && !canUseUnwrappedRawValue(attr))) ++defaultValuedAttrStartIndex; } /// Collect any inferred attributes. for (const NamedTypeConstraint &operand : op.getOperands()) { if (operand.isVariadicOfVariadic()) { inferredAttributes.insert( operand.constraint.getVariadicOfVariadicSegmentSizeAttr()); } } for (int i = 0, e = op.getNumArgs(), numOperands = 0; i < e; ++i) { Argument arg = op.getArg(i); if (const auto *operand = llvm::dyn_cast_if_present(arg)) { StringRef type; if (operand->isVariadicOfVariadic()) type = "::llvm::ArrayRef<::mlir::ValueRange>"; else if (operand->isVariadic()) type = "::mlir::ValueRange"; else type = "::mlir::Value"; paramList.emplace_back(type, getArgumentName(op, numOperands++), operand->isOptional()); continue; } if ([[maybe_unused]] const auto *operand = llvm::dyn_cast_if_present(arg)) { // TODO continue; } const NamedAttribute &namedAttr = *arg.get(); const Attribute &attr = namedAttr.attr; // Inferred attributes don't need to be added to the param list. if (inferredAttributes.contains(namedAttr.name)) continue; StringRef type; switch (attrParamKind) { case AttrParamKind::WrappedAttr: type = attr.getStorageType(); break; case AttrParamKind::UnwrappedValue: if (canUseUnwrappedRawValue(attr)) type = attr.getReturnType(); else type = attr.getStorageType(); break; } // Attach default value if requested and possible. std::string defaultValue; if (i >= defaultValuedAttrStartIndex) { if (attrParamKind == AttrParamKind::UnwrappedValue && canUseUnwrappedRawValue(attr)) defaultValue += attr.getDefaultValue(); else defaultValue += "nullptr"; } paramList.emplace_back(type, namedAttr.name, StringRef(defaultValue), attr.isOptional()); } /// Insert parameters for each successor. for (const NamedSuccessor &succ : op.getSuccessors()) { StringRef type = succ.isVariadic() ? "::mlir::BlockRange" : "::mlir::Block *"; paramList.emplace_back(type, succ.name); } /// Insert parameters for variadic regions. for (const NamedRegion ®ion : op.getRegions()) if (region.isVariadic()) paramList.emplace_back("unsigned", llvm::formatv("{0}Count", region.name).str()); } void OpEmitter::genCodeForAddingArgAndRegionForBuilder( MethodBody &body, llvm::StringSet<> &inferredAttributes, bool isRawValueAttr) { // Push all operands to the result. for (int i = 0, e = op.getNumOperands(); i < e; ++i) { std::string argName = getArgumentName(op, i); const NamedTypeConstraint &operand = op.getOperand(i); if (operand.constraint.isVariadicOfVariadic()) { body << " for (::mlir::ValueRange range : " << argName << ")\n " << builderOpState << ".addOperands(range);\n"; // Add the segment attribute. body << " {\n" << " ::llvm::SmallVector rangeSegments;\n" << " for (::mlir::ValueRange range : " << argName << ")\n" << " rangeSegments.push_back(range.size());\n" << " auto rangeAttr = " << odsBuilder << ".getDenseI32ArrayAttr(rangeSegments);\n"; if (op.getDialect().usePropertiesForAttributes()) { body << " " << builderOpState << ".getOrAddProperties()." << operand.constraint.getVariadicOfVariadicSegmentSizeAttr() << " = rangeAttr;"; } else { body << " " << builderOpState << ".addAttribute(" << op.getGetterName( operand.constraint.getVariadicOfVariadicSegmentSizeAttr()) << "AttrName(" << builderOpState << ".name), rangeAttr);"; } body << " }\n"; continue; } if (operand.isOptional()) body << " if (" << argName << ")\n "; body << " " << builderOpState << ".addOperands(" << argName << ");\n"; } // If the operation has the operand segment size attribute, add it here. auto emitSegment = [&]() { interleaveComma(llvm::seq(0, op.getNumOperands()), body, [&](int i) { const NamedTypeConstraint &operand = op.getOperand(i); if (!operand.isVariableLength()) { body << "1"; return; } std::string operandName = getArgumentName(op, i); if (operand.isOptional()) { body << "(" << operandName << " ? 1 : 0)"; } else if (operand.isVariadicOfVariadic()) { body << llvm::formatv( "static_cast(std::accumulate({0}.begin(), {0}.end(), 0, " "[](int32_t curSum, ::mlir::ValueRange range) {{ return curSum + " "static_cast(range.size()); }))", operandName); } else { body << "static_cast(" << getArgumentName(op, i) << ".size())"; } }); }; if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { std::string sizes = op.getGetterName(operandSegmentAttrName); if (op.getDialect().usePropertiesForAttributes()) { body << " ::llvm::copy(::llvm::ArrayRef({"; emitSegment(); body << "}), " << builderOpState << ".getOrAddProperties()." "operandSegmentSizes.begin());\n"; } else { body << " " << builderOpState << ".addAttribute(" << sizes << "AttrName(" << builderOpState << ".name), " << "odsBuilder.getDenseI32ArrayAttr({"; emitSegment(); body << "}));\n"; } } // Push all attributes to the result. for (const auto &namedAttr : op.getAttributes()) { auto &attr = namedAttr.attr; if (attr.isDerivedAttr() || inferredAttributes.contains(namedAttr.name)) continue; // TODO: The wrapping of optional is different for default or not, so don't // unwrap for default ones that would fail below. bool emitNotNullCheck = (attr.isOptional() && !attr.hasDefaultValue()) || (attr.hasDefaultValue() && !isRawValueAttr) || // TODO: UnitAttr is optional, not wrapped, but needs to be guarded as // the constant materialization is only for true case. (isRawValueAttr && attr.getAttrDefName() == "UnitAttr"); if (emitNotNullCheck) body.indent() << formatv("if ({0}) ", namedAttr.name) << "{\n"; if (isRawValueAttr && canUseUnwrappedRawValue(attr)) { // If this is a raw value, then we need to wrap it in an Attribute // instance. FmtContext fctx; fctx.withBuilder("odsBuilder"); if (op.getDialect().usePropertiesForAttributes()) { body << formatv(" {0}.getOrAddProperties().{1} = {2};\n", builderOpState, namedAttr.name, constBuildAttrFromParam(attr, fctx, namedAttr.name)); } else { body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n", builderOpState, op.getGetterName(namedAttr.name), constBuildAttrFromParam(attr, fctx, namedAttr.name)); } } else { if (op.getDialect().usePropertiesForAttributes()) { body << formatv(" {0}.getOrAddProperties().{1} = {1};\n", builderOpState, namedAttr.name); } else { body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n", builderOpState, op.getGetterName(namedAttr.name), namedAttr.name); } } if (emitNotNullCheck) body.unindent() << " }\n"; } // Create the correct number of regions. for (const NamedRegion ®ion : op.getRegions()) { if (region.isVariadic()) body << formatv(" for (unsigned i = 0; i < {0}Count; ++i)\n ", region.name); body << " (void)" << builderOpState << ".addRegion();\n"; } // Push all successors to the result. for (const NamedSuccessor &namedSuccessor : op.getSuccessors()) { body << formatv(" {0}.addSuccessors({1});\n", builderOpState, namedSuccessor.name); } } void OpEmitter::genCanonicalizerDecls() { bool hasCanonicalizeMethod = def.getValueAsBit("hasCanonicalizeMethod"); if (hasCanonicalizeMethod) { // static LogicResult FooOp:: // canonicalize(FooOp op, PatternRewriter &rewriter); SmallVector paramList; paramList.emplace_back(op.getCppClassName(), "op"); paramList.emplace_back("::mlir::PatternRewriter &", "rewriter"); auto *m = opClass.declareStaticMethod("::mlir::LogicalResult", "canonicalize", std::move(paramList)); ERROR_IF_PRUNED(m, "canonicalize", op); } // We get a prototype for 'getCanonicalizationPatterns' if requested directly // or if using a 'canonicalize' method. bool hasCanonicalizer = def.getValueAsBit("hasCanonicalizer"); if (!hasCanonicalizeMethod && !hasCanonicalizer) return; // We get a body for 'getCanonicalizationPatterns' when using a 'canonicalize' // method, but not implementing 'getCanonicalizationPatterns' manually. bool hasBody = hasCanonicalizeMethod && !hasCanonicalizer; // Add a signature for getCanonicalizationPatterns if implemented by the // dialect or if synthesized to call 'canonicalize'. SmallVector paramList; paramList.emplace_back("::mlir::RewritePatternSet &", "results"); paramList.emplace_back("::mlir::MLIRContext *", "context"); auto kind = hasBody ? Method::Static : Method::StaticDeclaration; auto *method = opClass.addMethod("void", "getCanonicalizationPatterns", kind, std::move(paramList)); // If synthesizing the method, fill it. if (hasBody) { ERROR_IF_PRUNED(method, "getCanonicalizationPatterns", op); method->body() << " results.add(canonicalize);\n"; } } void OpEmitter::genFolderDecls() { if (!op.hasFolder()) return; SmallVector paramList; paramList.emplace_back("FoldAdaptor", "adaptor"); StringRef retType; bool hasSingleResult = op.getNumResults() == 1 && op.getNumVariableLengthResults() == 0; if (hasSingleResult) { retType = "::mlir::OpFoldResult"; } else { paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &", "results"); retType = "::mlir::LogicalResult"; } auto *m = opClass.declareMethod(retType, "fold", std::move(paramList)); ERROR_IF_PRUNED(m, "fold", op); } void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) { Interface interface = opTrait->getInterface(); // Get the set of methods that should always be declared. auto alwaysDeclaredMethodsVec = opTrait->getAlwaysDeclaredMethods(); llvm::StringSet<> alwaysDeclaredMethods; alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(), alwaysDeclaredMethodsVec.end()); for (const InterfaceMethod &method : interface.getMethods()) { // Don't declare if the method has a body. if (method.getBody()) continue; // Don't declare if the method has a default implementation and the op // didn't request that it always be declared. if (method.getDefaultImplementation() && !alwaysDeclaredMethods.count(method.getName())) continue; // Interface methods are allowed to overlap with existing methods, so don't // check if pruned. (void)genOpInterfaceMethod(method); } } Method *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method, bool declaration) { SmallVector paramList; for (const InterfaceMethod::Argument &arg : method.getArguments()) paramList.emplace_back(arg.type, arg.name); auto props = (method.isStatic() ? Method::Static : Method::None) | (declaration ? Method::Declaration : Method::None); return opClass.addMethod(method.getReturnType(), method.getName(), props, std::move(paramList)); } void OpEmitter::genOpInterfaceMethods() { for (const auto &trait : op.getTraits()) { if (const auto *opTrait = dyn_cast(&trait)) if (opTrait->shouldDeclareMethods()) genOpInterfaceMethods(opTrait); } } void OpEmitter::genSideEffectInterfaceMethods() { enum EffectKind { Operand, Result, Symbol, Static }; struct EffectLocation { /// The effect applied. SideEffect effect; /// The index if the kind is not static. unsigned index; /// The kind of the location. unsigned kind; }; StringMap> interfaceEffects; auto resolveDecorators = [&](Operator::var_decorator_range decorators, unsigned index, unsigned kind) { for (auto decorator : decorators) if (SideEffect *effect = dyn_cast(&decorator)) { opClass.addTrait(effect->getInterfaceTrait()); interfaceEffects[effect->getBaseEffectName()].push_back( EffectLocation{*effect, index, kind}); } }; // Collect effects that were specified via: /// Traits. for (const auto &trait : op.getTraits()) { const auto *opTrait = dyn_cast(&trait); if (!opTrait) continue; auto &effects = interfaceEffects[opTrait->getBaseEffectName()]; for (auto decorator : opTrait->getEffects()) effects.push_back(EffectLocation{cast(decorator), /*index=*/0, EffectKind::Static}); } /// Attributes and Operands. for (unsigned i = 0, operandIt = 0, e = op.getNumArgs(); i != e; ++i) { Argument arg = op.getArg(i); if (arg.is()) { resolveDecorators(op.getArgDecorators(i), operandIt, EffectKind::Operand); ++operandIt; continue; } if (arg.is()) continue; const NamedAttribute *attr = arg.get(); if (attr->attr.getBaseAttr().isSymbolRefAttr()) resolveDecorators(op.getArgDecorators(i), i, EffectKind::Symbol); } /// Results. for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) resolveDecorators(op.getResultDecorators(i), i, EffectKind::Result); // The code used to add an effect instance. // {0}: The effect class. // {1}: Optional value or symbol reference. // {2}: The side effect stage. // {3}: Does this side effect act on every single value of resource. // {4}: The resource class. const char *addEffectCode = " effects.emplace_back({0}::get(), {1}{2}, {3}, {4}::get());\n"; for (auto &it : interfaceEffects) { // Generate the 'getEffects' method. std::string type = llvm::formatv("::llvm::SmallVectorImpl<::mlir::" "SideEffects::EffectInstance<{0}>> &", it.first()) .str(); auto *getEffects = opClass.addMethod("void", "getEffects", MethodParameter(type, "effects")); ERROR_IF_PRUNED(getEffects, "getEffects", op); auto &body = getEffects->body(); // Add effect instances for each of the locations marked on the operation. for (auto &location : it.second) { StringRef effect = location.effect.getName(); StringRef resource = location.effect.getResource(); int stage = (int)location.effect.getStage(); bool effectOnFullRegion = (int)location.effect.getEffectOnfullRegion(); if (location.kind == EffectKind::Static) { // A static instance has no attached value. body << llvm::formatv(addEffectCode, effect, "", stage, effectOnFullRegion, resource) .str(); } else if (location.kind == EffectKind::Symbol) { // A symbol reference requires adding the proper attribute. const auto *attr = op.getArg(location.index).get(); std::string argName = op.getGetterName(attr->name); if (attr->attr.isOptional()) { body << " if (auto symbolRef = " << argName << "Attr())\n " << llvm::formatv(addEffectCode, effect, "symbolRef, ", stage, effectOnFullRegion, resource) .str(); } else { body << llvm::formatv(addEffectCode, effect, argName + "Attr(), ", stage, effectOnFullRegion, resource) .str(); } } else { // Otherwise this is an operand/result, so we need to attach the Value. body << " for (::mlir::Value value : getODS" << (location.kind == EffectKind::Operand ? "Operands" : "Results") << "(" << location.index << "))\n " << llvm::formatv(addEffectCode, effect, "value, ", stage, effectOnFullRegion, resource) .str(); } } } } void OpEmitter::genTypeInterfaceMethods() { if (!op.allResultTypesKnown()) return; // Generate 'inferReturnTypes' method declaration using the interface method // declared in 'InferTypeOpInterface' op interface. const auto *trait = cast(op.getTrait("::mlir::InferTypeOpInterface::Trait")); Interface interface = trait->getInterface(); Method *method = [&]() -> Method * { for (const InterfaceMethod &interfaceMethod : interface.getMethods()) { if (interfaceMethod.getName() == "inferReturnTypes") { return genOpInterfaceMethod(interfaceMethod, /*declaration=*/false); } } assert(0 && "unable to find inferReturnTypes interface method"); return nullptr; }(); ERROR_IF_PRUNED(method, "inferReturnTypes", op); auto &body = method->body(); body << " inferredReturnTypes.resize(" << op.getNumResults() << ");\n"; FmtContext fctx; fctx.withBuilder("odsBuilder"); fctx.addSubst("_ctxt", "context"); body << " ::mlir::Builder odsBuilder(context);\n"; // Process the type inference graph in topological order, starting from types // that are always fully-inferred: operands and results with constructible // types. The type inference graph here will always be a DAG, so this gives // us the correct order for generating the types. -1 is a placeholder to // indicate the type for a result has not been generated. SmallVector constructedIndices(op.getNumResults(), -1); int inferredTypeIdx = 0; for (int numResults = op.getNumResults(); inferredTypeIdx != numResults;) { for (int i = 0, e = op.getNumResults(); i != e; ++i) { if (constructedIndices[i] >= 0) continue; const InferredResultType &infer = op.getInferredResultType(i); std::string typeStr; if (infer.isArg()) { // If this is an operand, just index into operand list to access the // type. auto arg = op.getArgToOperandOrAttribute(infer.getIndex()); if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) { typeStr = ("operands[" + Twine(arg.operandOrAttributeIndex()) + "].getType()") .str(); // If this is an attribute, index into the attribute dictionary. } else { auto *attr = op.getArg(arg.operandOrAttributeIndex()).get(); body << " ::mlir::TypedAttr odsInferredTypeAttr" << inferredTypeIdx << " = "; if (op.getDialect().usePropertiesForAttributes()) { body << "(properties ? properties.as()->" << attr->name << " : " "::llvm::dyn_cast_or_null<::mlir::TypedAttr>(attributes." "get(\"" + attr->name + "\")));\n"; } else { body << "::llvm::dyn_cast_or_null<::mlir::TypedAttr>(attributes." "get(\"" + attr->name + "\"));\n"; } body << " if (!odsInferredTypeAttr" << inferredTypeIdx << ") return ::mlir::failure();\n"; typeStr = ("odsInferredTypeAttr" + Twine(inferredTypeIdx) + ".getType()") .str(); } } else if (std::optional builder = op.getResult(infer.getResultIndex()) .constraint.getBuilderCall()) { typeStr = tgfmt(*builder, &fctx).str(); } else if (int index = constructedIndices[infer.getResultIndex()]; index >= 0) { typeStr = ("odsInferredType" + Twine(index)).str(); } else { continue; } body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = " << tgfmt(infer.getTransformer(), &fctx.withSelf(typeStr)) << ";\n"; constructedIndices[i] = inferredTypeIdx - 1; } } for (auto [i, index] : llvm::enumerate(constructedIndices)) body << " inferredReturnTypes[" << i << "] = odsInferredType" << index << ";\n"; body << " return ::mlir::success();"; } void OpEmitter::genParser() { if (hasStringAttribute(def, "assemblyFormat")) return; if (!def.getValueAsBit("hasCustomAssemblyFormat")) return; SmallVector paramList; paramList.emplace_back("::mlir::OpAsmParser &", "parser"); paramList.emplace_back("::mlir::OperationState &", "result"); auto *method = opClass.declareStaticMethod("::mlir::ParseResult", "parse", std::move(paramList)); ERROR_IF_PRUNED(method, "parse", op); } void OpEmitter::genPrinter() { if (hasStringAttribute(def, "assemblyFormat")) return; // Check to see if this op uses a c++ format. if (!def.getValueAsBit("hasCustomAssemblyFormat")) return; auto *method = opClass.declareMethod( "void", "print", MethodParameter("::mlir::OpAsmPrinter &", "p")); ERROR_IF_PRUNED(method, "print", op); } void OpEmitter::genVerifier() { auto *implMethod = opClass.addMethod("::mlir::LogicalResult", "verifyInvariantsImpl"); ERROR_IF_PRUNED(implMethod, "verifyInvariantsImpl", op); auto &implBody = implMethod->body(); bool useProperties = emitHelper.hasProperties(); populateSubstitutions(emitHelper, verifyCtx); genAttributeVerifier(emitHelper, verifyCtx, implBody, staticVerifierEmitter, useProperties); genOperandResultVerifier(implBody, op.getOperands(), "operand"); genOperandResultVerifier(implBody, op.getResults(), "result"); for (auto &trait : op.getTraits()) { if (auto *t = dyn_cast(&trait)) { implBody << tgfmt(" if (!($0))\n " "return emitOpError(\"failed to verify that $1\");\n", &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx), t->getSummary()); } } genRegionVerifier(implBody); genSuccessorVerifier(implBody); implBody << " return ::mlir::success();\n"; // TODO: Some places use the `verifyInvariants` to do operation verification. // This may not act as their expectation because this doesn't call any // verifiers of native/interface traits. Needs to review those use cases and // see if we should use the mlir::verify() instead. auto *method = opClass.addMethod("::mlir::LogicalResult", "verifyInvariants"); ERROR_IF_PRUNED(method, "verifyInvariants", op); auto &body = method->body(); if (def.getValueAsBit("hasVerifier")) { body << " if(::mlir::succeeded(verifyInvariantsImpl()) && " "::mlir::succeeded(verify()))\n"; body << " return ::mlir::success();\n"; body << " return ::mlir::failure();"; } else { body << " return verifyInvariantsImpl();"; } } void OpEmitter::genCustomVerifier() { if (def.getValueAsBit("hasVerifier")) { auto *method = opClass.declareMethod("::mlir::LogicalResult", "verify"); ERROR_IF_PRUNED(method, "verify", op); } if (def.getValueAsBit("hasRegionVerifier")) { auto *method = opClass.declareMethod("::mlir::LogicalResult", "verifyRegions"); ERROR_IF_PRUNED(method, "verifyRegions", op); } } void OpEmitter::genOperandResultVerifier(MethodBody &body, Operator::const_value_range values, StringRef valueKind) { // Check that an optional value is at most 1 element. // // {0}: Value index. // {1}: "operand" or "result" const char *const verifyOptional = R"( if (valueGroup{0}.size() > 1) { return emitOpError("{1} group starting at #") << index << " requires 0 or 1 element, but found " << valueGroup{0}.size(); } )"; // Check the types of a range of values. // // {0}: Value index. // {1}: Type constraint function. // {2}: "operand" or "result" const char *const verifyValues = R"( for (auto v : valueGroup{0}) { if (::mlir::failed({1}(*this, v.getType(), "{2}", index++))) return ::mlir::failure(); } )"; const auto canSkip = [](const NamedTypeConstraint &value) { return !value.hasPredicate() && !value.isOptional() && !value.isVariadicOfVariadic(); }; if (values.empty() || llvm::all_of(values, canSkip)) return; FmtContext fctx; body << " {\n unsigned index = 0; (void)index;\n"; for (const auto &staticValue : llvm::enumerate(values)) { const NamedTypeConstraint &value = staticValue.value(); bool hasPredicate = value.hasPredicate(); bool isOptional = value.isOptional(); bool isVariadicOfVariadic = value.isVariadicOfVariadic(); if (!hasPredicate && !isOptional && !isVariadicOfVariadic) continue; body << formatv(" auto valueGroup{2} = getODS{0}{1}s({2});\n", // Capitalize the first letter to match the function name valueKind.substr(0, 1).upper(), valueKind.substr(1), staticValue.index()); // If the constraint is optional check that the value group has at most 1 // value. if (isOptional) { body << formatv(verifyOptional, staticValue.index(), valueKind); } else if (isVariadicOfVariadic) { body << formatv( " if (::mlir::failed(::mlir::OpTrait::impl::verifyValueSizeAttr(" "*this, \"{0}\", \"{1}\", valueGroup{2}.size())))\n" " return ::mlir::failure();\n", value.constraint.getVariadicOfVariadicSegmentSizeAttr(), value.name, staticValue.index()); } // Otherwise, if there is no predicate there is nothing left to do. if (!hasPredicate) continue; // Emit a loop to check all the dynamic values in the pack. StringRef constraintFn = staticVerifierEmitter.getTypeConstraintFn(value.constraint); body << formatv(verifyValues, staticValue.index(), constraintFn, valueKind); } body << " }\n"; } void OpEmitter::genRegionVerifier(MethodBody &body) { /// Code to verify a region. /// /// {0}: Getter for the regions. /// {1}: The region constraint. /// {2}: The region's name. /// {3}: The region description. const char *const verifyRegion = R"( for (auto ®ion : {0}) if (::mlir::failed({1}(*this, region, "{2}", index++))) return ::mlir::failure(); )"; /// Get a single region. /// /// {0}: The region's index. const char *const getSingleRegion = "::llvm::MutableArrayRef((*this)->getRegion({0}))"; // If we have no regions, there is nothing more to do. const auto canSkip = [](const NamedRegion ®ion) { return region.constraint.getPredicate().isNull(); }; auto regions = op.getRegions(); if (regions.empty() && llvm::all_of(regions, canSkip)) return; body << " {\n unsigned index = 0; (void)index;\n"; for (const auto &it : llvm::enumerate(regions)) { const auto ®ion = it.value(); if (canSkip(region)) continue; auto getRegion = region.isVariadic() ? formatv("{0}()", op.getGetterName(region.name)).str() : formatv(getSingleRegion, it.index()).str(); auto constraintFn = staticVerifierEmitter.getRegionConstraintFn(region.constraint); body << formatv(verifyRegion, getRegion, constraintFn, region.name); } body << " }\n"; } void OpEmitter::genSuccessorVerifier(MethodBody &body) { const char *const verifySuccessor = R"( for (auto *successor : {0}) if (::mlir::failed({1}(*this, successor, "{2}", index++))) return ::mlir::failure(); )"; /// Get a single successor. /// /// {0}: The successor's name. const char *const getSingleSuccessor = "::llvm::MutableArrayRef({0}())"; // If we have no successors, there is nothing more to do. const auto canSkip = [](const NamedSuccessor &successor) { return successor.constraint.getPredicate().isNull(); }; auto successors = op.getSuccessors(); if (successors.empty() && llvm::all_of(successors, canSkip)) return; body << " {\n unsigned index = 0; (void)index;\n"; for (auto it : llvm::enumerate(successors)) { const auto &successor = it.value(); if (canSkip(successor)) continue; auto getSuccessor = formatv(successor.isVariadic() ? "{0}()" : getSingleSuccessor, successor.name, it.index()) .str(); auto constraintFn = staticVerifierEmitter.getSuccessorConstraintFn(successor.constraint); body << formatv(verifySuccessor, getSuccessor, constraintFn, successor.name); } body << " }\n"; } /// Add a size count trait to the given operation class. static void addSizeCountTrait(OpClass &opClass, StringRef traitKind, int numTotal, int numVariadic) { if (numVariadic != 0) { if (numTotal == numVariadic) opClass.addTrait("::mlir::OpTrait::Variadic" + traitKind + "s"); else opClass.addTrait("::mlir::OpTrait::AtLeastN" + traitKind + "s<" + Twine(numTotal - numVariadic) + ">::Impl"); return; } switch (numTotal) { case 0: opClass.addTrait("::mlir::OpTrait::Zero" + traitKind + "s"); break; case 1: opClass.addTrait("::mlir::OpTrait::One" + traitKind); break; default: opClass.addTrait("::mlir::OpTrait::N" + traitKind + "s<" + Twine(numTotal) + ">::Impl"); break; } } void OpEmitter::genTraits() { // Add region size trait. unsigned numRegions = op.getNumRegions(); unsigned numVariadicRegions = op.getNumVariadicRegions(); addSizeCountTrait(opClass, "Region", numRegions, numVariadicRegions); // Add result size traits. int numResults = op.getNumResults(); int numVariadicResults = op.getNumVariableLengthResults(); addSizeCountTrait(opClass, "Result", numResults, numVariadicResults); // For single result ops with a known specific type, generate a OneTypedResult // trait. if (numResults == 1 && numVariadicResults == 0) { auto cppName = op.getResults().begin()->constraint.getCPPClassName(); opClass.addTrait("::mlir::OpTrait::OneTypedResult<" + cppName + ">::Impl"); } // Add successor size trait. unsigned numSuccessors = op.getNumSuccessors(); unsigned numVariadicSuccessors = op.getNumVariadicSuccessors(); addSizeCountTrait(opClass, "Successor", numSuccessors, numVariadicSuccessors); // Add variadic size trait and normal op traits. int numOperands = op.getNumOperands(); int numVariadicOperands = op.getNumVariableLengthOperands(); // Add operand size trait. addSizeCountTrait(opClass, "Operand", numOperands, numVariadicOperands); // The op traits defined internal are ensured that they can be verified // earlier. for (const auto &trait : op.getTraits()) { if (auto *opTrait = dyn_cast(&trait)) { if (opTrait->isStructuralOpTrait()) opClass.addTrait(opTrait->getFullyQualifiedTraitName()); } } // OpInvariants wrapps the verifyInvariants which needs to be run before // native/interface traits and after all the traits with `StructuralOpTrait`. opClass.addTrait("::mlir::OpTrait::OpInvariants"); if (emitHelper.hasProperties()) opClass.addTrait("::mlir::BytecodeOpInterface::Trait"); // Add the native and interface traits. for (const auto &trait : op.getTraits()) { if (auto *opTrait = dyn_cast(&trait)) { if (!opTrait->isStructuralOpTrait()) opClass.addTrait(opTrait->getFullyQualifiedTraitName()); } else if (auto *opTrait = dyn_cast(&trait)) { opClass.addTrait(opTrait->getFullyQualifiedTraitName()); } } } void OpEmitter::genOpNameGetter() { auto *method = opClass.addStaticMethod( "::llvm::StringLiteral", "getOperationName"); ERROR_IF_PRUNED(method, "getOperationName", op); method->body() << " return ::llvm::StringLiteral(\"" << op.getOperationName() << "\");"; } void OpEmitter::genOpAsmInterface() { // If the user only has one results or specifically added the Asm trait, // then don't generate it for them. We specifically only handle multi result // operations, because the name of a single result in the common case is not // interesting(generally 'result'/'output'/etc.). // TODO: We could also add a flag to allow operations to opt in to this // generation, even if they only have a single operation. int numResults = op.getNumResults(); if (numResults <= 1 || op.getTrait("::mlir::OpAsmOpInterface::Trait")) return; SmallVector resultNames(numResults); for (int i = 0; i != numResults; ++i) resultNames[i] = op.getResultName(i); // Don't add the trait if none of the results have a valid name. if (llvm::all_of(resultNames, [](StringRef name) { return name.empty(); })) return; opClass.addTrait("::mlir::OpAsmOpInterface::Trait"); // Generate the right accessor for the number of results. auto *method = opClass.addMethod( "void", "getAsmResultNames", MethodParameter("::mlir::OpAsmSetValueNameFn", "setNameFn")); ERROR_IF_PRUNED(method, "getAsmResultNames", op); auto &body = method->body(); for (int i = 0; i != numResults; ++i) { body << " auto resultGroup" << i << " = getODSResults(" << i << ");\n" << " if (!resultGroup" << i << ".empty())\n" << " setNameFn(*resultGroup" << i << ".begin(), \"" << resultNames[i] << "\");\n"; } } //===----------------------------------------------------------------------===// // OpOperandAdaptor emitter //===----------------------------------------------------------------------===// namespace { // Helper class to emit Op operand adaptors to an output stream. Operand // adaptors are wrappers around random access ranges that provide named operand // getters identical to those defined in the Op. // This currently generates 3 classes per Op: // * A Base class within the 'detail' namespace, which contains all logic and // members independent of the random access range that is indexed into. // In other words, it contains all the attribute and region getters. // * A templated class named '{OpName}GenericAdaptor' with a template parameter // 'RangeT' that is indexed into by the getters to access the operands. // It contains all getters to access operands and inherits from the previous // class. // * A class named '{OpName}Adaptor', which inherits from the 'GenericAdaptor' // with 'mlir::ValueRange' as template parameter. It adds a constructor from // an instance of the op type and a verify function. class OpOperandAdaptorEmitter { public: static void emitDecl(const Operator &op, const StaticVerifierFunctionEmitter &staticVerifierEmitter, raw_ostream &os); static void emitDef(const Operator &op, const StaticVerifierFunctionEmitter &staticVerifierEmitter, raw_ostream &os); private: explicit OpOperandAdaptorEmitter( const Operator &op, const StaticVerifierFunctionEmitter &staticVerifierEmitter); // Add verification function. This generates a verify method for the adaptor // which verifies all the op-independent attribute constraints. void addVerification(); // The operation for which to emit an adaptor. const Operator &op; // The generated adaptor classes. Class genericAdaptorBase; Class genericAdaptor; Class adaptor; // The emitter containing all of the locally emitted verification functions. const StaticVerifierFunctionEmitter &staticVerifierEmitter; // Helper for emitting adaptor code. OpOrAdaptorHelper emitHelper; }; } // namespace OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( const Operator &op, const StaticVerifierFunctionEmitter &staticVerifierEmitter) : op(op), genericAdaptorBase(op.getGenericAdaptorName() + "Base"), genericAdaptor(op.getGenericAdaptorName()), adaptor(op.getAdaptorName()), staticVerifierEmitter(staticVerifierEmitter), emitHelper(op, /*emitForOp=*/false) { genericAdaptorBase.declare(Visibility::Public); bool useProperties = emitHelper.hasProperties(); if (useProperties) { // Define the properties struct with multiple members. using ConstArgument = llvm::PointerUnion; SmallVector attrOrProperties; for (const std::pair &it : emitHelper.getAttrMetadata()) { if (!it.second.constraint || !it.second.constraint->isDerivedAttr()) attrOrProperties.push_back(&it.second); } for (const NamedProperty &prop : op.getProperties()) attrOrProperties.push_back(&prop); if (emitHelper.getOperandSegmentsSize()) attrOrProperties.push_back(&emitHelper.getOperandSegmentsSize().value()); if (emitHelper.getResultSegmentsSize()) attrOrProperties.push_back(&emitHelper.getResultSegmentsSize().value()); assert(!attrOrProperties.empty()); std::string declarations = " struct Properties {\n"; llvm::raw_string_ostream os(declarations); std::string comparator = " bool operator==(const Properties &rhs) const {\n" " return \n"; llvm::raw_string_ostream comparatorOs(comparator); for (const auto &attrOrProp : attrOrProperties) { if (const auto *namedProperty = llvm::dyn_cast_if_present(attrOrProp)) { StringRef name = namedProperty->name; if (name.empty()) report_fatal_error("missing name for property"); std::string camelName = convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true); auto &prop = namedProperty->prop; // Generate the data member using the storage type. os << " using " << name << "Ty = " << prop.getStorageType() << ";\n" << " " << name << "Ty " << name; if (prop.hasDefaultValue()) os << " = " << prop.getDefaultValue(); comparatorOs << " rhs." << name << " == this->" << name << " &&\n"; // Emit accessors using the interface type. const char *accessorFmt = R"decl(; {0} get{1}() { auto &propStorage = this->{2}; return {3}; } void set{1}(const {0} &propValue) { auto &propStorage = this->{2}; {4}; } )decl"; FmtContext fctx; os << formatv(accessorFmt, prop.getInterfaceType(), camelName, name, tgfmt(prop.getConvertFromStorageCall(), &fctx.addSubst("_storage", propertyStorage)), tgfmt(prop.getAssignToStorageCall(), &fctx.addSubst("_value", propertyValue) .addSubst("_storage", propertyStorage))); continue; } const auto *namedAttr = llvm::dyn_cast_if_present(attrOrProp); const Attribute *attr = nullptr; if (namedAttr->constraint) attr = &*namedAttr->constraint; StringRef name = namedAttr->attrName; if (name.empty()) report_fatal_error("missing name for property attr"); std::string camelName = convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true); // Generate the data member using the storage type. StringRef storageType; if (attr) { storageType = attr->getStorageType(); } else { if (name != operandSegmentAttrName && name != resultSegmentAttrName) { report_fatal_error("unexpected AttributeMetadata"); } // TODO: update to use native integers. storageType = "::mlir::DenseI32ArrayAttr"; } os << " using " << name << "Ty = " << storageType << ";\n" << " " << name << "Ty " << name << ";\n"; comparatorOs << " rhs." << name << " == this->" << name << " &&\n"; // Emit accessors using the interface type. if (attr) { const char *accessorFmt = R"decl( auto get{0}() { auto &propStorage = this->{1}; return ::llvm::{2}<{3}>(propStorage); } void set{0}(const {3} &propValue) { this->{1} = propValue; } )decl"; os << formatv(accessorFmt, camelName, name, attr->isOptional() || attr->hasDefaultValue() ? "dyn_cast_or_null" : "cast", storageType); } } comparatorOs << " true;\n }\n" " bool operator!=(const Properties &rhs) const {\n" " return !(*this == rhs);\n" " }\n"; comparatorOs.flush(); os << comparator; os << " };\n"; os.flush(); genericAdaptorBase.declare(std::move(declarations)); } genericAdaptorBase.declare(Visibility::Protected); genericAdaptorBase.declare("::mlir::DictionaryAttr", "odsAttrs"); genericAdaptorBase.declare("::std::optional<::mlir::OperationName>", "odsOpName"); if (useProperties) genericAdaptorBase.declare("Properties", "properties"); genericAdaptorBase.declare("::mlir::RegionRange", "odsRegions"); genericAdaptor.addTemplateParam("RangeT"); genericAdaptor.addField("RangeT", "odsOperands"); genericAdaptor.addParent( ParentClass("detail::" + genericAdaptorBase.getClassName())); genericAdaptor.declare( "ValueT", "::llvm::detail::ValueOfRange"); genericAdaptor.declare( "Base", "detail::" + genericAdaptorBase.getClassName()); const auto *attrSizedOperands = op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"); { SmallVector paramList; paramList.emplace_back("::mlir::DictionaryAttr", "attrs", attrSizedOperands ? "" : "nullptr"); if (useProperties) paramList.emplace_back("const Properties &", "properties", "{}"); else paramList.emplace_back("const ::mlir::EmptyProperties &", "properties", "{}"); paramList.emplace_back("::mlir::RegionRange", "regions", "{}"); auto *baseConstructor = genericAdaptorBase.addConstructor(paramList); baseConstructor->addMemberInitializer("odsAttrs", "attrs"); if (useProperties) baseConstructor->addMemberInitializer("properties", "properties"); baseConstructor->addMemberInitializer("odsRegions", "regions"); MethodBody &body = baseConstructor->body(); body.indent() << "if (odsAttrs)\n"; body.indent() << formatv( "odsOpName.emplace(\"{0}\", odsAttrs.getContext());\n", op.getOperationName()); paramList.insert(paramList.begin(), MethodParameter("RangeT", "values")); auto *constructor = genericAdaptor.addConstructor(paramList); constructor->addMemberInitializer("Base", "attrs, properties, regions"); constructor->addMemberInitializer("odsOperands", "values"); // Add a forwarding constructor to the previous one that accepts // OpaqueProperties instead and check for null and perform the cast to the // actual properties type. paramList[1] = MethodParameter("::mlir::DictionaryAttr", "attrs"); paramList[2] = MethodParameter("::mlir::OpaqueProperties", "properties"); auto *opaquePropertiesConstructor = genericAdaptor.addConstructor(std::move(paramList)); if (useProperties) { opaquePropertiesConstructor->addMemberInitializer( genericAdaptor.getClassName(), "values, " "attrs, " "(properties ? *properties.as() : Properties{}), " "regions"); } else { opaquePropertiesConstructor->addMemberInitializer( genericAdaptor.getClassName(), "values, " "attrs, " "(properties ? *properties.as<::mlir::EmptyProperties *>() : " "::mlir::EmptyProperties{}), " "regions"); } } // Create constructors constructing the adaptor from an instance of the op. // This takes the attributes, properties and regions from the op instance // and the value range from the parameter. { // Base class is in the cpp file and can simply access the members of the op // class to initialize the template independent fields. auto *constructor = genericAdaptorBase.addConstructor( MethodParameter(op.getCppClassName(), "op")); constructor->addMemberInitializer( genericAdaptorBase.getClassName(), llvm::Twine(!useProperties ? "op->getAttrDictionary()" : "op->getDiscardableAttrDictionary()") + ", op.getProperties(), op->getRegions()"); // Generic adaptor is templated and therefore defined inline in the header. // We cannot use the Op class here as it is an incomplete type (we have a // circular reference between the two). // Use a template trick to make the constructor be instantiated at call site // when the op class is complete. constructor = genericAdaptor.addConstructor( MethodParameter("RangeT", "values"), MethodParameter("LateInst", "op")); constructor->addTemplateParam("LateInst = " + op.getCppClassName()); constructor->addTemplateParam( "= std::enable_if_t>"); constructor->addMemberInitializer("Base", "op"); constructor->addMemberInitializer("odsOperands", "values"); } std::string sizeAttrInit; if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { if (op.getDialect().usePropertiesForAttributes()) sizeAttrInit = formatv(adapterSegmentSizeAttrInitCodeProperties, llvm::formatv("getProperties().operandSegmentSizes")); else sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode, emitHelper.getAttr(operandSegmentAttrName)); } generateNamedOperandGetters(op, genericAdaptor, /*genericAdaptorBase=*/&genericAdaptorBase, /*sizeAttrInit=*/sizeAttrInit, /*rangeType=*/"RangeT", /*rangeElementType=*/"ValueT", /*rangeBeginCall=*/"odsOperands.begin()", /*rangeSizeCall=*/"odsOperands.size()", /*getOperandCallPattern=*/"odsOperands[{0}]"); // Any invalid overlap for `getOperands` will have been diagnosed before // here already. if (auto *m = genericAdaptor.addMethod("RangeT", "getOperands")) m->body() << " return odsOperands;"; FmtContext fctx; fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())"); // Generate named accessor with Attribute return type. auto emitAttrWithStorageType = [&](StringRef name, StringRef emitName, Attribute attr) { auto *method = genericAdaptorBase.addMethod(attr.getStorageType(), emitName + "Attr"); ERROR_IF_PRUNED(method, "Adaptor::" + emitName + "Attr", op); auto &body = method->body().indent(); if (!useProperties) body << "assert(odsAttrs && \"no attributes when constructing " "adapter\");\n"; body << formatv( "auto attr = ::llvm::{1}<{2}>({0});\n", emitHelper.getAttr(name), attr.hasDefaultValue() || attr.isOptional() ? "dyn_cast_or_null" : "cast", attr.getStorageType()); if (attr.hasDefaultValue() && attr.isOptional()) { // Use the default value if attribute is not set. // TODO: this is inefficient, we are recreating the attribute for every // call. This should be set instead. std::string defaultValue = std::string( tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue())); body << "if (!attr)\n attr = " << defaultValue << ";\n"; } body << "return attr;\n"; }; if (useProperties) { auto *m = genericAdaptorBase.addInlineMethod("const Properties &", "getProperties"); ERROR_IF_PRUNED(m, "Adaptor::getProperties", op); m->body() << " return properties;"; } { auto *m = genericAdaptorBase.addMethod("::mlir::DictionaryAttr", "getAttributes"); ERROR_IF_PRUNED(m, "Adaptor::getAttributes", op); m->body() << " return odsAttrs;"; } for (auto &namedAttr : op.getAttributes()) { const auto &name = namedAttr.name; const auto &attr = namedAttr.attr; if (attr.isDerivedAttr()) continue; std::string emitName = op.getGetterName(name); emitAttrWithStorageType(name, emitName, attr); emitAttrGetterWithReturnType(fctx, genericAdaptorBase, op, emitName, attr); } unsigned numRegions = op.getNumRegions(); for (unsigned i = 0; i < numRegions; ++i) { const auto ®ion = op.getRegion(i); if (region.name.empty()) continue; // Generate the accessors for a variadic region. std::string name = op.getGetterName(region.name); if (region.isVariadic()) { auto *m = genericAdaptorBase.addMethod("::mlir::RegionRange", name); ERROR_IF_PRUNED(m, "Adaptor::" + name, op); m->body() << formatv(" return odsRegions.drop_front({0});", i); continue; } auto *m = genericAdaptorBase.addMethod("::mlir::Region &", name); ERROR_IF_PRUNED(m, "Adaptor::" + name, op); m->body() << formatv(" return *odsRegions[{0}];", i); } if (numRegions > 0) { // Any invalid overlap for `getRegions` will have been diagnosed before // here already. if (auto *m = genericAdaptorBase.addMethod("::mlir::RegionRange", "getRegions")) m->body() << " return odsRegions;"; } StringRef genericAdaptorClassName = genericAdaptor.getClassName(); adaptor.addParent(ParentClass(genericAdaptorClassName)) .addTemplateParam("::mlir::ValueRange"); adaptor.declare(Visibility::Public); adaptor.declare(genericAdaptorClassName + "::" + genericAdaptorClassName); { // Constructor taking the Op as single parameter. auto *constructor = adaptor.addConstructor(MethodParameter(op.getCppClassName(), "op")); constructor->addMemberInitializer(genericAdaptorClassName, "op->getOperands(), op"); } // Add verification function. addVerification(); genericAdaptorBase.finalize(); genericAdaptor.finalize(); adaptor.finalize(); } void OpOperandAdaptorEmitter::addVerification() { auto *method = adaptor.addMethod("::mlir::LogicalResult", "verify", MethodParameter("::mlir::Location", "loc")); ERROR_IF_PRUNED(method, "verify", op); auto &body = method->body(); bool useProperties = emitHelper.hasProperties(); FmtContext verifyCtx; populateSubstitutions(emitHelper, verifyCtx); genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter, useProperties); body << " return ::mlir::success();"; } void OpOperandAdaptorEmitter::emitDecl( const Operator &op, const StaticVerifierFunctionEmitter &staticVerifierEmitter, raw_ostream &os) { OpOperandAdaptorEmitter emitter(op, staticVerifierEmitter); { NamespaceEmitter ns(os, "detail"); emitter.genericAdaptorBase.writeDeclTo(os); } emitter.genericAdaptor.writeDeclTo(os); emitter.adaptor.writeDeclTo(os); } void OpOperandAdaptorEmitter::emitDef( const Operator &op, const StaticVerifierFunctionEmitter &staticVerifierEmitter, raw_ostream &os) { OpOperandAdaptorEmitter emitter(op, staticVerifierEmitter); { NamespaceEmitter ns(os, "detail"); emitter.genericAdaptorBase.writeDefTo(os); } emitter.genericAdaptor.writeDefTo(os); emitter.adaptor.writeDefTo(os); } // Emits the opcode enum and op classes. static void emitOpClasses(const RecordKeeper &recordKeeper, const std::vector &defs, raw_ostream &os, bool emitDecl) { // First emit forward declaration for each class, this allows them to refer // to each others in traits for example. if (emitDecl) { os << "#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES)\n"; os << "#undef GET_OP_FWD_DEFINES\n"; for (auto *def : defs) { Operator op(*def); NamespaceEmitter emitter(os, op.getCppNamespace()); os << "class " << op.getCppClassName() << ";\n"; } os << "#endif\n\n"; } IfDefScope scope("GET_OP_CLASSES", os); if (defs.empty()) return; // Generate all of the locally instantiated methods first. StaticVerifierFunctionEmitter staticVerifierEmitter(os, recordKeeper); os << formatv(opCommentHeader, "Local Utility Method", "Definitions"); staticVerifierEmitter.emitOpConstraints(defs, emitDecl); for (auto *def : defs) { Operator op(*def); if (emitDecl) { { NamespaceEmitter emitter(os, op.getCppNamespace()); os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations"); OpOperandAdaptorEmitter::emitDecl(op, staticVerifierEmitter, os); OpEmitter::emitDecl(op, os, staticVerifierEmitter); } // Emit the TypeID explicit specialization to have a single definition. if (!op.getCppNamespace().empty()) os << "MLIR_DECLARE_EXPLICIT_TYPE_ID(" << op.getCppNamespace() << "::" << op.getCppClassName() << ")\n\n"; } else { { NamespaceEmitter emitter(os, op.getCppNamespace()); os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions"); OpOperandAdaptorEmitter::emitDef(op, staticVerifierEmitter, os); OpEmitter::emitDef(op, os, staticVerifierEmitter); } // Emit the TypeID explicit specialization to have a single definition. if (!op.getCppNamespace().empty()) os << "MLIR_DEFINE_EXPLICIT_TYPE_ID(" << op.getCppNamespace() << "::" << op.getCppClassName() << ")\n\n"; } } } // Emits a comma-separated list of the ops. static void emitOpList(const std::vector &defs, raw_ostream &os) { IfDefScope scope("GET_OP_LIST", os); interleave( // TODO: We are constructing the Operator wrapper instance just for // getting it's qualified class name here. Reduce the overhead by having a // lightweight version of Operator class just for that purpose. defs, [&os](Record *def) { os << Operator(def).getQualCppClassName(); }, [&os]() { os << ",\n"; }); } static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) { emitSourceFileHeader("Op Declarations", os, recordKeeper); std::vector defs = getRequestedOpDefinitions(recordKeeper); emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/true); return false; } static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) { emitSourceFileHeader("Op Definitions", os, recordKeeper); std::vector defs = getRequestedOpDefinitions(recordKeeper); emitOpList(defs, os); emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/false); return false; } static mlir::GenRegistration genOpDecls("gen-op-decls", "Generate op declarations", [](const RecordKeeper &records, raw_ostream &os) { return emitOpDecls(records, os); }); static mlir::GenRegistration genOpDefs("gen-op-defs", "Generate op definitions", [](const RecordKeeper &records, raw_ostream &os) { return emitOpDefs(records, os); });