185 lines
6.9 KiB
C++
185 lines
6.9 KiB
C++
|
//===- InferIntRangeInterfaceImpls.cpp - Integer range impls for gpu -===//
|
||
|
//
|
||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||
|
#include "mlir/IR/Matchers.h"
|
||
|
#include "mlir/Interfaces/InferIntRangeInterface.h"
|
||
|
#include "llvm/ADT/STLForwardCompat.h"
|
||
|
#include "llvm/Support/ErrorHandling.h"
|
||
|
#include "llvm/Support/MathExtras.h"
|
||
|
#include <optional>
|
||
|
|
||
|
using namespace mlir;
|
||
|
using namespace mlir::gpu;
|
||
|
|
||
|
// Maximum grid and block dimensions of all known GPUs are less than 2^32.
|
||
|
static constexpr uint64_t kMaxDim = std::numeric_limits<uint32_t>::max();
|
||
|
// Maximum cluster size
|
||
|
static constexpr uint64_t kMaxClusterDim = 8;
|
||
|
// Maximum subgroups are no larger than 128.
|
||
|
static constexpr uint64_t kMaxSubgroupSize = 128;
|
||
|
|
||
|
static ConstantIntRanges getIndexRange(uint64_t umin, uint64_t umax) {
|
||
|
unsigned width = IndexType::kInternalStorageBitWidth;
|
||
|
return ConstantIntRanges::fromUnsigned(APInt(width, umin),
|
||
|
APInt(width, umax));
|
||
|
}
|
||
|
|
||
|
namespace {
|
||
|
enum class LaunchDims : uint32_t { Block = 0, Grid = 1 };
|
||
|
} // end namespace
|
||
|
|
||
|
/// If the operation `op` is in a context that is annotated with maximum
|
||
|
/// launch dimensions (a launch op with constant block or grid
|
||
|
/// sizes or a launch_func op with the appropriate dimensions), return
|
||
|
/// the bound on the maximum size of the dimension that the op is querying.
|
||
|
/// IDs will be one less than this bound.
|
||
|
|
||
|
static Value valueByDim(KernelDim3 dims, Dimension dim) {
|
||
|
switch (dim) {
|
||
|
case Dimension::x:
|
||
|
return dims.x;
|
||
|
case Dimension::y:
|
||
|
return dims.y;
|
||
|
case Dimension::z:
|
||
|
return dims.z;
|
||
|
}
|
||
|
llvm_unreachable("All dimension enum cases handled above");
|
||
|
}
|
||
|
|
||
|
static uint64_t zext(uint32_t arg) { return static_cast<uint64_t>(arg); }
|
||
|
|
||
|
template <typename Op>
|
||
|
static std::optional<uint64_t> getKnownLaunchDim(Op op, LaunchDims type) {
|
||
|
Dimension dim = op.getDimension();
|
||
|
if (auto launch = op->template getParentOfType<LaunchOp>()) {
|
||
|
KernelDim3 bounds;
|
||
|
switch (type) {
|
||
|
case LaunchDims::Block:
|
||
|
bounds = launch.getBlockSizeOperandValues();
|
||
|
break;
|
||
|
case LaunchDims::Grid:
|
||
|
bounds = launch.getGridSizeOperandValues();
|
||
|
break;
|
||
|
}
|
||
|
Value maybeBound = valueByDim(bounds, dim);
|
||
|
APInt value;
|
||
|
if (matchPattern(maybeBound, m_ConstantInt(&value)))
|
||
|
return value.getZExtValue();
|
||
|
}
|
||
|
|
||
|
if (auto func = op->template getParentOfType<GPUFuncOp>()) {
|
||
|
switch (type) {
|
||
|
case LaunchDims::Block:
|
||
|
return llvm::transformOptional(func.getKnownBlockSize(dim), zext);
|
||
|
case LaunchDims::Grid:
|
||
|
return llvm::transformOptional(func.getKnownGridSize(dim), zext);
|
||
|
}
|
||
|
}
|
||
|
return std::nullopt;
|
||
|
}
|
||
|
|
||
|
void ClusterDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
|
||
|
SetIntRangeFn setResultRange) {
|
||
|
setResultRange(getResult(), getIndexRange(1, kMaxClusterDim));
|
||
|
}
|
||
|
|
||
|
void ClusterIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
|
||
|
SetIntRangeFn setResultRange) {
|
||
|
uint64_t max = kMaxClusterDim;
|
||
|
setResultRange(getResult(), getIndexRange(0, max - 1ULL));
|
||
|
}
|
||
|
|
||
|
void BlockDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
|
||
|
SetIntRangeFn setResultRange) {
|
||
|
std::optional<uint64_t> knownVal =
|
||
|
getKnownLaunchDim(*this, LaunchDims::Block);
|
||
|
if (knownVal)
|
||
|
setResultRange(getResult(), getIndexRange(*knownVal, *knownVal));
|
||
|
else
|
||
|
setResultRange(getResult(), getIndexRange(1, kMaxDim));
|
||
|
}
|
||
|
|
||
|
void BlockIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
|
||
|
SetIntRangeFn setResultRange) {
|
||
|
uint64_t max = getKnownLaunchDim(*this, LaunchDims::Grid).value_or(kMaxDim);
|
||
|
setResultRange(getResult(), getIndexRange(0, max - 1ULL));
|
||
|
}
|
||
|
|
||
|
void GridDimOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
|
||
|
SetIntRangeFn setResultRange) {
|
||
|
std::optional<uint64_t> knownVal = getKnownLaunchDim(*this, LaunchDims::Grid);
|
||
|
if (knownVal)
|
||
|
setResultRange(getResult(), getIndexRange(*knownVal, *knownVal));
|
||
|
else
|
||
|
setResultRange(getResult(), getIndexRange(1, kMaxDim));
|
||
|
}
|
||
|
|
||
|
void ThreadIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
|
||
|
SetIntRangeFn setResultRange) {
|
||
|
uint64_t max = getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim);
|
||
|
setResultRange(getResult(), getIndexRange(0, max - 1ULL));
|
||
|
}
|
||
|
|
||
|
void LaneIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
|
||
|
SetIntRangeFn setResultRange) {
|
||
|
setResultRange(getResult(), getIndexRange(0, kMaxSubgroupSize - 1ULL));
|
||
|
}
|
||
|
|
||
|
void SubgroupIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
|
||
|
SetIntRangeFn setResultRange) {
|
||
|
setResultRange(getResult(), getIndexRange(0, kMaxDim - 1ULL));
|
||
|
}
|
||
|
|
||
|
void GlobalIdOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
|
||
|
SetIntRangeFn setResultRange) {
|
||
|
uint64_t blockDimMax =
|
||
|
getKnownLaunchDim(*this, LaunchDims::Block).value_or(kMaxDim);
|
||
|
uint64_t gridDimMax =
|
||
|
getKnownLaunchDim(*this, LaunchDims::Grid).value_or(kMaxDim);
|
||
|
setResultRange(getResult(),
|
||
|
getIndexRange(0, (blockDimMax * gridDimMax) - 1ULL));
|
||
|
}
|
||
|
|
||
|
void NumSubgroupsOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
|
||
|
SetIntRangeFn setResultRange) {
|
||
|
setResultRange(getResult(), getIndexRange(1, kMaxDim));
|
||
|
}
|
||
|
|
||
|
void SubgroupSizeOp::inferResultRanges(ArrayRef<ConstantIntRanges>,
|
||
|
SetIntRangeFn setResultRange) {
|
||
|
setResultRange(getResult(), getIndexRange(1, kMaxSubgroupSize));
|
||
|
}
|
||
|
|
||
|
void LaunchOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||
|
SetIntRangeFn setResultRange) {
|
||
|
auto setRange = [&](const ConstantIntRanges &argRange, Value dimResult,
|
||
|
Value idxResult) {
|
||
|
if (argRange.umin().getBitWidth() != IndexType::kInternalStorageBitWidth)
|
||
|
return;
|
||
|
ConstantIntRanges dimRange =
|
||
|
argRange.intersection(getIndexRange(1, kMaxDim));
|
||
|
setResultRange(dimResult, dimRange);
|
||
|
ConstantIntRanges idxRange =
|
||
|
getIndexRange(0, dimRange.umax().getZExtValue() - 1);
|
||
|
setResultRange(idxResult, idxRange);
|
||
|
};
|
||
|
|
||
|
argRanges = argRanges.drop_front(getAsyncDependencies().size());
|
||
|
KernelDim3 gridDims = getGridSize();
|
||
|
KernelDim3 blockIds = getBlockIds();
|
||
|
setRange(argRanges[0], gridDims.x, blockIds.x);
|
||
|
setRange(argRanges[1], gridDims.y, blockIds.y);
|
||
|
setRange(argRanges[2], gridDims.z, blockIds.z);
|
||
|
KernelDim3 blockDims = getBlockSize();
|
||
|
KernelDim3 threadIds = getThreadIds();
|
||
|
setRange(argRanges[3], blockDims.x, threadIds.x);
|
||
|
setRange(argRanges[4], blockDims.y, threadIds.y);
|
||
|
setRange(argRanges[5], blockDims.z, threadIds.z);
|
||
|
}
|