//===-- OneToNTypeFuncConversions.cpp - Func 1:N type conversion-*- C++ -*-===// // // Licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // The patterns in this file are heavily inspired (and copied from) // convertFuncOpTypes in lib/Transforms/Utils/DialectConversion.cpp and the // patterns in lib/Dialect/Func/Transforms/FuncConversions.cpp but work for 1:N // type conversions. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Transforms/OneToNTypeConversion.h" using namespace mlir; using namespace mlir::func; namespace { class ConvertTypesInFuncCallOp : public OneToNOpConversionPattern { public: using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult matchAndRewrite(CallOp op, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override { Location loc = op->getLoc(); const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); // Nothing to do if the op doesn't have any non-identity conversions for its // operands or results. if (!adaptor.getOperandMapping().hasNonIdentityConversion() && !resultMapping.hasNonIdentityConversion()) return failure(); // Create new CallOp. auto newOp = rewriter.create(loc, resultMapping.getConvertedTypes(), adaptor.getFlatOperands()); newOp->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newOp->getResults(), resultMapping); return success(); } }; class ConvertTypesInFuncFuncOp : public OneToNOpConversionPattern { public: using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult matchAndRewrite(FuncOp op, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override { auto *typeConverter = getTypeConverter(); // Construct mapping for function arguments. OneToNTypeMapping argumentMapping(op.getArgumentTypes()); if (failed(typeConverter->computeTypeMapping(op.getArgumentTypes(), argumentMapping))) return failure(); // Construct mapping for function results. OneToNTypeMapping funcResultMapping(op.getResultTypes()); if (failed(typeConverter->computeTypeMapping(op.getResultTypes(), funcResultMapping))) return failure(); // Nothing to do if the op doesn't have any non-identity conversions for its // operands or results. if (!argumentMapping.hasNonIdentityConversion() && !funcResultMapping.hasNonIdentityConversion()) return failure(); // Update the function signature in-place. auto newType = FunctionType::get(rewriter.getContext(), argumentMapping.getConvertedTypes(), funcResultMapping.getConvertedTypes()); rewriter.modifyOpInPlace(op, [&] { op.setType(newType); }); // Update block signatures. if (!op.isExternal()) { Region *region = &op.getBody(); Block *block = ®ion->front(); rewriter.applySignatureConversion(block, argumentMapping); } return success(); } }; class ConvertTypesInFuncReturnOp : public OneToNOpConversionPattern { public: using OneToNOpConversionPattern::OneToNOpConversionPattern; LogicalResult matchAndRewrite(ReturnOp op, OpAdaptor adaptor, OneToNPatternRewriter &rewriter) const override { // Nothing to do if there is no non-identity conversion. if (!adaptor.getOperandMapping().hasNonIdentityConversion()) return failure(); // Convert operands. rewriter.modifyOpInPlace( op, [&] { op->setOperands(adaptor.getFlatOperands()); }); return success(); } }; } // namespace namespace mlir { void populateFuncTypeConversionPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add< // clang-format off ConvertTypesInFuncCallOp, ConvertTypesInFuncFuncOp, ConvertTypesInFuncReturnOp // clang-format on >(typeConverter, patterns.getContext()); } } // namespace mlir