//===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "TestDialect.h" #include "TestAttributes.h" #include "TestInterfaces.h" #include "TestTypes.h" #include "mlir/Bytecode/BytecodeImplementation.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/ExtensibleDialect.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/ODSSupport.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/FunctionImplementation.h" #include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Base64.h" #include "llvm/Support/Casting.h" #include #include #include // Include this before the using namespace lines below to // test that we don't have namespace dependencies. #include "TestOpsDialect.cpp.inc" using namespace mlir; using namespace test; Attribute MyPropStruct::asAttribute(MLIRContext *ctx) const { return StringAttr::get(ctx, content); } LogicalResult MyPropStruct::setFromAttr(MyPropStruct &prop, Attribute attr, function_ref emitError) { StringAttr strAttr = dyn_cast(attr); if (!strAttr) { emitError() << "Expect StringAttr but got " << attr; return failure(); } prop.content = strAttr.getValue(); return success(); } llvm::hash_code MyPropStruct::hash() const { return hash_value(StringRef(content)); } static LogicalResult readFromMlirBytecode(DialectBytecodeReader &reader, MyPropStruct &prop) { StringRef str; if (failed(reader.readString(str))) return failure(); prop.content = str.str(); return success(); } static void writeToMlirBytecode(::mlir::DialectBytecodeWriter &writer, MyPropStruct &prop) { writer.writeOwnedString(prop.content); } static LogicalResult readFromMlirBytecode(DialectBytecodeReader &reader, MutableArrayRef prop) { uint64_t size; if (failed(reader.readVarInt(size))) return failure(); if (size != prop.size()) return reader.emitError("array size mismach when reading properties: ") << size << " vs expected " << prop.size(); for (auto &elt : prop) { uint64_t value; if (failed(reader.readVarInt(value))) return failure(); elt = value; } return success(); } static void writeToMlirBytecode(::mlir::DialectBytecodeWriter &writer, ArrayRef prop) { writer.writeVarInt(prop.size()); for (auto elt : prop) writer.writeVarInt(elt); } static LogicalResult setPropertiesFromAttribute(PropertiesWithCustomPrint &prop, Attribute attr, function_ref emitError); static DictionaryAttr getPropertiesAsAttribute(MLIRContext *ctx, const PropertiesWithCustomPrint &prop); static llvm::hash_code computeHash(const PropertiesWithCustomPrint &prop); static void customPrintProperties(OpAsmPrinter &p, const PropertiesWithCustomPrint &prop); static ParseResult customParseProperties(OpAsmParser &parser, PropertiesWithCustomPrint &prop); static LogicalResult setPropertiesFromAttribute(VersionedProperties &prop, Attribute attr, function_ref emitError); static DictionaryAttr getPropertiesAsAttribute(MLIRContext *ctx, const VersionedProperties &prop); static llvm::hash_code computeHash(const VersionedProperties &prop); static void customPrintProperties(OpAsmPrinter &p, const VersionedProperties &prop); static ParseResult customParseProperties(OpAsmParser &parser, VersionedProperties &prop); static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl> &caseRegions); static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions); void test::registerTestDialect(DialectRegistry ®istry) { registry.insert(); } //===----------------------------------------------------------------------===// // Dynamic operations //===----------------------------------------------------------------------===// std::unique_ptr getDynamicGenericOp(TestDialect *dialect) { return DynamicOpDefinition::get( "dynamic_generic", dialect, [](Operation *op) { return success(); }, [](Operation *op) { return success(); }); } std::unique_ptr getDynamicOneOperandTwoResultsOp(TestDialect *dialect) { return DynamicOpDefinition::get( "dynamic_one_operand_two_results", dialect, [](Operation *op) { if (op->getNumOperands() != 1) { op->emitOpError() << "expected 1 operand, but had " << op->getNumOperands(); return failure(); } if (op->getNumResults() != 2) { op->emitOpError() << "expected 2 results, but had " << op->getNumResults(); return failure(); } return success(); }, [](Operation *op) { return success(); }); } std::unique_ptr getDynamicCustomParserPrinterOp(TestDialect *dialect) { auto verifier = [](Operation *op) { if (op->getNumOperands() == 0 && op->getNumResults() == 0) return success(); op->emitError() << "operation should have no operands and no results"; return failure(); }; auto regionVerifier = [](Operation *op) { return success(); }; auto parser = [](OpAsmParser &parser, OperationState &state) { return parser.parseKeyword("custom_keyword"); }; auto printer = [](Operation *op, OpAsmPrinter &printer, llvm::StringRef) { printer << op->getName() << " custom_keyword"; }; return DynamicOpDefinition::get("dynamic_custom_parser_printer", dialect, verifier, regionVerifier, parser, printer); } //===----------------------------------------------------------------------===// // TestDialect //===----------------------------------------------------------------------===// static void testSideEffectOpGetEffect( Operation *op, SmallVectorImpl> &effects); // This is the implementation of a dialect fallback for `TestEffectOpInterface`. struct TestOpEffectInterfaceFallback : public TestEffectOpInterface::FallbackModel< TestOpEffectInterfaceFallback> { static bool classof(Operation *op) { bool isSupportedOp = op->getName().getStringRef() == "test.unregistered_side_effect_op"; assert(isSupportedOp && "Unexpected dispatch"); return isSupportedOp; } void getEffects(Operation *op, SmallVectorImpl> &effects) const { testSideEffectOpGetEffect(op, effects); } }; void TestDialect::initialize() { registerAttributes(); registerTypes(); addOperations< #define GET_OP_LIST #include "TestOps.cpp.inc" >(); registerOpsSyntax(); addOperations(); registerDynamicOp(getDynamicGenericOp(this)); registerDynamicOp(getDynamicOneOperandTwoResultsOp(this)); registerDynamicOp(getDynamicCustomParserPrinterOp(this)); registerInterfaces(); allowUnknownOperations(); // Instantiate our fallback op interface that we'll use on specific // unregistered op. fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback; } TestDialect::~TestDialect() { delete static_cast( fallbackEffectOpInterfaces); } Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { return builder.create(loc, type, value); } void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID, OperationName opName) { if (opName.getIdentifier() == "test.unregistered_side_effect_op" && typeID == TypeID::get()) return fallbackEffectOpInterfaces; return nullptr; } LogicalResult TestDialect::verifyOperationAttribute(Operation *op, NamedAttribute namedAttr) { if (namedAttr.getName() == "test.invalid_attr") return op->emitError() << "invalid to use 'test.invalid_attr'"; return success(); } LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, unsigned regionIndex, unsigned argIndex, NamedAttribute namedAttr) { if (namedAttr.getName() == "test.invalid_attr") return op->emitError() << "invalid to use 'test.invalid_attr'"; return success(); } LogicalResult TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, unsigned resultIndex, NamedAttribute namedAttr) { if (namedAttr.getName() == "test.invalid_attr") return op->emitError() << "invalid to use 'test.invalid_attr'"; return success(); } std::optional TestDialect::getParseOperationHook(StringRef opName) const { if (opName == "test.dialect_custom_printer") { return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { return parser.parseKeyword("custom_format"); }}; } if (opName == "test.dialect_custom_format_fallback") { return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { return parser.parseKeyword("custom_format_fallback"); }}; } if (opName == "test.dialect_custom_printer.with.dot") { return ParseOpHook{[](OpAsmParser &parser, OperationState &state) { return ParseResult::success(); }}; } return std::nullopt; } llvm::unique_function TestDialect::getOperationPrinter(Operation *op) const { StringRef opName = op->getName().getStringRef(); if (opName == "test.dialect_custom_printer") { return [](Operation *op, OpAsmPrinter &printer) { printer.getStream() << " custom_format"; }; } if (opName == "test.dialect_custom_format_fallback") { return [](Operation *op, OpAsmPrinter &printer) { printer.getStream() << " custom_format_fallback"; }; } return {}; } //===----------------------------------------------------------------------===// // TypedAttrOp //===----------------------------------------------------------------------===// /// Parse an attribute with a given type. static ParseResult parseAttrElideType(AsmParser &parser, TypeAttr type, Attribute &attr) { return parser.parseAttribute(attr, type.getValue()); } /// Print an attribute without its type. static void printAttrElideType(AsmPrinter &printer, Operation *op, TypeAttr type, Attribute attr) { printer.printAttributeWithoutType(attr); } //===----------------------------------------------------------------------===// // TestBranchOp //===----------------------------------------------------------------------===// SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) { assert(index == 0 && "invalid successor index"); return SuccessorOperands(getTargetOperandsMutable()); } //===----------------------------------------------------------------------===// // TestProducingBranchOp //===----------------------------------------------------------------------===// SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) { assert(index <= 1 && "invalid successor index"); if (index == 1) return SuccessorOperands(getFirstOperandsMutable()); return SuccessorOperands(getSecondOperandsMutable()); } //===----------------------------------------------------------------------===// // TestProducingBranchOp //===----------------------------------------------------------------------===// SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) { assert(index <= 1 && "invalid successor index"); if (index == 0) return SuccessorOperands(0, getSuccessOperandsMutable()); return SuccessorOperands(1, getErrorOperandsMutable()); } //===----------------------------------------------------------------------===// // TestDialectCanonicalizerOp //===----------------------------------------------------------------------===// static LogicalResult dialectCanonicalizationPattern(TestDialectCanonicalizerOp op, PatternRewriter &rewriter) { rewriter.replaceOpWithNewOp( op, rewriter.getI32IntegerAttr(42)); return success(); } void TestDialect::getCanonicalizationPatterns( RewritePatternSet &results) const { results.add(&dialectCanonicalizationPattern); } //===----------------------------------------------------------------------===// // TestCallOp //===----------------------------------------------------------------------===// LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { // Check that the callee attribute was specified. auto fnAttr = (*this)->getAttrOfType("callee"); if (!fnAttr) return emitOpError("requires a 'callee' symbol reference attribute"); if (!symbolTable.lookupNearestSymbolFrom(*this, fnAttr)) return emitOpError() << "'" << fnAttr.getValue() << "' does not reference a valid function"; return success(); } //===----------------------------------------------------------------------===// // ConversionFuncOp //===----------------------------------------------------------------------===// ParseResult ConversionFuncOp::parse(OpAsmParser &parser, OperationState &result) { auto buildFuncType = [](Builder &builder, ArrayRef argTypes, ArrayRef results, function_interface_impl::VariadicFlag, std::string &) { return builder.getFunctionType(argTypes, results); }; return function_interface_impl::parseFunctionOp( parser, result, /*allowVariadic=*/false, getFunctionTypeAttrName(result.name), buildFuncType, getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void ConversionFuncOp::print(OpAsmPrinter &p) { function_interface_impl::printFunctionOp( p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName()); } //===----------------------------------------------------------------------===// // TestFoldToCallOp //===----------------------------------------------------------------------===// namespace { struct FoldToCallOpPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(FoldToCallOp op, PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, TypeRange(), op.getCalleeAttr(), ValueRange()); return success(); } }; } // namespace void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); } //===----------------------------------------------------------------------===// // Test IsolatedRegionOp - parse passthrough region arguments. //===----------------------------------------------------------------------===// ParseResult IsolatedRegionOp::parse(OpAsmParser &parser, OperationState &result) { // Parse the input operand. OpAsmParser::Argument argInfo; argInfo.type = parser.getBuilder().getIndexType(); if (parser.parseOperand(argInfo.ssaName) || parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands)) return failure(); // Parse the body region, and reuse the operand info as the argument info. Region *body = result.addRegion(); return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true); } void IsolatedRegionOp::print(OpAsmPrinter &p) { p << ' '; p.printOperand(getOperand()); p.shadowRegionArgs(getRegion(), getOperand()); p << ' '; p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); } //===----------------------------------------------------------------------===// // Test SSACFGRegionOp //===----------------------------------------------------------------------===// RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { return RegionKind::SSACFG; } //===----------------------------------------------------------------------===// // Test GraphRegionOp //===----------------------------------------------------------------------===// RegionKind GraphRegionOp::getRegionKind(unsigned index) { return RegionKind::Graph; } //===----------------------------------------------------------------------===// // Test AffineScopeOp //===----------------------------------------------------------------------===// ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) { // Parse the body region, and reuse the operand info as the argument info. Region *body = result.addRegion(); return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); } void AffineScopeOp::print(OpAsmPrinter &p) { p << " "; p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); } //===----------------------------------------------------------------------===// // Test removing op with inner ops. //===----------------------------------------------------------------------===// namespace { struct TestRemoveOpWithInnerOps : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; void initialize() { setDebugName("TestRemoveOpWithInnerOps"); } LogicalResult matchAndRewrite(TestOpWithRegionPattern op, PatternRewriter &rewriter) const override { rewriter.eraseOp(op); return success(); } }; } // namespace void TestOpWithRegionPattern::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { results.add(context); } OpFoldResult TestOpWithRegionFold::fold(FoldAdaptor adaptor) { return getOperand(); } OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); } LogicalResult TestOpWithVariadicResultsAndFolder::fold( FoldAdaptor adaptor, SmallVectorImpl &results) { for (Value input : this->getOperands()) { results.push_back(input); } return success(); } OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) { // Exercise the fact that an operation created with createOrFold should be // allowed to access its parent block. assert(getOperation()->getBlock() && "expected that operation is not unlinked"); if (adaptor.getOp() && !getProperties().attr) { // The folder adds "attr" if not present. getProperties().attr = dyn_cast_or_null(adaptor.getOp()); return getResult(); } return {}; } OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) { int64_t sum = 0; if (auto value = dyn_cast_or_null(adaptor.getOp())) sum += value.getValue().getSExtValue(); for (Attribute attr : adaptor.getVariadic()) if (auto value = dyn_cast_or_null(attr)) sum += 2 * value.getValue().getSExtValue(); for (ArrayRef attrs : adaptor.getVarOfVar()) for (Attribute attr : attrs) if (auto value = dyn_cast_or_null(attr)) sum += 3 * value.getValue().getSExtValue(); sum += 4 * std::distance(adaptor.getBody().begin(), adaptor.getBody().end()); return IntegerAttr::get(getType(), sum); } LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( MLIRContext *, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { if (operands[0].getType() != operands[1].getType()) { return emitOptionalError(location, "operand type mismatch ", operands[0].getType(), " vs ", operands[1].getType()); } inferredReturnTypes.assign({operands[0].getType()}); return success(); } LogicalResult OpWithInferTypeAdaptorInterfaceOp::inferReturnTypes( MLIRContext *, std::optional location, OpWithInferTypeAdaptorInterfaceOp::Adaptor adaptor, SmallVectorImpl &inferredReturnTypes) { if (adaptor.getX().getType() != adaptor.getY().getType()) { return emitOptionalError(location, "operand type mismatch ", adaptor.getX().getType(), " vs ", adaptor.getY().getType()); } inferredReturnTypes.assign({adaptor.getX().getType()}); return success(); } // TODO: We should be able to only define either inferReturnType or // refineReturnType, currently only refineReturnType can be omitted. LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes( MLIRContext *context, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &returnTypes) { returnTypes.clear(); return OpWithRefineTypeInterfaceOp::refineReturnTypes( context, location, operands, attributes, properties, regions, returnTypes); } LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes( MLIRContext *, std::optional location, ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &returnTypes) { if (operands[0].getType() != operands[1].getType()) { return emitOptionalError(location, "operand type mismatch ", operands[0].getType(), " vs ", operands[1].getType()); } // TODO: Add helper to make this more concise to write. if (returnTypes.empty()) returnTypes.resize(1, nullptr); if (returnTypes[0] && returnTypes[0] != operands[0].getType()) return emitOptionalError(location, "required first operand and result to match"); returnTypes[0] = operands[0].getType(); return success(); } LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( MLIRContext *context, std::optional location, ValueShapeRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { // Create return type consisting of the last element of the first operand. auto operandType = operands.front().getType(); auto sval = dyn_cast(operandType); if (!sval) return emitOptionalError(location, "only shaped type operands allowed"); int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic; auto type = IntegerType::get(context, 17); Attribute encoding; if (auto rankedTy = dyn_cast(sval)) encoding = rankedTy.getEncoding(); inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding)); return success(); } LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( OpBuilder &builder, ValueRange operands, llvm::SmallVectorImpl &shapes) { shapes = SmallVector{ builder.createOrFold(getLoc(), operands.front(), 0)}; return success(); } LogicalResult OpWithShapedTypeInferTypeAdaptorInterfaceOp::inferReturnTypeComponents( MLIRContext *context, std::optional location, OpWithShapedTypeInferTypeAdaptorInterfaceOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { // Create return type consisting of the last element of the first operand. auto operandType = adaptor.getOperand1().getType(); auto sval = dyn_cast(operandType); if (!sval) return emitOptionalError(location, "only shaped type operands allowed"); int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic; auto type = IntegerType::get(context, 17); Attribute encoding; if (auto rankedTy = dyn_cast(sval)) encoding = rankedTy.getEncoding(); inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding)); return success(); } LogicalResult OpWithShapedTypeInferTypeAdaptorInterfaceOp::reifyReturnTypeShapes( OpBuilder &builder, ValueRange operands, llvm::SmallVectorImpl &shapes) { shapes = SmallVector{ builder.createOrFold(getLoc(), operands.front(), 0)}; return success(); } LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( OpBuilder &builder, ValueRange operands, llvm::SmallVectorImpl &shapes) { Location loc = getLoc(); shapes.reserve(operands.size()); for (Value operand : llvm::reverse(operands)) { auto rank = cast(operand.getType()).getRank(); auto currShape = llvm::to_vector<4>( llvm::map_range(llvm::seq(0, rank), [&](int64_t dim) -> Value { return builder.createOrFold(loc, operand, dim); })); shapes.push_back(builder.create( getLoc(), RankedTensorType::get({rank}, builder.getIndexType()), currShape)); } return success(); } LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { Location loc = getLoc(); shapes.reserve(getNumOperands()); for (Value operand : llvm::reverse(getOperands())) { auto tensorType = cast(operand.getType()); auto currShape = llvm::to_vector<4>(llvm::map_range( llvm::seq(0, tensorType.getRank()), [&](int64_t dim) -> OpFoldResult { return tensorType.isDynamicDim(dim) ? static_cast( builder.createOrFold(loc, operand, dim)) : static_cast( builder.getIndexAttr(tensorType.getDimSize(dim))); })); shapes.emplace_back(std::move(currShape)); } return success(); } //===----------------------------------------------------------------------===// // Test SideEffect interfaces //===----------------------------------------------------------------------===// namespace { /// A test resource for side effects. struct TestResource : public SideEffects::Resource::Base { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource) StringRef getName() final { return ""; } }; } // namespace static void testSideEffectOpGetEffect( Operation *op, SmallVectorImpl> &effects) { auto effectsAttr = op->getAttrOfType("effect_parameter"); if (!effectsAttr) return; effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); } void SideEffectOp::getEffects( SmallVectorImpl &effects) { // Check for an effects attribute on the op instance. ArrayAttr effectsAttr = (*this)->getAttrOfType("effects"); if (!effectsAttr) return; // If there is one, it is an array of dictionary attributes that hold // information on the effects of this operation. for (Attribute element : effectsAttr) { DictionaryAttr effectElement = cast(element); // Get the specific memory effect. MemoryEffects::Effect *effect = StringSwitch( cast(effectElement.get("effect")).getValue()) .Case("allocate", MemoryEffects::Allocate::get()) .Case("free", MemoryEffects::Free::get()) .Case("read", MemoryEffects::Read::get()) .Case("write", MemoryEffects::Write::get()); // Check for a non-default resource to use. SideEffects::Resource *resource = SideEffects::DefaultResource::get(); if (effectElement.get("test_resource")) resource = TestResource::get(); // Check for a result to affect. if (effectElement.get("on_result")) effects.emplace_back(effect, getResult(), resource); else if (Attribute ref = effectElement.get("on_reference")) effects.emplace_back(effect, cast(ref), resource); else effects.emplace_back(effect, resource); } } void SideEffectOp::getEffects( SmallVectorImpl &effects) { testSideEffectOpGetEffect(getOperation(), effects); } //===----------------------------------------------------------------------===// // StringAttrPrettyNameOp //===----------------------------------------------------------------------===// // This op has fancy handling of its SSA result name. ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser, OperationState &result) { // Add the result types. for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) result.addTypes(parser.getBuilder().getIntegerType(32)); if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) return failure(); // If the attribute dictionary contains no 'names' attribute, infer it from // the SSA name (if specified). bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { return attr.getName() == "names"; }); // If there was no name specified, check to see if there was a useful name // specified in the asm file. if (hadNames || parser.getNumResults() == 0) return success(); SmallVector names; auto *context = result.getContext(); for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { auto resultName = parser.getResultName(i); StringRef nameStr; if (!resultName.first.empty() && !isdigit(resultName.first[0])) nameStr = resultName.first; names.push_back(nameStr); } auto namesAttr = parser.getBuilder().getStrArrayAttr(names); result.attributes.push_back({StringAttr::get(context, "names"), namesAttr}); return success(); } void StringAttrPrettyNameOp::print(OpAsmPrinter &p) { // Note that we only need to print the "name" attribute if the asmprinter // result name disagrees with it. This can happen in strange cases, e.g. // when there are conflicts. bool namesDisagree = getNames().size() != getNumResults(); SmallString<32> resultNameStr; for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) { resultNameStr.clear(); llvm::raw_svector_ostream tmpStream(resultNameStr); p.printOperand(getResult(i), tmpStream); auto expectedName = dyn_cast(getNames()[i]); if (!expectedName || tmpStream.str().drop_front() != expectedName.getValue()) { namesDisagree = true; } } if (namesDisagree) p.printOptionalAttrDictWithKeyword((*this)->getAttrs()); else p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"}); } // We set the SSA name in the asm syntax to the contents of the name // attribute. void StringAttrPrettyNameOp::getAsmResultNames( function_ref setNameFn) { auto value = getNames(); for (size_t i = 0, e = value.size(); i != e; ++i) if (auto str = dyn_cast(value[i])) if (!str.getValue().empty()) setNameFn(getResult(i), str.getValue()); } void CustomResultsNameOp::getAsmResultNames( function_ref setNameFn) { ArrayAttr value = getNames(); for (size_t i = 0, e = value.size(); i != e; ++i) if (auto str = dyn_cast(value[i])) if (!str.empty()) setNameFn(getResult(i), str.getValue()); } //===----------------------------------------------------------------------===// // ResultTypeWithTraitOp //===----------------------------------------------------------------------===// LogicalResult ResultTypeWithTraitOp::verify() { if ((*this)->getResultTypes()[0].hasTrait()) return success(); return emitError("result type should have trait 'TestTypeTrait'"); } //===----------------------------------------------------------------------===// // AttrWithTraitOp //===----------------------------------------------------------------------===// LogicalResult AttrWithTraitOp::verify() { if (getAttr().hasTrait()) return success(); return emitError("'attr' attribute should have trait 'TestAttrTrait'"); } //===----------------------------------------------------------------------===// // RegionIfOp //===----------------------------------------------------------------------===// void RegionIfOp::print(OpAsmPrinter &p) { p << " "; p.printOperands(getOperands()); p << ": " << getOperandTypes(); p.printArrowTypeList(getResultTypes()); p << " then "; p.printRegion(getThenRegion(), /*printEntryBlockArgs=*/true, /*printBlockTerminators=*/true); p << " else "; p.printRegion(getElseRegion(), /*printEntryBlockArgs=*/true, /*printBlockTerminators=*/true); p << " join "; p.printRegion(getJoinRegion(), /*printEntryBlockArgs=*/true, /*printBlockTerminators=*/true); } ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector operandInfos; SmallVector operandTypes; result.regions.reserve(3); Region *thenRegion = result.addRegion(); Region *elseRegion = result.addRegion(); Region *joinRegion = result.addRegion(); // Parse operand, type and arrow type lists. if (parser.parseOperandList(operandInfos) || parser.parseColonTypeList(operandTypes) || parser.parseArrowTypeList(result.types)) return failure(); // Parse all attached regions. if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) return failure(); return parser.resolveOperands(operandInfos, operandTypes, parser.getCurrentLocation(), result.operands); } OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) { assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) && "invalid region index"); return getOperands(); } void RegionIfOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) { // We always branch to the join region. if (!point.isParent()) { if (point != getJoinRegion()) regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs())); else regions.push_back(RegionSuccessor(getResults())); return; } // The then and else regions are the entry regions of this op. regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs())); regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs())); } void RegionIfOp::getRegionInvocationBounds( ArrayRef operands, SmallVectorImpl &invocationBounds) { // Each region is invoked at most once. invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1}); } //===----------------------------------------------------------------------===// // AnyCondOp //===----------------------------------------------------------------------===// void AnyCondOp::getSuccessorRegions(RegionBranchPoint point, SmallVectorImpl ®ions) { // The parent op branches into the only region, and the region branches back // to the parent op. if (point.isParent()) regions.emplace_back(&getRegion()); else regions.emplace_back(getResults()); } void AnyCondOp::getRegionInvocationBounds( ArrayRef operands, SmallVectorImpl &invocationBounds) { invocationBounds.emplace_back(1, 1); } //===----------------------------------------------------------------------===// // LoopBlockOp //===----------------------------------------------------------------------===// void LoopBlockOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) { regions.emplace_back(&getBody(), getBody().getArguments()); if (point.isParent()) return; regions.emplace_back((*this)->getResults()); } OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) { assert(point == getBody()); return MutableOperandRange(getInitMutable()); } //===----------------------------------------------------------------------===// // LoopBlockTerminatorOp //===----------------------------------------------------------------------===// MutableOperandRange LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) { if (point.isParent()) return getExitArgMutable(); return getNextIterArgMutable(); } //===----------------------------------------------------------------------===// // SwitchWithNoBreakOp //===----------------------------------------------------------------------===// void TestNoTerminatorOp::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) {} //===----------------------------------------------------------------------===// // SingleNoTerminatorCustomAsmOp //===----------------------------------------------------------------------===// ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser, OperationState &state) { Region *body = state.addRegion(); if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) return failure(); return success(); } void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) { printer.printRegion( getRegion(), /*printEntryBlockArgs=*/false, // This op has a single block without terminators. But explicitly mark // as not printing block terminators for testing. /*printBlockTerminators=*/false); } //===----------------------------------------------------------------------===// // TestVerifiersOp //===----------------------------------------------------------------------===// LogicalResult TestVerifiersOp::verify() { if (!getRegion().hasOneBlock()) return emitOpError("`hasOneBlock` trait hasn't been verified"); Operation *definingOp = getInput().getDefiningOp(); if (definingOp && failed(mlir::verify(definingOp))) return emitOpError("operand hasn't been verified"); // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier // loop. mlir::emitRemark(getLoc(), "success run of verifier"); return success(); } LogicalResult TestVerifiersOp::verifyRegions() { if (!getRegion().hasOneBlock()) return emitOpError("`hasOneBlock` trait hasn't been verified"); for (Block &block : getRegion()) for (Operation &op : block) if (failed(mlir::verify(&op))) return emitOpError("nested op hasn't been verified"); // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier // loop. mlir::emitRemark(getLoc(), "success run of region verifier"); return success(); } //===----------------------------------------------------------------------===// // Test InferIntRangeInterface //===----------------------------------------------------------------------===// void TestWithBoundsOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRanges) { setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()}); } ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser, OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); // Parse the input argument OpAsmParser::Argument argInfo; argInfo.type = parser.getBuilder().getIndexType(); if (failed(parser.parseArgument(argInfo))) return failure(); // Parse the body region, and reuse the operand info as the argument info. Region *body = result.addRegion(); return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false); } void TestWithBoundsRegionOp::print(OpAsmPrinter &p) { p.printOptionalAttrDict((*this)->getAttrs()); p << ' '; p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{}, /*omitType=*/true); p << ' '; p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); } void TestWithBoundsRegionOp::inferResultRanges( ArrayRef argRanges, SetIntRangeFn setResultRanges) { Value arg = getRegion().getArgument(0); setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()}); } void TestIncrementOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRanges) { const ConstantIntRanges &range = argRanges[0]; APInt one(range.umin().getBitWidth(), 1); setResultRanges(getResult(), {range.umin().uadd_sat(one), range.umax().uadd_sat(one), range.smin().sadd_sat(one), range.smax().sadd_sat(one)}); } void TestReflectBoundsOp::inferResultRanges( ArrayRef argRanges, SetIntRangeFn setResultRanges) { const ConstantIntRanges &range = argRanges[0]; MLIRContext *ctx = getContext(); Builder b(ctx); setUminAttr(b.getIndexAttr(range.umin().getZExtValue())); setUmaxAttr(b.getIndexAttr(range.umax().getZExtValue())); setSminAttr(b.getIndexAttr(range.smin().getSExtValue())); setSmaxAttr(b.getIndexAttr(range.smax().getSExtValue())); setResultRanges(getResult(), range); } OpFoldResult ManualCppOpWithFold::fold(ArrayRef attributes) { // Just a simple fold for testing purposes that reads an operands constant // value and returns it. if (!attributes.empty()) return attributes.front(); return nullptr; } static LogicalResult setPropertiesFromAttribute(PropertiesWithCustomPrint &prop, Attribute attr, function_ref emitError) { DictionaryAttr dict = dyn_cast(attr); if (!dict) { emitError() << "expected DictionaryAttr to set TestProperties"; return failure(); } auto label = dict.getAs("label"); if (!label) { emitError() << "expected StringAttr for key `label`"; return failure(); } auto valueAttr = dict.getAs("value"); if (!valueAttr) { emitError() << "expected IntegerAttr for key `value`"; return failure(); } prop.label = std::make_shared(label.getValue()); prop.value = valueAttr.getValue().getSExtValue(); return success(); } static DictionaryAttr getPropertiesAsAttribute(MLIRContext *ctx, const PropertiesWithCustomPrint &prop) { SmallVector attrs; Builder b{ctx}; attrs.push_back(b.getNamedAttr("label", b.getStringAttr(*prop.label))); attrs.push_back(b.getNamedAttr("value", b.getI32IntegerAttr(prop.value))); return b.getDictionaryAttr(attrs); } static llvm::hash_code computeHash(const PropertiesWithCustomPrint &prop) { return llvm::hash_combine(prop.value, StringRef(*prop.label)); } static void customPrintProperties(OpAsmPrinter &p, const PropertiesWithCustomPrint &prop) { p.printKeywordOrString(*prop.label); p << " is " << prop.value; } static ParseResult customParseProperties(OpAsmParser &parser, PropertiesWithCustomPrint &prop) { std::string label; if (parser.parseKeywordOrString(&label) || parser.parseKeyword("is") || parser.parseInteger(prop.value)) return failure(); prop.label = std::make_shared(std::move(label)); return success(); } static ParseResult parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, SmallVectorImpl> &caseRegions) { SmallVector caseValues; while (succeeded(p.parseOptionalKeyword("case"))) { int64_t value; Region ®ion = *caseRegions.emplace_back(std::make_unique()); if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{})) return failure(); caseValues.push_back(value); } cases = p.getBuilder().getDenseI64ArrayAttr(caseValues); return success(); } static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions) { for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) { p.printNewline(); p << "case " << value << ' '; p.printRegion(*region, /*printEntryBlockArgs=*/false); } } static LogicalResult setPropertiesFromAttribute(VersionedProperties &prop, Attribute attr, function_ref emitError) { DictionaryAttr dict = dyn_cast(attr); if (!dict) { emitError() << "expected DictionaryAttr to set VersionedProperties"; return failure(); } auto value1Attr = dict.getAs("value1"); if (!value1Attr) { emitError() << "expected IntegerAttr for key `value1`"; return failure(); } auto value2Attr = dict.getAs("value2"); if (!value2Attr) { emitError() << "expected IntegerAttr for key `value2`"; return failure(); } prop.value1 = value1Attr.getValue().getSExtValue(); prop.value2 = value2Attr.getValue().getSExtValue(); return success(); } static DictionaryAttr getPropertiesAsAttribute(MLIRContext *ctx, const VersionedProperties &prop) { SmallVector attrs; Builder b{ctx}; attrs.push_back(b.getNamedAttr("value1", b.getI32IntegerAttr(prop.value1))); attrs.push_back(b.getNamedAttr("value2", b.getI32IntegerAttr(prop.value2))); return b.getDictionaryAttr(attrs); } static llvm::hash_code computeHash(const VersionedProperties &prop) { return llvm::hash_combine(prop.value1, prop.value2); } static void customPrintProperties(OpAsmPrinter &p, const VersionedProperties &prop) { p << prop.value1 << " | " << prop.value2; } static ParseResult customParseProperties(OpAsmParser &parser, VersionedProperties &prop) { if (parser.parseInteger(prop.value1) || parser.parseVerticalBar() || parser.parseInteger(prop.value2)) return failure(); return success(); } static bool parseUsingPropertyInCustom(OpAsmParser &parser, int64_t value[3]) { return parser.parseLSquare() || parser.parseInteger(value[0]) || parser.parseComma() || parser.parseInteger(value[1]) || parser.parseComma() || parser.parseInteger(value[2]) || parser.parseRSquare(); } static void printUsingPropertyInCustom(OpAsmPrinter &printer, Operation *op, ArrayRef value) { printer << '[' << value << ']'; } static bool parseIntProperty(OpAsmParser &parser, int64_t &value) { return failed(parser.parseInteger(value)); } static void printIntProperty(OpAsmPrinter &printer, Operation *op, int64_t value) { printer << value; } static bool parseSumProperty(OpAsmParser &parser, int64_t &second, int64_t first) { int64_t sum; auto loc = parser.getCurrentLocation(); if (parser.parseInteger(second) || parser.parseEqual() || parser.parseInteger(sum)) return true; if (sum != second + first) { parser.emitError(loc, "Expected sum to equal first + second"); return true; } return false; } static void printSumProperty(OpAsmPrinter &printer, Operation *op, int64_t second, int64_t first) { printer << second << " = " << (second + first); } //===----------------------------------------------------------------------===// // Tensor/Buffer Ops //===----------------------------------------------------------------------===// void ReadBufferOp::getEffects( SmallVectorImpl> &effects) { // The buffer operand is read. effects.emplace_back(MemoryEffects::Read::get(), getBuffer(), SideEffects::DefaultResource::get()); // The buffer contents are dumped. effects.emplace_back(MemoryEffects::Write::get(), SideEffects::DefaultResource::get()); } //===----------------------------------------------------------------------===// // Test Dataflow //===----------------------------------------------------------------------===// CallInterfaceCallable TestCallAndStoreOp::getCallableForCallee() { return getCallee(); } void TestCallAndStoreOp::setCalleeFromCallable(CallInterfaceCallable callee) { setCalleeAttr(callee.get()); } Operation::operand_range TestCallAndStoreOp::getArgOperands() { return getCalleeOperands(); } MutableOperandRange TestCallAndStoreOp::getArgOperandsMutable() { return getCalleeOperandsMutable(); } CallInterfaceCallable TestCallOnDeviceOp::getCallableForCallee() { return getCallee(); } void TestCallOnDeviceOp::setCalleeFromCallable(CallInterfaceCallable callee) { setCalleeAttr(callee.get()); } Operation::operand_range TestCallOnDeviceOp::getArgOperands() { return getForwardedOperands(); } MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() { return getForwardedOperandsMutable(); } void TestStoreWithARegion::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) { if (point.isParent()) regions.emplace_back(&getBody(), getBody().front().getArguments()); else regions.emplace_back(); } void TestStoreWithALoopRegion::getSuccessorRegions( RegionBranchPoint point, SmallVectorImpl ®ions) { // Both the operation itself and the region may be branching into the body or // back into the operation itself. It is possible for the operation not to // enter the body. regions.emplace_back( RegionSuccessor(&getBody(), getBody().front().getArguments())); regions.emplace_back(); } LogicalResult TestVersionedOpA::readProperties(::mlir::DialectBytecodeReader &reader, ::mlir::OperationState &state) { auto &prop = state.getOrAddProperties(); if (::mlir::failed(reader.readAttribute(prop.dims))) return ::mlir::failure(); // Check if we have a version. If not, assume we are parsing the current // version. auto maybeVersion = reader.getDialectVersion(); if (succeeded(maybeVersion)) { // If version is less than 2.0, there is no additional attribute to parse. // We can materialize missing properties post parsing before verification. const auto *version = reinterpret_cast(*maybeVersion); if ((version->major_ < 2)) { return success(); } } if (::mlir::failed(reader.readAttribute(prop.modifier))) return ::mlir::failure(); return ::mlir::success(); } void TestVersionedOpA::writeProperties(::mlir::DialectBytecodeWriter &writer) { auto &prop = getProperties(); writer.writeAttribute(prop.dims); auto maybeVersion = writer.getDialectVersion(); if (succeeded(maybeVersion)) { // If version is less than 2.0, there is no additional attribute to write. const auto *version = reinterpret_cast(*maybeVersion); if ((version->major_ < 2)) { llvm::outs() << "downgrading op properties...\n"; return; } } writer.writeAttribute(prop.modifier); } ::mlir::LogicalResult TestOpWithVersionedProperties::readFromMlirBytecode( ::mlir::DialectBytecodeReader &reader, test::VersionedProperties &prop) { uint64_t value1, value2 = 0; if (failed(reader.readVarInt(value1))) return failure(); // Check if we have a version. If not, assume we are parsing the current // version. auto maybeVersion = reader.getDialectVersion(); bool needToParseAnotherInt = true; if (succeeded(maybeVersion)) { // If version is less than 2.0, there is no additional attribute to parse. // We can materialize missing properties post parsing before verification. const auto *version = reinterpret_cast(*maybeVersion); if ((version->major_ < 2)) needToParseAnotherInt = false; } if (needToParseAnotherInt && failed(reader.readVarInt(value2))) return failure(); prop.value1 = value1; prop.value2 = value2; return success(); } void TestOpWithVersionedProperties::writeToMlirBytecode( ::mlir::DialectBytecodeWriter &writer, const test::VersionedProperties &prop) { writer.writeVarInt(prop.value1); writer.writeVarInt(prop.value2); } #include "TestOpEnums.cpp.inc" #include "TestOpInterfaces.cpp.inc" #include "TestTypeInterfaces.cpp.inc" #define GET_OP_CLASSES #include "TestOps.cpp.inc"