263 lines
11 KiB
C++
263 lines
11 KiB
C++
|
//===- ArithToAMDGPU.cpp - Arith to AMDGPU dialect conversion ---------===//
|
||
|
//
|
||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
|
||
|
|
||
|
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
|
||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||
|
#include "mlir/Dialect/Arith/Utils/Utils.h"
|
||
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||
|
#include "mlir/IR/BuiltinTypes.h"
|
||
|
#include "mlir/IR/PatternMatch.h"
|
||
|
#include "mlir/IR/TypeUtilities.h"
|
||
|
#include "mlir/Pass/Pass.h"
|
||
|
#include "mlir/Support/LogicalResult.h"
|
||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||
|
|
||
|
namespace mlir {
|
||
|
#define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS
|
||
|
#include "mlir/Conversion/Passes.h.inc"
|
||
|
} // namespace mlir
|
||
|
|
||
|
using namespace mlir;
|
||
|
|
||
|
namespace {
|
||
|
struct ArithToAMDGPUConversionPass final
|
||
|
: impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
|
||
|
using impl::ArithToAMDGPUConversionPassBase<
|
||
|
ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase;
|
||
|
|
||
|
void runOnOperation() override;
|
||
|
};
|
||
|
|
||
|
struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
|
||
|
using OpRewritePattern::OpRewritePattern;
|
||
|
|
||
|
LogicalResult match(arith::ExtFOp op) const override;
|
||
|
void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
|
||
|
};
|
||
|
|
||
|
struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
|
||
|
bool saturateFP8 = false;
|
||
|
TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8)
|
||
|
: OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8) {}
|
||
|
|
||
|
LogicalResult match(arith::TruncFOp op) const override;
|
||
|
void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
|
||
|
};
|
||
|
} // end namespace
|
||
|
|
||
|
static Value castF32To(Type elementType, Value f32, Location loc,
|
||
|
PatternRewriter &rewriter) {
|
||
|
if (elementType.isF32())
|
||
|
return f32;
|
||
|
if (elementType.getIntOrFloatBitWidth() < 32)
|
||
|
return rewriter.create<arith::TruncFOp>(loc, elementType, f32);
|
||
|
if (elementType.getIntOrFloatBitWidth() > 32)
|
||
|
return rewriter.create<arith::ExtFOp>(loc, elementType, f32);
|
||
|
llvm_unreachable("The only 32-bit float type is f32");
|
||
|
}
|
||
|
|
||
|
LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
|
||
|
Type inType = op.getIn().getType();
|
||
|
if (auto inVecType = inType.dyn_cast<VectorType>()) {
|
||
|
if (inVecType.isScalable())
|
||
|
return failure();
|
||
|
if (inVecType.getShape().size() > 1)
|
||
|
// Multi-dimensional vectors are currently unsupported.
|
||
|
return failure();
|
||
|
inType = inVecType.getElementType();
|
||
|
}
|
||
|
return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
|
||
|
}
|
||
|
|
||
|
void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
|
||
|
PatternRewriter &rewriter) const {
|
||
|
Location loc = op.getLoc();
|
||
|
Value in = op.getIn();
|
||
|
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
|
||
|
if (!in.getType().isa<VectorType>()) {
|
||
|
Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
|
||
|
loc, rewriter.getF32Type(), in, 0);
|
||
|
Value result = castF32To(outElemType, asFloat, loc, rewriter);
|
||
|
return rewriter.replaceOp(op, result);
|
||
|
}
|
||
|
VectorType inType = in.getType().cast<VectorType>();
|
||
|
int64_t numElements = inType.getNumElements();
|
||
|
Value zero = rewriter.createOrFold<arith::ConstantOp>(
|
||
|
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
|
||
|
Value result =
|
||
|
rewriter.createOrFold<vector::SplatOp>(loc, op.getOut().getType(), zero);
|
||
|
if (inType.getShape().empty()) {
|
||
|
Value scalarIn =
|
||
|
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
|
||
|
// Recurse to send the 0-D vector case to the 1-D vector case
|
||
|
Value scalarExt =
|
||
|
rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
|
||
|
result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero,
|
||
|
ArrayRef<int64_t>{});
|
||
|
return rewriter.replaceOp(op, result);
|
||
|
}
|
||
|
for (int64_t i = 0; i < numElements; i += 4) {
|
||
|
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
|
||
|
Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
|
||
|
loc, in, i, elemsThisOp, 1);
|
||
|
for (int64_t j = 0; j < elemsThisOp; ++j) {
|
||
|
Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
|
||
|
loc, rewriter.getF32Type(), inSlice, j);
|
||
|
Value asType = castF32To(outElemType, asFloat, loc, rewriter);
|
||
|
result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
|
||
|
}
|
||
|
}
|
||
|
rewriter.replaceOp(op, result);
|
||
|
}
|
||
|
|
||
|
static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
|
||
|
Type type = value.getType();
|
||
|
if (type.isF32())
|
||
|
return value;
|
||
|
if (type.getIntOrFloatBitWidth() < 32)
|
||
|
return rewriter.create<arith::ExtFOp>(loc, rewriter.getF32Type(), value);
|
||
|
if (type.getIntOrFloatBitWidth() > 32)
|
||
|
return rewriter.create<arith::TruncFOp>(loc, rewriter.getF32Type(), value);
|
||
|
llvm_unreachable("The only 32-bit float type is f32");
|
||
|
}
|
||
|
|
||
|
// If `in` is a finite value, clamp it between the maximum and minimum values
|
||
|
// of `outElemType` so that subsequent conversion instructions don't
|
||
|
// overflow those out-of-range values to NaN. These semantics are commonly
|
||
|
// used in machine-learning contexts where failure to clamp would lead to
|
||
|
// excessive NaN production.
|
||
|
static Value clampInput(PatternRewriter &rewriter, Location loc,
|
||
|
Type outElemType, Value source) {
|
||
|
Type sourceType = source.getType();
|
||
|
const llvm::fltSemantics &sourceSem =
|
||
|
cast<FloatType>(getElementTypeOrSelf(sourceType)).getFloatSemantics();
|
||
|
const llvm::fltSemantics &targetSem =
|
||
|
cast<FloatType>(outElemType).getFloatSemantics();
|
||
|
|
||
|
APFloat min = APFloat::getLargest(targetSem, /*Negative=*/true);
|
||
|
APFloat max = APFloat::getLargest(targetSem, /*Negative=*/false);
|
||
|
bool ignoredLosesInfo = false;
|
||
|
// We can ignore conversion failures here because this conversion promotes
|
||
|
// from a smaller type to a larger one - ex. there can be no loss of precision
|
||
|
// when casting fp8 to f16.
|
||
|
(void)min.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
|
||
|
(void)max.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
|
||
|
|
||
|
Value minCst = createScalarOrSplatConstant(rewriter, loc, sourceType, min);
|
||
|
Value maxCst = createScalarOrSplatConstant(rewriter, loc, sourceType, max);
|
||
|
|
||
|
Value inf = createScalarOrSplatConstant(
|
||
|
rewriter, loc, sourceType,
|
||
|
APFloat::getInf(sourceSem, /*Negative=*/false));
|
||
|
Value negInf = createScalarOrSplatConstant(
|
||
|
rewriter, loc, sourceType, APFloat::getInf(sourceSem, /*Negative=*/true));
|
||
|
Value isInf = rewriter.createOrFold<arith::CmpFOp>(
|
||
|
loc, arith::CmpFPredicate::OEQ, source, inf);
|
||
|
Value isNegInf = rewriter.createOrFold<arith::CmpFOp>(
|
||
|
loc, arith::CmpFPredicate::OEQ, source, negInf);
|
||
|
Value isNan = rewriter.createOrFold<arith::CmpFOp>(
|
||
|
loc, arith::CmpFPredicate::UNO, source, source);
|
||
|
Value isNonFinite = rewriter.create<arith::OrIOp>(
|
||
|
loc, rewriter.create<arith::OrIOp>(loc, isInf, isNegInf), isNan);
|
||
|
|
||
|
Value clampedBelow = rewriter.create<arith::MaximumFOp>(loc, source, minCst);
|
||
|
Value clamped = rewriter.create<arith::MinimumFOp>(loc, clampedBelow, maxCst);
|
||
|
Value res =
|
||
|
rewriter.create<arith::SelectOp>(loc, isNonFinite, source, clamped);
|
||
|
return res;
|
||
|
}
|
||
|
|
||
|
LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
|
||
|
Type outType = op.getOut().getType();
|
||
|
if (auto outVecType = outType.dyn_cast<VectorType>()) {
|
||
|
if (outVecType.isScalable())
|
||
|
return failure();
|
||
|
if (outVecType.getShape().size() > 1)
|
||
|
// Multi-dimensional vectors are currently unsupported.
|
||
|
return failure();
|
||
|
outType = outVecType.getElementType();
|
||
|
}
|
||
|
auto inType = dyn_cast<FloatType>(getElementTypeOrSelf(op.getIn().getType()));
|
||
|
if (inType && inType.getWidth() <= 8 && saturateFP8)
|
||
|
// Conversion between 8-bit floats is not supported with truncation enabled.
|
||
|
return failure();
|
||
|
return success(outType.isFloat8E5M2FNUZ() || outType.isFloat8E4M3FNUZ());
|
||
|
}
|
||
|
|
||
|
void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
|
||
|
PatternRewriter &rewriter) const {
|
||
|
Location loc = op.getLoc();
|
||
|
Value in = op.getIn();
|
||
|
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
|
||
|
if (saturateFP8)
|
||
|
in = clampInput(rewriter, loc, outElemType, in);
|
||
|
VectorType truncResType = VectorType::get(4, outElemType);
|
||
|
if (!in.getType().isa<VectorType>()) {
|
||
|
Value asFloat = castToF32(in, loc, rewriter);
|
||
|
Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
|
||
|
loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
|
||
|
/*existing=*/nullptr);
|
||
|
Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0);
|
||
|
return rewriter.replaceOp(op, result);
|
||
|
}
|
||
|
VectorType outType = op.getOut().getType().cast<VectorType>();
|
||
|
int64_t numElements = outType.getNumElements();
|
||
|
Value zero = rewriter.createOrFold<arith::ConstantOp>(
|
||
|
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
|
||
|
Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
|
||
|
if (outType.getShape().empty()) {
|
||
|
Value scalarIn =
|
||
|
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
|
||
|
// Recurse to send the 0-D vector case to the 1-D vector case
|
||
|
Value scalarTrunc =
|
||
|
rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
|
||
|
result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
|
||
|
ArrayRef<int64_t>{});
|
||
|
return rewriter.replaceOp(op, result);
|
||
|
}
|
||
|
|
||
|
for (int64_t i = 0; i < numElements; i += 4) {
|
||
|
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
|
||
|
Value thisResult = nullptr;
|
||
|
for (int64_t j = 0; j < elemsThisOp; j += 2) {
|
||
|
Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i + j);
|
||
|
Value asFloatA = castToF32(elemA, loc, rewriter);
|
||
|
Value asFloatB = nullptr;
|
||
|
if (j + 1 < elemsThisOp) {
|
||
|
Value elemB = rewriter.create<vector::ExtractOp>(loc, in, i + j + 1);
|
||
|
asFloatB = castToF32(elemB, loc, rewriter);
|
||
|
}
|
||
|
thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
|
||
|
loc, truncResType, asFloatA, asFloatB, j / 2, thisResult);
|
||
|
}
|
||
|
if (elemsThisOp < 4)
|
||
|
thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
|
||
|
loc, thisResult, 0, elemsThisOp, 1);
|
||
|
result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
|
||
|
result, i, 1);
|
||
|
}
|
||
|
rewriter.replaceOp(op, result);
|
||
|
}
|
||
|
|
||
|
void mlir::arith::populateArithToAMDGPUConversionPatterns(
|
||
|
RewritePatternSet &patterns, bool saturateFP8TruncF) {
|
||
|
patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext());
|
||
|
patterns.add<TruncFToFloat8RewritePattern>(patterns.getContext(),
|
||
|
saturateFP8TruncF);
|
||
|
}
|
||
|
|
||
|
void ArithToAMDGPUConversionPass::runOnOperation() {
|
||
|
Operation *op = getOperation();
|
||
|
RewritePatternSet patterns(op->getContext());
|
||
|
arith::populateArithToAMDGPUConversionPatterns(patterns, saturateFP8Truncf);
|
||
|
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
|
||
|
return signalPassFailure();
|
||
|
}
|