bolt/deps/llvm-18.1.8/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
2025-02-14 19:21:04 +01:00

634 lines
26 KiB
C++

//===- GPUToSPIRV.cpp - GPU to SPIR-V Patterns ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns to convert GPU dialect to SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include <optional>
using namespace mlir;
static constexpr const char kSPIRVModule[] = "__spv__";
namespace {
/// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation
/// builtin variables.
template <typename SourceOp, spirv::BuiltIn builtin>
class LaunchConfigConversion : public OpConversionPattern<SourceOp> {
public:
using OpConversionPattern<SourceOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// Pattern lowering subgroup size/id to loading SPIR-V invocation
/// builtin variables.
template <typename SourceOp, spirv::BuiltIn builtin>
class SingleDimLaunchConfigConversion : public OpConversionPattern<SourceOp> {
public:
using OpConversionPattern<SourceOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// This is separate because in Vulkan workgroup size is exposed to shaders via
/// a constant with WorkgroupSize decoration. So here we cannot generate a
/// builtin variable; instead the information in the `spirv.entry_point_abi`
/// attribute on the surrounding FuncOp is used to replace the gpu::BlockDimOp.
class WorkGroupSizeConversion : public OpConversionPattern<gpu::BlockDimOp> {
public:
WorkGroupSizeConversion(TypeConverter &typeConverter, MLIRContext *context)
: OpConversionPattern(typeConverter, context, /*benefit*/ 10) {}
LogicalResult
matchAndRewrite(gpu::BlockDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// Pattern to convert a kernel function in GPU dialect within a spirv.module.
class GPUFuncOpConversion final : public OpConversionPattern<gpu::GPUFuncOp> {
public:
using OpConversionPattern<gpu::GPUFuncOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
private:
SmallVector<int32_t, 3> workGroupSizeAsInt32;
};
/// Pattern to convert a gpu.module to a spirv.module.
class GPUModuleConversion final : public OpConversionPattern<gpu::GPUModuleOp> {
public:
using OpConversionPattern<gpu::GPUModuleOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
class GPUModuleEndConversion final
: public OpConversionPattern<gpu::ModuleEndOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(gpu::ModuleEndOp endOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.eraseOp(endOp);
return success();
}
};
/// Pattern to convert a gpu.return into a SPIR-V return.
// TODO: This can go to DRR when GPU return has operands.
class GPUReturnOpConversion final : public OpConversionPattern<gpu::ReturnOp> {
public:
using OpConversionPattern<gpu::ReturnOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(gpu::ReturnOp returnOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// Pattern to convert a gpu.barrier op into a spirv.ControlBarrier op.
class GPUBarrierConversion final : public OpConversionPattern<gpu::BarrierOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(gpu::BarrierOp barrierOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
/// Pattern to convert a gpu.shuffle op into a spirv.GroupNonUniformShuffle op.
class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
//===----------------------------------------------------------------------===//
// Builtins.
//===----------------------------------------------------------------------===//
template <typename SourceOp, spirv::BuiltIn builtin>
LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
Type indexType = typeConverter->getIndexType();
// For Vulkan, these SPIR-V builtin variables are required to be a vector of
// type <3xi32> by the spec:
// https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/NumWorkgroups.html
// https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/WorkgroupId.html
// https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/WorkgroupSize.html
// https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/LocalInvocationId.html
// https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/LocalInvocationId.html
// https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/GlobalInvocationId.html
//
// For OpenCL, it depends on the Physical32/Physical64 addressing model:
// https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables
bool forShader =
typeConverter->getTargetEnv().allows(spirv::Capability::Shader);
Type builtinType = forShader ? rewriter.getIntegerType(32) : indexType;
Value vector =
spirv::getBuiltinVariableValue(op, builtin, builtinType, rewriter);
Value dim = rewriter.create<spirv::CompositeExtractOp>(
op.getLoc(), builtinType, vector,
rewriter.getI32ArrayAttr({static_cast<int32_t>(op.getDimension())}));
if (forShader && builtinType != indexType)
dim = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType, dim);
rewriter.replaceOp(op, dim);
return success();
}
template <typename SourceOp, spirv::BuiltIn builtin>
LogicalResult
SingleDimLaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
SourceOp op, typename SourceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
Type indexType = typeConverter->getIndexType();
Type i32Type = rewriter.getIntegerType(32);
// For Vulkan, these SPIR-V builtin variables are required to be a vector of
// type i32 by the spec:
// https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/NumSubgroups.html
// https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/SubgroupId.html
// https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/SubgroupSize.html
//
// For OpenCL, they are also required to be i32:
// https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_Env.html#_built_in_variables
Value builtinValue =
spirv::getBuiltinVariableValue(op, builtin, i32Type, rewriter);
if (i32Type != indexType)
builtinValue = rewriter.create<spirv::UConvertOp>(op.getLoc(), indexType,
builtinValue);
rewriter.replaceOp(op, builtinValue);
return success();
}
LogicalResult WorkGroupSizeConversion::matchAndRewrite(
gpu::BlockDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
DenseI32ArrayAttr workGroupSizeAttr = spirv::lookupLocalWorkGroupSize(op);
if (!workGroupSizeAttr)
return failure();
int val =
workGroupSizeAttr.asArrayRef()[static_cast<int32_t>(op.getDimension())];
auto convertedType =
getTypeConverter()->convertType(op.getResult().getType());
if (!convertedType)
return failure();
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
op, convertedType, IntegerAttr::get(convertedType, val));
return success();
}
//===----------------------------------------------------------------------===//
// GPUFuncOp
//===----------------------------------------------------------------------===//
// Legalizes a GPU function as an entry SPIR-V function.
static spirv::FuncOp
lowerAsEntryFunction(gpu::GPUFuncOp funcOp, const TypeConverter &typeConverter,
ConversionPatternRewriter &rewriter,
spirv::EntryPointABIAttr entryPointInfo,
ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo) {
auto fnType = funcOp.getFunctionType();
if (fnType.getNumResults()) {
funcOp.emitError("SPIR-V lowering only supports entry functions"
"with no return values right now");
return nullptr;
}
if (!argABIInfo.empty() && fnType.getNumInputs() != argABIInfo.size()) {
funcOp.emitError(
"lowering as entry functions requires ABI info for all arguments "
"or none of them");
return nullptr;
}
// Update the signature to valid SPIR-V types and add the ABI
// attributes. These will be "materialized" by using the
// LowerABIAttributesPass.
TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
{
for (const auto &argType :
enumerate(funcOp.getFunctionType().getInputs())) {
auto convertedType = typeConverter.convertType(argType.value());
if (!convertedType)
return nullptr;
signatureConverter.addInputs(argType.index(), convertedType);
}
}
auto newFuncOp = rewriter.create<spirv::FuncOp>(
funcOp.getLoc(), funcOp.getName(),
rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
std::nullopt));
for (const auto &namedAttr : funcOp->getAttrs()) {
if (namedAttr.getName() == funcOp.getFunctionTypeAttrName() ||
namedAttr.getName() == SymbolTable::getSymbolAttrName())
continue;
newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
}
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
&signatureConverter)))
return nullptr;
rewriter.eraseOp(funcOp);
// Set the attributes for argument and the function.
StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
for (auto argIndex : llvm::seq<unsigned>(0, argABIInfo.size())) {
newFuncOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
}
newFuncOp->setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo);
return newFuncOp;
}
/// Populates `argABI` with spirv.interface_var_abi attributes for lowering
/// gpu.func to spirv.func if no arguments have the attributes set
/// already. Returns failure if any argument has the ABI attribute set already.
static LogicalResult
getDefaultABIAttrs(const spirv::TargetEnv &targetEnv, gpu::GPUFuncOp funcOp,
SmallVectorImpl<spirv::InterfaceVarABIAttr> &argABI) {
if (!spirv::needsInterfaceVarABIAttrs(targetEnv))
return success();
for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
if (funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
argIndex, spirv::getInterfaceVarABIAttrName()))
return failure();
// Vulkan's interface variable requirements needs scalars to be wrapped in a
// struct. The struct held in storage buffer.
std::optional<spirv::StorageClass> sc;
if (funcOp.getArgument(argIndex).getType().isIntOrIndexOrFloat())
sc = spirv::StorageClass::StorageBuffer;
argABI.push_back(
spirv::getInterfaceVarABIAttr(0, argIndex, sc, funcOp.getContext()));
}
return success();
}
LogicalResult GPUFuncOpConversion::matchAndRewrite(
gpu::GPUFuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (!gpu::GPUDialect::isKernel(funcOp))
return failure();
auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
SmallVector<spirv::InterfaceVarABIAttr, 4> argABI;
if (failed(
getDefaultABIAttrs(typeConverter->getTargetEnv(), funcOp, argABI))) {
argABI.clear();
for (auto argIndex : llvm::seq<unsigned>(0, funcOp.getNumArguments())) {
// If the ABI is already specified, use it.
auto abiAttr = funcOp.getArgAttrOfType<spirv::InterfaceVarABIAttr>(
argIndex, spirv::getInterfaceVarABIAttrName());
if (!abiAttr) {
funcOp.emitRemark(
"match failure: missing 'spirv.interface_var_abi' attribute at "
"argument ")
<< argIndex;
return failure();
}
argABI.push_back(abiAttr);
}
}
auto entryPointAttr = spirv::lookupEntryPointABI(funcOp);
if (!entryPointAttr) {
funcOp.emitRemark(
"match failure: missing 'spirv.entry_point_abi' attribute");
return failure();
}
spirv::FuncOp newFuncOp = lowerAsEntryFunction(
funcOp, *getTypeConverter(), rewriter, entryPointAttr, argABI);
if (!newFuncOp)
return failure();
newFuncOp->removeAttr(
rewriter.getStringAttr(gpu::GPUDialect::getKernelFuncAttrName()));
return success();
}
//===----------------------------------------------------------------------===//
// ModuleOp with gpu.module.
//===----------------------------------------------------------------------===//
LogicalResult GPUModuleConversion::matchAndRewrite(
gpu::GPUModuleOp moduleOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
const spirv::TargetEnv &targetEnv = typeConverter->getTargetEnv();
spirv::AddressingModel addressingModel = spirv::getAddressingModel(
targetEnv, typeConverter->getOptions().use64bitIndex);
FailureOr<spirv::MemoryModel> memoryModel = spirv::getMemoryModel(targetEnv);
if (failed(memoryModel))
return moduleOp.emitRemark(
"cannot deduce memory model from 'spirv.target_env'");
// Add a keyword to the module name to avoid symbolic conflict.
std::string spvModuleName = (kSPIRVModule + moduleOp.getName()).str();
auto spvModule = rewriter.create<spirv::ModuleOp>(
moduleOp.getLoc(), addressingModel, *memoryModel, std::nullopt,
StringRef(spvModuleName));
// Move the region from the module op into the SPIR-V module.
Region &spvModuleRegion = spvModule.getRegion();
rewriter.inlineRegionBefore(moduleOp.getBodyRegion(), spvModuleRegion,
spvModuleRegion.begin());
// The spirv.module build method adds a block. Remove that.
rewriter.eraseBlock(&spvModuleRegion.back());
// Some of the patterns call `lookupTargetEnv` during conversion and they
// will fail if called after GPUModuleConversion and we don't preserve
// `TargetEnv` attribute.
// Copy TargetEnvAttr only if it is attached directly to the GPUModuleOp.
if (auto attr = moduleOp->getAttrOfType<spirv::TargetEnvAttr>(
spirv::getTargetEnvAttrName()))
spvModule->setAttr(spirv::getTargetEnvAttrName(), attr);
rewriter.eraseOp(moduleOp);
return success();
}
//===----------------------------------------------------------------------===//
// GPU return inside kernel functions to SPIR-V return.
//===----------------------------------------------------------------------===//
LogicalResult GPUReturnOpConversion::matchAndRewrite(
gpu::ReturnOp returnOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (!adaptor.getOperands().empty())
return failure();
rewriter.replaceOpWithNewOp<spirv::ReturnOp>(returnOp);
return success();
}
//===----------------------------------------------------------------------===//
// Barrier.
//===----------------------------------------------------------------------===//
LogicalResult GPUBarrierConversion::matchAndRewrite(
gpu::BarrierOp barrierOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
MLIRContext *context = getContext();
// Both execution and memory scope should be workgroup.
auto scope = spirv::ScopeAttr::get(context, spirv::Scope::Workgroup);
// Require acquire and release memory semantics for workgroup memory.
auto memorySemantics = spirv::MemorySemanticsAttr::get(
context, spirv::MemorySemantics::WorkgroupMemory |
spirv::MemorySemantics::AcquireRelease);
rewriter.replaceOpWithNewOp<spirv::ControlBarrierOp>(barrierOp, scope, scope,
memorySemantics);
return success();
}
//===----------------------------------------------------------------------===//
// Shuffle
//===----------------------------------------------------------------------===//
LogicalResult GPUShuffleConversion::matchAndRewrite(
gpu::ShuffleOp shuffleOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Require the shuffle width to be the same as the target's subgroup size,
// given that for SPIR-V non-uniform subgroup ops, we cannot select
// participating invocations.
auto targetEnv = getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
unsigned subgroupSize =
targetEnv.getAttr().getResourceLimits().getSubgroupSize();
IntegerAttr widthAttr;
if (!matchPattern(shuffleOp.getWidth(), m_Constant(&widthAttr)) ||
widthAttr.getValue().getZExtValue() != subgroupSize)
return rewriter.notifyMatchFailure(
shuffleOp, "shuffle width and target subgroup size mismatch");
Location loc = shuffleOp.getLoc();
Value trueVal = spirv::ConstantOp::getOne(rewriter.getI1Type(),
shuffleOp.getLoc(), rewriter);
auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
Value result;
switch (shuffleOp.getMode()) {
case gpu::ShuffleMode::XOR:
result = rewriter.create<spirv::GroupNonUniformShuffleXorOp>(
loc, scope, adaptor.getValue(), adaptor.getOffset());
break;
case gpu::ShuffleMode::IDX:
result = rewriter.create<spirv::GroupNonUniformShuffleOp>(
loc, scope, adaptor.getValue(), adaptor.getOffset());
break;
default:
return rewriter.notifyMatchFailure(shuffleOp, "unimplemented shuffle mode");
}
rewriter.replaceOp(shuffleOp, {result, trueVal});
return success();
}
//===----------------------------------------------------------------------===//
// Group ops
//===----------------------------------------------------------------------===//
template <typename UniformOp, typename NonUniformOp>
static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc,
Value arg, bool isGroup, bool isUniform) {
Type type = arg.getType();
auto scope = mlir::spirv::ScopeAttr::get(builder.getContext(),
isGroup ? spirv::Scope::Workgroup
: spirv::Scope::Subgroup);
auto groupOp = spirv::GroupOperationAttr::get(builder.getContext(),
spirv::GroupOperation::Reduce);
if (isUniform) {
return builder.create<UniformOp>(loc, type, scope, groupOp, arg)
.getResult();
}
return builder.create<NonUniformOp>(loc, type, scope, groupOp, arg, Value{})
.getResult();
}
static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
Location loc, Value arg,
gpu::AllReduceOperation opType,
bool isGroup, bool isUniform) {
enum class ElemType { Float, Boolean, Integer };
using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool);
struct OpHandler {
gpu::AllReduceOperation kind;
ElemType elemType;
FuncT func;
};
Type type = arg.getType();
ElemType elementType;
if (isa<FloatType>(type)) {
elementType = ElemType::Float;
} else if (auto intTy = dyn_cast<IntegerType>(type)) {
elementType = (intTy.getIntOrFloatBitWidth() == 1) ? ElemType::Boolean
: ElemType::Integer;
} else {
return std::nullopt;
}
// TODO(https://github.com/llvm/llvm-project/issues/73459): The SPIR-V spec
// does not specify how -0.0 / +0.0 and NaN values are handled in *FMin/*FMax
// reduction ops. We should account possible precision requirements in this
// conversion.
using ReduceType = gpu::AllReduceOperation;
const OpHandler handlers[] = {
{ReduceType::ADD, ElemType::Integer,
&createGroupReduceOpImpl<spirv::GroupIAddOp,
spirv::GroupNonUniformIAddOp>},
{ReduceType::ADD, ElemType::Float,
&createGroupReduceOpImpl<spirv::GroupFAddOp,
spirv::GroupNonUniformFAddOp>},
{ReduceType::MUL, ElemType::Integer,
&createGroupReduceOpImpl<spirv::GroupIMulKHROp,
spirv::GroupNonUniformIMulOp>},
{ReduceType::MUL, ElemType::Float,
&createGroupReduceOpImpl<spirv::GroupFMulKHROp,
spirv::GroupNonUniformFMulOp>},
{ReduceType::MINUI, ElemType::Integer,
&createGroupReduceOpImpl<spirv::GroupUMinOp,
spirv::GroupNonUniformUMinOp>},
{ReduceType::MINSI, ElemType::Integer,
&createGroupReduceOpImpl<spirv::GroupSMinOp,
spirv::GroupNonUniformSMinOp>},
{ReduceType::MINNUMF, ElemType::Float,
&createGroupReduceOpImpl<spirv::GroupFMinOp,
spirv::GroupNonUniformFMinOp>},
{ReduceType::MAXUI, ElemType::Integer,
&createGroupReduceOpImpl<spirv::GroupUMaxOp,
spirv::GroupNonUniformUMaxOp>},
{ReduceType::MAXSI, ElemType::Integer,
&createGroupReduceOpImpl<spirv::GroupSMaxOp,
spirv::GroupNonUniformSMaxOp>},
{ReduceType::MAXNUMF, ElemType::Float,
&createGroupReduceOpImpl<spirv::GroupFMaxOp,
spirv::GroupNonUniformFMaxOp>},
{ReduceType::MINIMUMF, ElemType::Float,
&createGroupReduceOpImpl<spirv::GroupFMinOp,
spirv::GroupNonUniformFMinOp>},
{ReduceType::MAXIMUMF, ElemType::Float,
&createGroupReduceOpImpl<spirv::GroupFMaxOp,
spirv::GroupNonUniformFMaxOp>}};
for (const OpHandler &handler : handlers)
if (handler.kind == opType && elementType == handler.elemType)
return handler.func(builder, loc, arg, isGroup, isUniform);
return std::nullopt;
}
/// Pattern to convert a gpu.all_reduce op into a SPIR-V group op.
class GPUAllReduceConversion final
: public OpConversionPattern<gpu::AllReduceOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(gpu::AllReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto opType = op.getOp();
// gpu.all_reduce can have either reduction op attribute or reduction
// region. Only attribute version is supported.
if (!opType)
return failure();
auto result =
createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(), *opType,
/*isGroup*/ true, op.getUniform());
if (!result)
return failure();
rewriter.replaceOp(op, *result);
return success();
}
};
/// Pattern to convert a gpu.subgroup_reduce op into a SPIR-V group op.
class GPUSubgroupReduceConversion final
: public OpConversionPattern<gpu::SubgroupReduceOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!isa<spirv::ScalarType>(adaptor.getValue().getType()))
return rewriter.notifyMatchFailure(op, "reduction type is not a scalar");
auto result = createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(),
adaptor.getOp(),
/*isGroup=*/false, adaptor.getUniform());
if (!result)
return failure();
rewriter.replaceOp(op, *result);
return success();
}
};
//===----------------------------------------------------------------------===//
// GPU To SPIRV Patterns.
//===----------------------------------------------------------------------===//
void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<
GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
GPUModuleEndConversion, GPUReturnOpConversion, GPUShuffleConversion,
LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
LaunchConfigConversion<gpu::ThreadIdOp,
spirv::BuiltIn::LocalInvocationId>,
LaunchConfigConversion<gpu::GlobalIdOp,
spirv::BuiltIn::GlobalInvocationId>,
SingleDimLaunchConfigConversion<gpu::SubgroupIdOp,
spirv::BuiltIn::SubgroupId>,
SingleDimLaunchConfigConversion<gpu::NumSubgroupsOp,
spirv::BuiltIn::NumSubgroups>,
SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
spirv::BuiltIn::SubgroupSize>,
WorkGroupSizeConversion, GPUAllReduceConversion,
GPUSubgroupReduceConversion>(typeConverter, patterns.getContext());
}