//===- TestOpsSyntax.cpp - Operations for testing syntax ------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "TestOpsSyntax.h" #include "TestDialect.h" #include "mlir/IR/OpImplementation.h" #include "llvm/Support/Base64.h" using namespace mlir; using namespace test; //===----------------------------------------------------------------------===// // Test Format* operations //===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===// // Parsing static ParseResult parseCustomOptionalOperand( OpAsmParser &parser, std::optional &optOperand) { if (succeeded(parser.parseOptionalLParen())) { optOperand.emplace(); if (parser.parseOperand(*optOperand) || parser.parseRParen()) return failure(); } return success(); } static ParseResult parseCustomDirectiveOperands( OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, std::optional &optOperand, SmallVectorImpl &varOperands) { if (parser.parseOperand(operand)) return failure(); if (succeeded(parser.parseOptionalComma())) { optOperand.emplace(); if (parser.parseOperand(*optOperand)) return failure(); } if (parser.parseArrow() || parser.parseLParen() || parser.parseOperandList(varOperands) || parser.parseRParen()) return failure(); return success(); } static ParseResult parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, Type &optOperandType, SmallVectorImpl &varOperandTypes) { if (parser.parseColon()) return failure(); if (parser.parseType(operandType)) return failure(); if (succeeded(parser.parseOptionalComma())) { if (parser.parseType(optOperandType)) return failure(); } if (parser.parseArrow() || parser.parseLParen() || parser.parseTypeList(varOperandTypes) || parser.parseRParen()) return failure(); return success(); } static ParseResult parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, Type optOperandType, const SmallVectorImpl &varOperandTypes) { if (parser.parseKeyword("type_refs_capture")) return failure(); Type operandType2, optOperandType2; SmallVector varOperandTypes2; if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, varOperandTypes2)) return failure(); if (operandType != operandType2 || optOperandType != optOperandType2 || varOperandTypes != varOperandTypes2) return failure(); return success(); } static ParseResult parseCustomDirectiveOperandsAndTypes( OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, std::optional &optOperand, SmallVectorImpl &varOperands, Type &operandType, Type &optOperandType, SmallVectorImpl &varOperandTypes) { if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || parseCustomDirectiveResults(parser, operandType, optOperandType, varOperandTypes)) return failure(); return success(); } static ParseResult parseCustomDirectiveRegions( OpAsmParser &parser, Region ®ion, SmallVectorImpl> &varRegions) { if (parser.parseRegion(region)) return failure(); if (failed(parser.parseOptionalComma())) return success(); std::unique_ptr varRegion = std::make_unique(); if (parser.parseRegion(*varRegion)) return failure(); varRegions.emplace_back(std::move(varRegion)); return success(); } static ParseResult parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, SmallVectorImpl &varSuccessors) { if (parser.parseSuccessor(successor)) return failure(); if (failed(parser.parseOptionalComma())) return success(); Block *varSuccessor; if (parser.parseSuccessor(varSuccessor)) return failure(); varSuccessors.append(2, varSuccessor); return success(); } static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, IntegerAttr &attr, IntegerAttr &optAttr) { if (parser.parseAttribute(attr)) return failure(); if (succeeded(parser.parseOptionalComma())) { if (parser.parseAttribute(optAttr)) return failure(); } return success(); } static ParseResult parseCustomDirectiveSpacing(OpAsmParser &parser, mlir::StringAttr &attr) { return parser.parseAttribute(attr); } static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, NamedAttrList &attrs) { return parser.parseOptionalAttrDict(attrs); } static ParseResult parseCustomDirectiveOptionalOperandRef( OpAsmParser &parser, std::optional &optOperand) { int64_t operandCount = 0; if (parser.parseInteger(operandCount)) return failure(); bool expectedOptionalOperand = operandCount == 0; return success(expectedOptionalOperand != optOperand.has_value()); } //===----------------------------------------------------------------------===// // Printing static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *, Value optOperand) { if (optOperand) printer << "(" << optOperand << ") "; } static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, Value operand, Value optOperand, OperandRange varOperands) { printer << operand; if (optOperand) printer << ", " << optOperand; printer << " -> (" << varOperands << ")"; } static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, Type operandType, Type optOperandType, TypeRange varOperandTypes) { printer << " : " << operandType; if (optOperandType) printer << ", " << optOperandType; printer << " -> (" << varOperandTypes << ")"; } static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, Operation *op, Type operandType, Type optOperandType, TypeRange varOperandTypes) { printer << " type_refs_capture "; printCustomDirectiveResults(printer, op, operandType, optOperandType, varOperandTypes); } static void printCustomDirectiveOperandsAndTypes( OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, OperandRange varOperands, Type operandType, Type optOperandType, TypeRange varOperandTypes) { printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); printCustomDirectiveResults(printer, op, operandType, optOperandType, varOperandTypes); } static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, Region ®ion, MutableArrayRef varRegions) { printer.printRegion(region); if (!varRegions.empty()) { printer << ", "; for (Region ®ion : varRegions) printer.printRegion(region); } } static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, Block *successor, SuccessorRange varSuccessors) { printer << successor; if (!varSuccessors.empty()) printer << ", " << varSuccessors.front(); } static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, Attribute attribute, Attribute optAttribute) { printer << attribute; if (optAttribute) printer << ", " << optAttribute; } static void printCustomDirectiveSpacing(OpAsmPrinter &printer, Operation *op, Attribute attribute) { printer << attribute; } static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, DictionaryAttr attrs) { printer.printOptionalAttrDict(attrs.getValue()); } static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, Operation *op, Value optOperand) { printer << (optOperand ? "1" : "0"); } //===----------------------------------------------------------------------===// // Test parser. //===----------------------------------------------------------------------===// ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalColon()) return success(); uint64_t numResults; if (parser.parseInteger(numResults)) return failure(); IndexType type = parser.getBuilder().getIndexType(); for (unsigned i = 0; i < numResults; ++i) result.addTypes(type); return success(); } void ParseIntegerLiteralOp::print(OpAsmPrinter &p) { if (unsigned numResults = getNumResults()) p << " : " << numResults; } ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser, OperationState &result) { StringRef keyword; if (parser.parseKeyword(&keyword)) return failure(); result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); return success(); } void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); } ParseResult ParseB64BytesOp::parse(OpAsmParser &parser, OperationState &result) { std::vector bytes; if (parser.parseBase64Bytes(&bytes)) return failure(); result.addAttribute("b64", parser.getBuilder().getStringAttr( StringRef(&bytes.front(), bytes.size()))); return success(); } void ParseB64BytesOp::print(OpAsmPrinter &p) { p << " \"" << llvm::encodeBase64(getB64()) << "\""; } ::mlir::LogicalResult FormatInferType2Op::inferReturnTypes( ::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, OpaqueProperties properties, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)}); return ::mlir::success(); } //===----------------------------------------------------------------------===// // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. ParseResult WrappingRegionOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseKeyword("wraps")) return failure(); // Parse the wrapped op in a region Region &body = *result.addRegion(); body.push_back(new Block); Block &block = body.back(); Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin()); if (!wrappedOp) return failure(); // Create a return terminator in the inner region, pass as operand to the // terminator the returned values from the wrapped operation. SmallVector returnOperands(wrappedOp->getResults()); OpBuilder builder(parser.getContext()); builder.setInsertionPointToEnd(&block); builder.create(wrappedOp->getLoc(), returnOperands); // Get the results type for the wrapping op from the terminator operands. Operation &returnOp = body.back().back(); result.types.append(returnOp.operand_type_begin(), returnOp.operand_type_end()); // Use the location of the wrapped op for the "test.wrapping_region" op. result.location = wrappedOp->getLoc(); return success(); } void WrappingRegionOp::print(OpAsmPrinter &p) { p << " wraps "; p.printGenericOp(&getRegion().front().front()); } //===----------------------------------------------------------------------===// // Test PrettyPrintedRegionOp - exercising the following parser APIs // parseGenericOperationAfterOpName // parseCustomOperationName //===----------------------------------------------------------------------===// ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser, OperationState &result) { SMLoc loc = parser.getCurrentLocation(); Location currLocation = parser.getEncodedSourceLoc(loc); // Parse the operands. SmallVector operands; if (parser.parseOperandList(operands)) return failure(); // Check if we are parsing the pretty-printed version // test.pretty_printed_region start end : // Else fallback to parsing the "non pretty-printed" version. if (!succeeded(parser.parseOptionalKeyword("start"))) return parser.parseGenericOperationAfterOpName(result, llvm::ArrayRef(operands)); FailureOr parseOpNameInfo = parser.parseCustomOperationName(); if (failed(parseOpNameInfo)) return failure(); StringAttr innerOpName = parseOpNameInfo->getIdentifier(); FunctionType opFntype; std::optional explicitLoc; if (parser.parseKeyword("end") || parser.parseColon() || parser.parseType(opFntype) || parser.parseOptionalLocationSpecifier(explicitLoc)) return failure(); // If location of the op is explicitly provided, then use it; Else use // the parser's current location. Location opLoc = explicitLoc.value_or(currLocation); // Derive the SSA-values for op's operands. if (parser.resolveOperands(operands, opFntype.getInputs(), loc, result.operands)) return failure(); // Add a region for op. Region ®ion = *result.addRegion(); // Create a basic-block inside op's region. Block &block = region.emplaceBlock(); // Create and insert an "inner-op" operation in the block. // Just for testing purposes, we can assume that inner op is a binary op with // result and operand types all same as the test-op's first operand. Type innerOpType = opFntype.getInput(0); Value lhs = block.addArgument(innerOpType, opLoc); Value rhs = block.addArgument(innerOpType, opLoc); OpBuilder builder(parser.getBuilder().getContext()); builder.setInsertionPointToStart(&block); Operation *innerOp = builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType); // Insert a return statement in the block returning the inner-op's result. builder.create(innerOp->getLoc(), innerOp->getResults()); // Populate the op operation-state with result-type and location. result.addTypes(opFntype.getResults()); result.location = innerOp->getLoc(); return success(); } void PrettyPrintedRegionOp::print(OpAsmPrinter &p) { p << ' '; p.printOperands(getOperands()); Operation &innerOp = getRegion().front().front(); // Assuming that region has a single non-terminator inner-op, if the inner-op // meets some criteria (which in this case is a simple one based on the name // of inner-op), then we can print the entire region in a succinct way. // Here we assume that the prototype of "test.special.op" can be trivially // derived while parsing it back. if (innerOp.getName().getStringRef().equals("test.special.op")) { p << " start test.special.op end"; } else { p << " ("; p.printRegion(getRegion()); p << ")"; } p << " : "; p.printFunctionalType(*this); } //===----------------------------------------------------------------------===// // Test PolyForOp - parse list of region arguments. //===----------------------------------------------------------------------===// ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector ivsInfo; // Parse list of region arguments without a delimiter. if (parser.parseArgumentList(ivsInfo, OpAsmParser::Delimiter::None)) return failure(); // Parse the body region. Region *body = result.addRegion(); for (auto &iv : ivsInfo) iv.type = parser.getBuilder().getIndexType(); return parser.parseRegion(*body, ivsInfo); } void PolyForOp::print(OpAsmPrinter &p) { p << " "; llvm::interleaveComma(getRegion().getArguments(), p, [&](auto arg) { p.printRegionArgument(arg, /*argAttrs =*/{}, /*omitType=*/true); }); p << " "; p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); } void PolyForOp::getAsmBlockArgumentNames(Region ®ion, OpAsmSetValueNameFn setNameFn) { auto arrayAttr = getOperation()->getAttrOfType("arg_names"); if (!arrayAttr) return; auto args = getRegion().front().getArguments(); auto e = std::min(arrayAttr.size(), args.size()); for (unsigned i = 0; i < e; ++i) { if (auto strAttr = dyn_cast(arrayAttr[i])) setNameFn(args[i], strAttr.getValue()); } } //===----------------------------------------------------------------------===// // TestAttrWithLoc - parse/printOptionalLocationSpecifier //===----------------------------------------------------------------------===// static ParseResult parseOptionalLoc(OpAsmParser &p, Attribute &loc) { std::optional result; SMLoc sourceLoc = p.getCurrentLocation(); if (p.parseOptionalLocationSpecifier(result)) return failure(); if (result) loc = *result; else loc = p.getEncodedSourceLoc(sourceLoc); return success(); } static void printOptionalLoc(OpAsmPrinter &p, Operation *op, Attribute loc) { p.printOptionalLocationSpecifier(cast(loc)); } #define GET_OP_CLASSES #include "TestOpsSyntax.cpp.inc" void TestDialect::registerOpsSyntax() { addOperations< #define GET_OP_LIST #include "TestOpsSyntax.cpp.inc" >(); }