//===- EmulateWideInt.cpp - Wide integer operation emulation ----*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MathExtras.h" #include namespace mlir::memref { #define GEN_PASS_DEF_MEMREFEMULATEWIDEINT #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" } // namespace mlir::memref using namespace mlir; namespace { //===----------------------------------------------------------------------===// // ConvertMemRefAlloc //===----------------------------------------------------------------------===// struct ConvertMemRefAlloc final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type newTy = getTypeConverter()->convertType(op.getType()); if (!newTy) return rewriter.notifyMatchFailure( op->getLoc(), llvm::formatv("failed to convert memref type: {0}", op.getType())); rewriter.replaceOpWithNewOp( op, newTy, adaptor.getDynamicSizes(), adaptor.getSymbolOperands(), adaptor.getAlignmentAttr()); return success(); } }; //===----------------------------------------------------------------------===// // ConvertMemRefLoad //===----------------------------------------------------------------------===// struct ConvertMemRefLoad final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type newResTy = getTypeConverter()->convertType(op.getType()); if (!newResTy) return rewriter.notifyMatchFailure( op->getLoc(), llvm::formatv("failed to convert memref type: {0}", op.getMemRefType())); rewriter.replaceOpWithNewOp( op, newResTy, adaptor.getMemref(), adaptor.getIndices(), op.getNontemporal()); return success(); } }; //===----------------------------------------------------------------------===// // ConvertMemRefStore //===----------------------------------------------------------------------===// struct ConvertMemRefStore final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type newTy = getTypeConverter()->convertType(op.getMemRefType()); if (!newTy) return rewriter.notifyMatchFailure( op->getLoc(), llvm::formatv("failed to convert memref type: {0}", op.getMemRefType())); rewriter.replaceOpWithNewOp( op, adaptor.getValue(), adaptor.getMemref(), adaptor.getIndices(), op.getNontemporal()); return success(); } }; //===----------------------------------------------------------------------===// // Pass Definition //===----------------------------------------------------------------------===// struct EmulateWideIntPass final : memref::impl::MemRefEmulateWideIntBase { using MemRefEmulateWideIntBase::MemRefEmulateWideIntBase; void runOnOperation() override { if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) { signalPassFailure(); return; } Operation *op = getOperation(); MLIRContext *ctx = op->getContext(); arith::WideIntEmulationConverter typeConverter(widestIntSupported); memref::populateMemRefWideIntEmulationConversions(typeConverter); ConversionTarget target(*ctx); target.addDynamicallyLegalDialect< arith::ArithDialect, memref::MemRefDialect, vector::VectorDialect>( [&typeConverter](Operation *op) { return typeConverter.isLegal(op); }); RewritePatternSet patterns(ctx); // Add common pattenrs to support contants, functions, etc. arith::populateArithWideIntEmulationPatterns(typeConverter, patterns); memref::populateMemRefWideIntEmulationPatterns(typeConverter, patterns); if (failed(applyPartialConversion(op, target, std::move(patterns)))) signalPassFailure(); } }; } // end anonymous namespace //===----------------------------------------------------------------------===// // Public Interface Definition //===----------------------------------------------------------------------===// void memref::populateMemRefWideIntEmulationPatterns( arith::WideIntEmulationConverter &typeConverter, RewritePatternSet &patterns) { // Populate `memref.*` conversion patterns. patterns.add( typeConverter, patterns.getContext()); } void memref::populateMemRefWideIntEmulationConversions( arith::WideIntEmulationConverter &typeConverter) { typeConverter.addConversion( [&typeConverter](MemRefType ty) -> std::optional { auto intTy = dyn_cast(ty.getElementType()); if (!intTy) return ty; if (intTy.getIntOrFloatBitWidth() <= typeConverter.getMaxTargetIntBitWidth()) return ty; Type newElemTy = typeConverter.convertType(intTy); if (!newElemTy) return std::nullopt; return ty.cloneWith(std::nullopt, newElemTy); }); }