//===- FuncOps.cpp - Func Dialect Operations ------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/Interfaces/FunctionImplementation.h" #include "mlir/Support/MathExtras.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" #include #include "mlir/Dialect/Func/IR/FuncOpsDialect.cpp.inc" using namespace mlir; using namespace mlir::func; //===----------------------------------------------------------------------===// // FuncDialect //===----------------------------------------------------------------------===// void FuncDialect::initialize() { addOperations< #define GET_OP_LIST #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc" >(); declarePromisedInterface(); declarePromisedInterface(); } /// Materialize a single constant operation from a given attribute value with /// the desired resultant type. Operation *FuncDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { if (ConstantOp::isBuildableWith(value, type)) return builder.create(loc, type, llvm::cast(value)); return nullptr; } //===----------------------------------------------------------------------===// // CallOp //===----------------------------------------------------------------------===// LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { // Check that the callee attribute was specified. auto fnAttr = (*this)->getAttrOfType("callee"); if (!fnAttr) return emitOpError("requires a 'callee' symbol reference attribute"); FuncOp fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr); if (!fn) return emitOpError() << "'" << fnAttr.getValue() << "' does not reference a valid function"; // Verify that the operand and result types match the callee. auto fnType = fn.getFunctionType(); if (fnType.getNumInputs() != getNumOperands()) return emitOpError("incorrect number of operands for callee"); for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) if (getOperand(i).getType() != fnType.getInput(i)) return emitOpError("operand type mismatch: expected operand type ") << fnType.getInput(i) << ", but provided " << getOperand(i).getType() << " for operand number " << i; if (fnType.getNumResults() != getNumResults()) return emitOpError("incorrect number of results for callee"); for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) if (getResult(i).getType() != fnType.getResult(i)) { auto diag = emitOpError("result type mismatch at index ") << i; diag.attachNote() << " op result types: " << getResultTypes(); diag.attachNote() << "function result types: " << fnType.getResults(); return diag; } return success(); } FunctionType CallOp::getCalleeType() { return FunctionType::get(getContext(), getOperandTypes(), getResultTypes()); } //===----------------------------------------------------------------------===// // CallIndirectOp //===----------------------------------------------------------------------===// /// Fold indirect calls that have a constant function as the callee operand. LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall, PatternRewriter &rewriter) { // Check that the callee is a constant callee. SymbolRefAttr calledFn; if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn))) return failure(); // Replace with a direct call. rewriter.replaceOpWithNewOp(indirectCall, calledFn, indirectCall.getResultTypes(), indirectCall.getArgOperands()); return success(); } //===----------------------------------------------------------------------===// // ConstantOp //===----------------------------------------------------------------------===// LogicalResult ConstantOp::verify() { StringRef fnName = getValue(); Type type = getType(); // Try to find the referenced function. auto fn = (*this)->getParentOfType().lookupSymbol(fnName); if (!fn) return emitOpError() << "reference to undefined function '" << fnName << "'"; // Check that the referenced function has the correct type. if (fn.getFunctionType() != type) return emitOpError("reference to function with mismatched type"); return success(); } OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } void ConstantOp::getAsmResultNames( function_ref setNameFn) { setNameFn(getResult(), "f"); } bool ConstantOp::isBuildableWith(Attribute value, Type type) { return llvm::isa(value) && llvm::isa(type); } //===----------------------------------------------------------------------===// // FuncOp //===----------------------------------------------------------------------===// FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, ArrayRef attrs) { OpBuilder builder(location->getContext()); OperationState state(location, getOperationName()); FuncOp::build(builder, state, name, type, attrs); return cast(Operation::create(state)); } FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, Operation::dialect_attr_range attrs) { SmallVector attrRef(attrs); return create(location, name, type, llvm::ArrayRef(attrRef)); } FuncOp FuncOp::create(Location location, StringRef name, FunctionType type, ArrayRef attrs, ArrayRef argAttrs) { FuncOp func = create(location, name, type, attrs); func.setAllArgAttrs(argAttrs); return func; } void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, FunctionType type, ArrayRef attrs, ArrayRef argAttrs) { state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); state.attributes.append(attrs.begin(), attrs.end()); state.addRegion(); if (argAttrs.empty()) return; assert(type.getNumInputs() == argAttrs.size()); function_interface_impl::addArgAndResultAttrs( builder, state, argAttrs, /*resultAttrs=*/std::nullopt, getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name)); } ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { auto buildFuncType = [](Builder &builder, ArrayRef argTypes, ArrayRef results, function_interface_impl::VariadicFlag, std::string &) { return builder.getFunctionType(argTypes, results); }; return function_interface_impl::parseFunctionOp( parser, result, /*allowVariadic=*/false, getFunctionTypeAttrName(result.name), buildFuncType, getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); } void FuncOp::print(OpAsmPrinter &p) { function_interface_impl::printFunctionOp( p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName()); } /// Clone the internal blocks from this function into dest and all attributes /// from this function to dest. void FuncOp::cloneInto(FuncOp dest, IRMapping &mapper) { // Add the attributes of this function to dest. llvm::MapVector newAttrMap; for (const auto &attr : dest->getAttrs()) newAttrMap.insert({attr.getName(), attr.getValue()}); for (const auto &attr : (*this)->getAttrs()) newAttrMap.insert({attr.getName(), attr.getValue()}); auto newAttrs = llvm::to_vector(llvm::map_range( newAttrMap, [](std::pair attrPair) { return NamedAttribute(attrPair.first, attrPair.second); })); dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs)); // Clone the body. getBody().cloneInto(&dest.getBody(), mapper); } /// Create a deep copy of this function and all of its blocks, remapping /// any operands that use values outside of the function using the map that is /// provided (leaving them alone if no entry is present). Replaces references /// to cloned sub-values with the corresponding value that is copied, and adds /// those mappings to the mapper. FuncOp FuncOp::clone(IRMapping &mapper) { // Create the new function. FuncOp newFunc = cast(getOperation()->cloneWithoutRegions()); // If the function has a body, then the user might be deleting arguments to // the function by specifying them in the mapper. If so, we don't add the // argument to the input type vector. if (!isExternal()) { FunctionType oldType = getFunctionType(); unsigned oldNumArgs = oldType.getNumInputs(); SmallVector newInputs; newInputs.reserve(oldNumArgs); for (unsigned i = 0; i != oldNumArgs; ++i) if (!mapper.contains(getArgument(i))) newInputs.push_back(oldType.getInput(i)); /// If any of the arguments were dropped, update the type and drop any /// necessary argument attributes. if (newInputs.size() != oldNumArgs) { newFunc.setType(FunctionType::get(oldType.getContext(), newInputs, oldType.getResults())); if (ArrayAttr argAttrs = getAllArgAttrs()) { SmallVector newArgAttrs; newArgAttrs.reserve(newInputs.size()); for (unsigned i = 0; i != oldNumArgs; ++i) if (!mapper.contains(getArgument(i))) newArgAttrs.push_back(argAttrs[i]); newFunc.setAllArgAttrs(newArgAttrs); } } } /// Clone the current function into the new one and return it. cloneInto(newFunc, mapper); return newFunc; } FuncOp FuncOp::clone() { IRMapping mapper; return clone(mapper); } //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// LogicalResult ReturnOp::verify() { auto function = cast((*this)->getParentOp()); // The operand number and types must match the function signature. const auto &results = function.getFunctionType().getResults(); if (getNumOperands() != results.size()) return emitOpError("has ") << getNumOperands() << " operands, but enclosing function (@" << function.getName() << ") returns " << results.size(); for (unsigned i = 0, e = results.size(); i != e; ++i) if (getOperand(i).getType() != results[i]) return emitError() << "type of return operand " << i << " (" << getOperand(i).getType() << ") doesn't match function result type (" << results[i] << ")" << " in function @" << function.getName(); return success(); } //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "mlir/Dialect/Func/IR/FuncOps.cpp.inc"