278 lines
11 KiB
C++
278 lines
11 KiB
C++
|
//===- CreateAsyncGroups.cpp - Create async device copies -----------------===//
|
||
|
//
|
||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
#include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
|
||
|
|
||
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||
|
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||
|
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
|
||
|
#include "mlir/Dialect/NVGPU/Transforms/Utils.h"
|
||
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||
|
#include "mlir/IR/BuiltinAttributes.h"
|
||
|
#include "mlir/IR/BuiltinTypes.h"
|
||
|
|
||
|
using namespace mlir;
|
||
|
|
||
|
/// Return "true" if the given vector transfer op is contiguous and suitable
|
||
|
/// for replacement with an async copy.
|
||
|
template <typename OpTy>
|
||
|
static bool isContiguousXferOp(OpTy op) {
|
||
|
return op.getPermutationMap().isMinorIdentity() && op.isDimInBounds(0) &&
|
||
|
op.hasPureBufferSemantics() &&
|
||
|
isLastMemrefDimUnitStride(
|
||
|
cast<MemRefType>(nvgpu::getMemrefOperand(op).getType()));
|
||
|
}
|
||
|
|
||
|
/// Return "true" if the given op is a contiguous and suitable
|
||
|
/// vector.transfer_write or vector.store op.
|
||
|
static bool isContiguousStore(Operation *write) {
|
||
|
if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(write))
|
||
|
return isContiguousXferOp(transferWrite) && !transferWrite.getMask();
|
||
|
// vector.store are always contiguous.
|
||
|
return isa<vector::StoreOp>(write);
|
||
|
}
|
||
|
|
||
|
/// Return "true" if the given op is a contiguous and suitable
|
||
|
/// vector.transfer_read or vector.load op.
|
||
|
static bool isContiguousRead(Operation *read) {
|
||
|
if (auto transferRead = dyn_cast<vector::TransferReadOp>(read))
|
||
|
return isContiguousXferOp(transferRead);
|
||
|
// vector.load are always contiguous.
|
||
|
return isa<vector::LoadOp>(read);
|
||
|
}
|
||
|
|
||
|
namespace {
|
||
|
/// A vector.create_mask op and extract position.
|
||
|
struct TransferMask {
|
||
|
vector::CreateMaskOp createMaskOp;
|
||
|
SmallVector<int64_t> extractPosition;
|
||
|
};
|
||
|
} // namespace
|
||
|
|
||
|
/// If the given vector load op has a mask that is defined by
|
||
|
/// vector.create_mask, return that op.
|
||
|
static FailureOr<TransferMask> getMaskOp(Operation *loadOp) {
|
||
|
auto transferRead = dyn_cast<vector::TransferReadOp>(loadOp);
|
||
|
if (!transferRead || !transferRead.getMask())
|
||
|
return TransferMask{{}, {}};
|
||
|
assert(transferRead.getMask().getType().getRank() == 1 &&
|
||
|
"expected 1-D mask");
|
||
|
|
||
|
// Case 1: Mask is the result of a vector.create_mask.
|
||
|
if (auto maskOp =
|
||
|
transferRead.getMask().getDefiningOp<vector::CreateMaskOp>())
|
||
|
return TransferMask{maskOp, {}};
|
||
|
|
||
|
// Case 2: Mask is the result of a vector.extract(vector.create_mask).
|
||
|
if (auto extractOp =
|
||
|
transferRead.getMask().getDefiningOp<vector::ExtractOp>())
|
||
|
if (auto maskOp =
|
||
|
extractOp.getVector().getDefiningOp<vector::CreateMaskOp>())
|
||
|
return TransferMask{maskOp,
|
||
|
SmallVector<int64_t>(extractOp.getStaticPosition())};
|
||
|
|
||
|
// All other cases: not supported.
|
||
|
return failure();
|
||
|
}
|
||
|
|
||
|
/// Build an SSA value that represents the number of read elements.
|
||
|
static Value buildNumReadElements(OpBuilder &b, Location loc,
|
||
|
Operation *readOp) {
|
||
|
FailureOr<TransferMask> transferMask = getMaskOp(readOp);
|
||
|
assert(succeeded(transferMask) && "invalid transfer mask");
|
||
|
|
||
|
// No mask => no num_read_elements.
|
||
|
if (!transferMask->createMaskOp)
|
||
|
return Value();
|
||
|
|
||
|
// No extract: return size of "ones" segment in the mask.
|
||
|
if (transferMask->extractPosition.empty()) {
|
||
|
assert(transferMask->createMaskOp.getNumOperands() == 1 &&
|
||
|
"expected single operand");
|
||
|
return transferMask->createMaskOp.getOperand(0);
|
||
|
}
|
||
|
|
||
|
// vector.extract(vector.create_mask).
|
||
|
// If extract_pos < num_ones, take number of elements from the least
|
||
|
// significant dimension. (Do this for all dimensions and bit-AND the
|
||
|
// conditions.)
|
||
|
assert(transferMask->createMaskOp.getVectorType().getRank() -
|
||
|
transferMask->extractPosition.size() ==
|
||
|
1 &&
|
||
|
"expected N-D -> (N-1)-D extract");
|
||
|
Value cond;
|
||
|
// Note: There is one more `sz` than `pos`. The loop end with the last `pos`.
|
||
|
for (auto [pos, sz] : llvm::zip(transferMask->extractPosition,
|
||
|
transferMask->createMaskOp->getOperands())) {
|
||
|
Value cmp =
|
||
|
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
|
||
|
b.create<arith::ConstantIndexOp>(loc, pos), sz);
|
||
|
if (!cond) {
|
||
|
cond = cmp;
|
||
|
continue;
|
||
|
}
|
||
|
cond = b.create<arith::AndIOp>(loc, cmp, cond);
|
||
|
}
|
||
|
return b.create<arith::SelectOp>(
|
||
|
loc, cond, transferMask->createMaskOp->getOperands().back(),
|
||
|
b.create<arith::ConstantIndexOp>(loc, 0));
|
||
|
}
|
||
|
|
||
|
/// Return "true" if the conversion to async copy is supported by "async copy".
|
||
|
static bool resultsInSupportedAsyncCopy(MemRefType memrefType,
|
||
|
VectorType vecType) {
|
||
|
assert(vecType.getRank() == 1 && "expected 1-D vector");
|
||
|
constexpr int64_t kSupportedCpAsyncAlignmentsInBytes[3] = {4, 8, 16};
|
||
|
|
||
|
// Condition 1: the copy size must be supported.
|
||
|
bool supportedCopySize = false;
|
||
|
int64_t numElements = vecType.getNumElements();
|
||
|
Type elementType = vecType.getElementType();
|
||
|
for (int64_t alignmentInBytes : kSupportedCpAsyncAlignmentsInBytes) {
|
||
|
if (alignmentInBytes * 8 ==
|
||
|
numElements * elementType.getIntOrFloatBitWidth()) {
|
||
|
supportedCopySize = true;
|
||
|
break;
|
||
|
}
|
||
|
}
|
||
|
if (!supportedCopySize)
|
||
|
return false;
|
||
|
|
||
|
// TODO: Condition 2: the alignments must be supported. For cp.async the
|
||
|
// NVIDIA doc (section 6.4.1) says: "The address must be naturally aligned to
|
||
|
// a multiple of the access size. If an address is not properly aligned, the
|
||
|
// resulting behavior is undefined.".
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
void nvgpu::createAsyncGroups(RewriterBase &rewriter, Operation *op,
|
||
|
bool bypassL1) {
|
||
|
llvm::SmallSetVector<Operation *, 16> copyToSharedMem;
|
||
|
|
||
|
// Look for all the copy that can be converted to async copy ops.
|
||
|
op->walk([&](Operation *writeOp) {
|
||
|
// Look for contiguous 1D vector store into shared memory.
|
||
|
if (!isContiguousStore(writeOp))
|
||
|
return;
|
||
|
Value vectorVal = nvgpu::getValueStored(writeOp);
|
||
|
if (cast<VectorType>(vectorVal.getType()).getRank() != 1)
|
||
|
return;
|
||
|
Value storeBase = nvgpu::getMemrefOperand(writeOp);
|
||
|
if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
|
||
|
cast<MemRefType>(storeBase.getType())))
|
||
|
return;
|
||
|
|
||
|
// The stored vector must originate from a contiguous 1D vector load.
|
||
|
Operation *readOp = vectorVal.getDefiningOp();
|
||
|
if (readOp == nullptr || !isContiguousRead(readOp))
|
||
|
return;
|
||
|
Value loadBase = nvgpu::getMemrefOperand(readOp);
|
||
|
// Should be reading from global memory (not shared memory).
|
||
|
if (nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
|
||
|
cast<MemRefType>(loadBase.getType())))
|
||
|
return;
|
||
|
|
||
|
// Look for compatible mask and padding.
|
||
|
if (auto transferRead = dyn_cast<vector::TransferReadOp>(readOp)) {
|
||
|
if (Value mask = transferRead.getMask()) {
|
||
|
if (getConstantIntValue(transferRead.getPadding()) ==
|
||
|
static_cast<int64_t>(0))
|
||
|
return;
|
||
|
if (failed(getMaskOp(readOp)))
|
||
|
return;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Check whether both accesses are supported before we emit: this is
|
||
|
// necessary to ensure the correctness of DeviceAsyncCopyOp.
|
||
|
VectorType vecType = cast<VectorType>(vectorVal.getType());
|
||
|
|
||
|
if (!resultsInSupportedAsyncCopy(cast<MemRefType>(loadBase.getType()),
|
||
|
vecType) ||
|
||
|
!resultsInSupportedAsyncCopy(cast<MemRefType>(storeBase.getType()),
|
||
|
vecType))
|
||
|
return;
|
||
|
|
||
|
copyToSharedMem.insert(writeOp);
|
||
|
return;
|
||
|
});
|
||
|
|
||
|
while (!copyToSharedMem.empty()) {
|
||
|
// Start a group with the first write.
|
||
|
SmallVector<Operation *> group;
|
||
|
Operation *writeOp = *copyToSharedMem.begin();
|
||
|
copyToSharedMem.remove(writeOp);
|
||
|
group.push_back(writeOp);
|
||
|
Operation *nextNode = writeOp;
|
||
|
|
||
|
// Look in the next nodes for more copies to add to the same group.
|
||
|
while ((nextNode = nextNode->getNextNode())) {
|
||
|
// Ignore ops without side effects.
|
||
|
auto memInterface = dyn_cast<MemoryEffectOpInterface>(nextNode);
|
||
|
if (memInterface && memInterface.hasNoEffect() &&
|
||
|
!nextNode->hasTrait<OpTrait::HasRecursiveMemoryEffects>())
|
||
|
continue;
|
||
|
// Ignore read from a different address space.
|
||
|
if (isa<vector::TransferReadOp, vector::LoadOp>(nextNode)) {
|
||
|
Operation *readOp = nextNode;
|
||
|
Value memrefOperand = nvgpu::getMemrefOperand(readOp);
|
||
|
if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(
|
||
|
cast<MemRefType>(memrefOperand.getType()))) {
|
||
|
continue;
|
||
|
}
|
||
|
}
|
||
|
if (copyToSharedMem.count(nextNode)) {
|
||
|
// Found another copy, add it to the group.
|
||
|
copyToSharedMem.remove(nextNode);
|
||
|
group.push_back(nextNode);
|
||
|
continue;
|
||
|
}
|
||
|
// If the op is something else stop the accumulating op in the group.
|
||
|
break;
|
||
|
}
|
||
|
|
||
|
// Emit the group.
|
||
|
SmallVector<Value> tokens;
|
||
|
for (Operation *writeOp : group) {
|
||
|
rewriter.setInsertionPoint(writeOp);
|
||
|
Value vectorVal = nvgpu::getValueStored(writeOp);
|
||
|
auto vectorType = cast<VectorType>(vectorVal.getType());
|
||
|
int64_t numElements = vectorType.getNumElements();
|
||
|
Operation *readOp = vectorVal.getDefiningOp();
|
||
|
Value storeBase = nvgpu::getMemrefOperand(writeOp);
|
||
|
Value loadBase = nvgpu::getMemrefOperand(readOp);
|
||
|
Value numReadElements =
|
||
|
buildNumReadElements(rewriter, writeOp->getLoc(), readOp);
|
||
|
auto dstMemref = cast<MemRefType>(storeBase.getType());
|
||
|
int64_t sizeInBytes =
|
||
|
(dstMemref.getElementTypeBitWidth() * numElements) / 8;
|
||
|
// bypass_l1 only possible with 16 byte transfer.
|
||
|
Value token = rewriter.create<nvgpu::DeviceAsyncCopyOp>(
|
||
|
writeOp->getLoc(), nvgpu::DeviceAsyncTokenType::get(op->getContext()),
|
||
|
/*dst=*/storeBase, /*dstIndices=*/nvgpu::getIndices(writeOp),
|
||
|
/*src=*/loadBase,
|
||
|
/*srcIndices=*/nvgpu::getIndices(readOp),
|
||
|
/*dstElements=*/rewriter.getIndexAttr(numElements),
|
||
|
/*srcElements=*/numReadElements,
|
||
|
/*bypassL1=*/bypassL1 && sizeInBytes == 16 ? rewriter.getUnitAttr()
|
||
|
: UnitAttr());
|
||
|
tokens.push_back(token);
|
||
|
}
|
||
|
|
||
|
// Create the group and wait for it right after.
|
||
|
Value groupToken = rewriter.create<nvgpu::DeviceAsyncCreateGroupOp>(
|
||
|
op->getLoc(), nvgpu::DeviceAsyncTokenType::get(op->getContext()),
|
||
|
tokens);
|
||
|
rewriter.create<nvgpu::DeviceAsyncWaitOp>(op->getLoc(), groupToken,
|
||
|
nullptr);
|
||
|
// Clean up old stores.
|
||
|
for (Operation *writeOp : group)
|
||
|
rewriter.eraseOp(writeOp);
|
||
|
}
|
||
|
}
|