//===- ArithToAMDGPU.cpp - Arith to AMDGPU dialect conversion ---------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { #define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; namespace { struct ArithToAMDGPUConversionPass final : impl::ArithToAMDGPUConversionPassBase { using impl::ArithToAMDGPUConversionPassBase< ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase; void runOnOperation() override; }; struct ExtFOnFloat8RewritePattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult match(arith::ExtFOp op) const override; void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override; }; struct TruncFToFloat8RewritePattern final : OpRewritePattern { bool saturateFP8 = false; TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8) : OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8) {} LogicalResult match(arith::TruncFOp op) const override; void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override; }; } // end namespace static Value castF32To(Type elementType, Value f32, Location loc, PatternRewriter &rewriter) { if (elementType.isF32()) return f32; if (elementType.getIntOrFloatBitWidth() < 32) return rewriter.create(loc, elementType, f32); if (elementType.getIntOrFloatBitWidth() > 32) return rewriter.create(loc, elementType, f32); llvm_unreachable("The only 32-bit float type is f32"); } LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const { Type inType = op.getIn().getType(); if (auto inVecType = inType.dyn_cast()) { if (inVecType.isScalable()) return failure(); if (inVecType.getShape().size() > 1) // Multi-dimensional vectors are currently unsupported. return failure(); inType = inVecType.getElementType(); } return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ()); } void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const { Location loc = op.getLoc(); Value in = op.getIn(); Type outElemType = getElementTypeOrSelf(op.getOut().getType()); if (!in.getType().isa()) { Value asFloat = rewriter.create( loc, rewriter.getF32Type(), in, 0); Value result = castF32To(outElemType, asFloat, loc, rewriter); return rewriter.replaceOp(op, result); } VectorType inType = in.getType().cast(); int64_t numElements = inType.getNumElements(); Value zero = rewriter.createOrFold( loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); Value result = rewriter.createOrFold(loc, op.getOut().getType(), zero); if (inType.getShape().empty()) { Value scalarIn = rewriter.create(loc, in, ArrayRef{}); // Recurse to send the 0-D vector case to the 1-D vector case Value scalarExt = rewriter.create(loc, outElemType, scalarIn); result = rewriter.create(loc, scalarExt, zero, ArrayRef{}); return rewriter.replaceOp(op, result); } for (int64_t i = 0; i < numElements; i += 4) { int64_t elemsThisOp = std::min(numElements, i + 4) - i; Value inSlice = rewriter.create( loc, in, i, elemsThisOp, 1); for (int64_t j = 0; j < elemsThisOp; ++j) { Value asFloat = rewriter.create( loc, rewriter.getF32Type(), inSlice, j); Value asType = castF32To(outElemType, asFloat, loc, rewriter); result = rewriter.create(loc, asType, result, i + j); } } rewriter.replaceOp(op, result); } static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) { Type type = value.getType(); if (type.isF32()) return value; if (type.getIntOrFloatBitWidth() < 32) return rewriter.create(loc, rewriter.getF32Type(), value); if (type.getIntOrFloatBitWidth() > 32) return rewriter.create(loc, rewriter.getF32Type(), value); llvm_unreachable("The only 32-bit float type is f32"); } // If `in` is a finite value, clamp it between the maximum and minimum values // of `outElemType` so that subsequent conversion instructions don't // overflow those out-of-range values to NaN. These semantics are commonly // used in machine-learning contexts where failure to clamp would lead to // excessive NaN production. static Value clampInput(PatternRewriter &rewriter, Location loc, Type outElemType, Value source) { Type sourceType = source.getType(); const llvm::fltSemantics &sourceSem = cast(getElementTypeOrSelf(sourceType)).getFloatSemantics(); const llvm::fltSemantics &targetSem = cast(outElemType).getFloatSemantics(); APFloat min = APFloat::getLargest(targetSem, /*Negative=*/true); APFloat max = APFloat::getLargest(targetSem, /*Negative=*/false); bool ignoredLosesInfo = false; // We can ignore conversion failures here because this conversion promotes // from a smaller type to a larger one - ex. there can be no loss of precision // when casting fp8 to f16. (void)min.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo); (void)max.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo); Value minCst = createScalarOrSplatConstant(rewriter, loc, sourceType, min); Value maxCst = createScalarOrSplatConstant(rewriter, loc, sourceType, max); Value inf = createScalarOrSplatConstant( rewriter, loc, sourceType, APFloat::getInf(sourceSem, /*Negative=*/false)); Value negInf = createScalarOrSplatConstant( rewriter, loc, sourceType, APFloat::getInf(sourceSem, /*Negative=*/true)); Value isInf = rewriter.createOrFold( loc, arith::CmpFPredicate::OEQ, source, inf); Value isNegInf = rewriter.createOrFold( loc, arith::CmpFPredicate::OEQ, source, negInf); Value isNan = rewriter.createOrFold( loc, arith::CmpFPredicate::UNO, source, source); Value isNonFinite = rewriter.create( loc, rewriter.create(loc, isInf, isNegInf), isNan); Value clampedBelow = rewriter.create(loc, source, minCst); Value clamped = rewriter.create(loc, clampedBelow, maxCst); Value res = rewriter.create(loc, isNonFinite, source, clamped); return res; } LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const { Type outType = op.getOut().getType(); if (auto outVecType = outType.dyn_cast()) { if (outVecType.isScalable()) return failure(); if (outVecType.getShape().size() > 1) // Multi-dimensional vectors are currently unsupported. return failure(); outType = outVecType.getElementType(); } auto inType = dyn_cast(getElementTypeOrSelf(op.getIn().getType())); if (inType && inType.getWidth() <= 8 && saturateFP8) // Conversion between 8-bit floats is not supported with truncation enabled. return failure(); return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ()); } void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const { Location loc = op.getLoc(); Value in = op.getIn(); Type outElemType = getElementTypeOrSelf(op.getOut().getType()); if (saturateFP8) in = clampInput(rewriter, loc, outElemType, in); VectorType truncResType = VectorType::get(4, outElemType); if (!in.getType().isa()) { Value asFloat = castToF32(in, loc, rewriter); Value asF8s = rewriter.create( loc, truncResType, asFloat, /*sourceB=*/nullptr, 0, /*existing=*/nullptr); Value result = rewriter.create(loc, asF8s, 0); return rewriter.replaceOp(op, result); } VectorType outType = op.getOut().getType().cast(); int64_t numElements = outType.getNumElements(); Value zero = rewriter.createOrFold( loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0)); Value result = rewriter.createOrFold(loc, outType, zero); if (outType.getShape().empty()) { Value scalarIn = rewriter.create(loc, in, ArrayRef{}); // Recurse to send the 0-D vector case to the 1-D vector case Value scalarTrunc = rewriter.create(loc, outElemType, scalarIn); result = rewriter.create(loc, scalarTrunc, zero, ArrayRef{}); return rewriter.replaceOp(op, result); } for (int64_t i = 0; i < numElements; i += 4) { int64_t elemsThisOp = std::min(numElements, i + 4) - i; Value thisResult = nullptr; for (int64_t j = 0; j < elemsThisOp; j += 2) { Value elemA = rewriter.create(loc, in, i + j); Value asFloatA = castToF32(elemA, loc, rewriter); Value asFloatB = nullptr; if (j + 1 < elemsThisOp) { Value elemB = rewriter.create(loc, in, i + j + 1); asFloatB = castToF32(elemB, loc, rewriter); } thisResult = rewriter.create( loc, truncResType, asFloatA, asFloatB, j / 2, thisResult); } if (elemsThisOp < 4) thisResult = rewriter.create( loc, thisResult, 0, elemsThisOp, 1); result = rewriter.create(loc, thisResult, result, i, 1); } rewriter.replaceOp(op, result); } void mlir::arith::populateArithToAMDGPUConversionPatterns( RewritePatternSet &patterns, bool saturateFP8TruncF) { patterns.add(patterns.getContext()); patterns.add(patterns.getContext(), saturateFP8TruncF); } void ArithToAMDGPUConversionPass::runOnOperation() { Operation *op = getOperation(); RewritePatternSet patterns(op->getContext()); arith::populateArithToAMDGPUConversionPatterns(patterns, saturateFP8Truncf); if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) return signalPassFailure(); }