bolt/deps/llvm-18.1.8/mlir/lib/Dialect/NVGPU/Transforms/CreateAsyncGroups.cpp

278 lines
11 KiB
C++
Raw Normal View History

2025-02-14 19:21:04 +01:00
//===- CreateAsyncGroups.cpp - Create async device copies -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Dialect/NVGPU/Transforms/Utils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
using namespace mlir;
/// Return "true" if the given vector transfer op is contiguous and suitable
/// for replacement with an async copy.
template <typename OpTy>
static bool isContiguousXferOp(OpTy op) {
return op.getPermutationMap().isMinorIdentity() && op.isDimInBounds(0) &&
op.hasPureBufferSemantics() &&
isLastMemrefDimUnitStride(
cast<MemRefType>(nvgpu::getMemrefOperand(op).getType()));
}
/// Return "true" if the given op is a contiguous and suitable
/// vector.transfer_write or vector.store op.
static bool isContiguousStore(Operation *write) {
if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(write))
return isContiguousXferOp(transferWrite) && !transferWrite.getMask();
// vector.store are always contiguous.
return isa<vector::StoreOp>(write);
}
/// Return "true" if the given op is a contiguous and suitable
/// vector.transfer_read or vector.load op.
static bool isContiguousRead(Operation *read) {
if (auto transferRead = dyn_cast<vector::TransferReadOp>(read))
return isContiguousXferOp(transferRead);
// vector.load are always contiguous.
return isa<vector::LoadOp>(read);
}
namespace {
/// A vector.create_mask op and extract position.
struct TransferMask {
vector::CreateMaskOp createMaskOp;
SmallVector<int64_t> extractPosition;
};
} // namespace
/// If the given vector load op has a mask that is defined by
/// vector.create_mask, return that op.
static FailureOr<TransferMask> getMaskOp(Operation *loadOp) {
auto transferRead = dyn_cast<vector::TransferReadOp>(loadOp);
if (!transferRead || !transferRead.getMask())
return TransferMask{{}, {}};
assert(transferRead.getMask().getType().getRank() == 1 &&
"expected 1-D mask");
// Case 1: Mask is the result of a vector.create_mask.
if (auto maskOp =
transferRead.getMask().getDefiningOp<vector::CreateMaskOp>())
return TransferMask{maskOp, {}};
// Case 2: Mask is the result of a vector.extract(vector.create_mask).
if (auto extractOp =
transferRead.getMask().getDefiningOp<vector::ExtractOp>())
if (auto maskOp =
extractOp.getVector().getDefiningOp<vector::CreateMaskOp>())
return TransferMask{maskOp,
SmallVector<int64_t>(extractOp.getStaticPosition())};
// All other cases: not supported.
return failure();
}
/// Build an SSA value that represents the number of read elements.
static Value buildNumReadElements(OpBuilder &b, Location loc,
Operation *readOp) {
FailureOr<TransferMask> transferMask = getMaskOp(readOp);
assert(succeeded(transferMask) && "invalid transfer mask");
// No mask => no num_read_elements.
if (!transferMask->createMaskOp)
return Value();
// No extract: return size of "ones" segment in the mask.
if (transferMask->extractPosition.empty()) {
assert(transferMask->createMaskOp.getNumOperands() == 1 &&
"expected single operand");
return transferMask->createMaskOp.getOperand(0);
}
// vector.extract(vector.create_mask).
// If extract_pos < num_ones, take number of elements from the least
// significant dimension. (Do this for all dimensions and bit-AND the
// conditions.)
assert(transferMask->createMaskOp.getVectorType().getRank() -
transferMask->extractPosition.size() ==
1 &&
"expected N-D -> (N-1)-D extract");
Value cond;
// Note: There is one more `sz` than `pos`. The loop end with the last `pos`.
for (auto [pos, sz] : llvm::zip(transferMask->extractPosition,
transferMask->createMaskOp->getOperands())) {
Value cmp =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
b.create<arith::ConstantIndexOp>(loc, pos), sz);
if (!cond) {
cond = cmp;
continue;
}
cond = b.create<arith::AndIOp>(loc, cmp, cond);
}
return b.create<arith::SelectOp>(
loc, cond, transferMask->createMaskOp->getOperands().back(),
b.create<arith::ConstantIndexOp>(loc, 0));
}
/// Return "true" if the conversion to async copy is supported by "async copy".
static bool resultsInSupportedAsyncCopy(MemRefType memrefType,
VectorType vecType) {
assert(vecType.getRank() == 1 && "expected 1-D vector");
constexpr int64_t kSupportedCpAsyncAlignmentsInBytes[3] = {4, 8, 16};
// Condition 1: the copy size must be supported.
bool supportedCopySize = false;
int64_t numElements = vecType.getNumElements();
Type elementType = vecType.getElementType();
for (int64_t alignmentInBytes : kSupportedCpAsyncAlignmentsInBytes) {
if (alignmentInBytes * 8 ==
numElements * elementType.getIntOrFloatBitWidth()) {
supportedCopySize = true;
break;
}
}
if (!supportedCopySize)
return false;
// TODO: Condition 2: the alignments must be supported. For cp.async the
// NVIDIA doc (section 6.4.1) says: "The address must be naturally aligned to
// a multiple of the access size. If an address is not properly aligned, the
// resulting behavior is undefined.".
return true;
}
void nvgpu::createAsyncGroups(RewriterBase &rewriter, Operation *op,
bool bypassL1) {
llvm::SmallSetVector<Operation *, 16> copyToSharedMem;
// Look for all the copy that can be converted to async copy ops.
op->walk([&](Operation *writeOp) {
// Look for contiguous 1D vector store into shared memory.
if (!isContiguousStore(writeOp))
return;
Value vectorVal = nvgpu::getValueStored(writeOp);
if (cast<VectorType>(vectorVal.getType()).getRank() != 1)
return;
Value storeBase = nvgpu::getMemrefOperand(writeOp);
if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
cast<MemRefType>(storeBase.getType())))
return;
// The stored vector must originate from a contiguous 1D vector load.
Operation *readOp = vectorVal.getDefiningOp();
if (readOp == nullptr || !isContiguousRead(readOp))
return;
Value loadBase = nvgpu::getMemrefOperand(readOp);
// Should be reading from global memory (not shared memory).
if (nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
cast<MemRefType>(loadBase.getType())))
return;
// Look for compatible mask and padding.
if (auto transferRead = dyn_cast<vector::TransferReadOp>(readOp)) {
if (Value mask = transferRead.getMask()) {
if (getConstantIntValue(transferRead.getPadding()) ==
static_cast<int64_t>(0))
return;
if (failed(getMaskOp(readOp)))
return;
}
}
// Check whether both accesses are supported before we emit: this is
// necessary to ensure the correctness of DeviceAsyncCopyOp.
VectorType vecType = cast<VectorType>(vectorVal.getType());
if (!resultsInSupportedAsyncCopy(cast<MemRefType>(loadBase.getType()),
vecType) ||
!resultsInSupportedAsyncCopy(cast<MemRefType>(storeBase.getType()),
vecType))
return;
copyToSharedMem.insert(writeOp);
return;
});
while (!copyToSharedMem.empty()) {
// Start a group with the first write.
SmallVector<Operation *> group;
Operation *writeOp = *copyToSharedMem.begin();
copyToSharedMem.remove(writeOp);
group.push_back(writeOp);
Operation *nextNode = writeOp;
// Look in the next nodes for more copies to add to the same group.
while ((nextNode = nextNode->getNextNode())) {
// Ignore ops without side effects.
auto memInterface = dyn_cast<MemoryEffectOpInterface>(nextNode);
if (memInterface && memInterface.hasNoEffect() &&
!nextNode->hasTrait<OpTrait::HasRecursiveMemoryEffects>())
continue;
// Ignore read from a different address space.
if (isa<vector::TransferReadOp, vector::LoadOp>(nextNode)) {
Operation *readOp = nextNode;
Value memrefOperand = nvgpu::getMemrefOperand(readOp);
if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
cast<MemRefType>(memrefOperand.getType()))) {
continue;
}
}
if (copyToSharedMem.count(nextNode)) {
// Found another copy, add it to the group.
copyToSharedMem.remove(nextNode);
group.push_back(nextNode);
continue;
}
// If the op is something else stop the accumulating op in the group.
break;
}
// Emit the group.
SmallVector<Value> tokens;
for (Operation *writeOp : group) {
rewriter.setInsertionPoint(writeOp);
Value vectorVal = nvgpu::getValueStored(writeOp);
auto vectorType = cast<VectorType>(vectorVal.getType());
int64_t numElements = vectorType.getNumElements();
Operation *readOp = vectorVal.getDefiningOp();
Value storeBase = nvgpu::getMemrefOperand(writeOp);
Value loadBase = nvgpu::getMemrefOperand(readOp);
Value numReadElements =
buildNumReadElements(rewriter, writeOp->getLoc(), readOp);
auto dstMemref = cast<MemRefType>(storeBase.getType());
int64_t sizeInBytes =
(dstMemref.getElementTypeBitWidth() * numElements) / 8;
// bypass_l1 only possible with 16 byte transfer.
Value token = rewriter.create<nvgpu::DeviceAsyncCopyOp>(
writeOp->getLoc(), nvgpu::DeviceAsyncTokenType::get(op->getContext()),
/*dst=*/storeBase, /*dstIndices=*/nvgpu::getIndices(writeOp),
/*src=*/loadBase,
/*srcIndices=*/nvgpu::getIndices(readOp),
/*dstElements=*/rewriter.getIndexAttr(numElements),
/*srcElements=*/numReadElements,
/*bypassL1=*/bypassL1 && sizeInBytes == 16 ? rewriter.getUnitAttr()
: UnitAttr());
tokens.push_back(token);
}
// Create the group and wait for it right after.
Value groupToken = rewriter.create<nvgpu::DeviceAsyncCreateGroupOp>(
op->getLoc(), nvgpu::DeviceAsyncTokenType::get(op->getContext()),
tokens);
rewriter.create<nvgpu::DeviceAsyncWaitOp>(op->getLoc(), groupToken,
nullptr);
// Clean up old stores.
for (Operation *writeOp : group)
rewriter.eraseOp(writeOp);
}
}