//===- CreateAsyncGroups.cpp - Create async device copies -----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/NVGPU/Transforms/Transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/NVGPU/Transforms/Utils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" using namespace mlir; /// Return "true" if the given vector transfer op is contiguous and suitable /// for replacement with an async copy. template static bool isContiguousXferOp(OpTy op) { return op.getPermutationMap().isMinorIdentity() && op.isDimInBounds(0) && op.hasPureBufferSemantics() && isLastMemrefDimUnitStride( cast(nvgpu::getMemrefOperand(op).getType())); } /// Return "true" if the given op is a contiguous and suitable /// vector.transfer_write or vector.store op. static bool isContiguousStore(Operation *write) { if (auto transferWrite = dyn_cast(write)) return isContiguousXferOp(transferWrite) && !transferWrite.getMask(); // vector.store are always contiguous. return isa(write); } /// Return "true" if the given op is a contiguous and suitable /// vector.transfer_read or vector.load op. static bool isContiguousRead(Operation *read) { if (auto transferRead = dyn_cast(read)) return isContiguousXferOp(transferRead); // vector.load are always contiguous. return isa(read); } namespace { /// A vector.create_mask op and extract position. struct TransferMask { vector::CreateMaskOp createMaskOp; SmallVector extractPosition; }; } // namespace /// If the given vector load op has a mask that is defined by /// vector.create_mask, return that op. static FailureOr getMaskOp(Operation *loadOp) { auto transferRead = dyn_cast(loadOp); if (!transferRead || !transferRead.getMask()) return TransferMask{{}, {}}; assert(transferRead.getMask().getType().getRank() == 1 && "expected 1-D mask"); // Case 1: Mask is the result of a vector.create_mask. if (auto maskOp = transferRead.getMask().getDefiningOp()) return TransferMask{maskOp, {}}; // Case 2: Mask is the result of a vector.extract(vector.create_mask). if (auto extractOp = transferRead.getMask().getDefiningOp()) if (auto maskOp = extractOp.getVector().getDefiningOp()) return TransferMask{maskOp, SmallVector(extractOp.getStaticPosition())}; // All other cases: not supported. return failure(); } /// Build an SSA value that represents the number of read elements. static Value buildNumReadElements(OpBuilder &b, Location loc, Operation *readOp) { FailureOr transferMask = getMaskOp(readOp); assert(succeeded(transferMask) && "invalid transfer mask"); // No mask => no num_read_elements. if (!transferMask->createMaskOp) return Value(); // No extract: return size of "ones" segment in the mask. if (transferMask->extractPosition.empty()) { assert(transferMask->createMaskOp.getNumOperands() == 1 && "expected single operand"); return transferMask->createMaskOp.getOperand(0); } // vector.extract(vector.create_mask). // If extract_pos < num_ones, take number of elements from the least // significant dimension. (Do this for all dimensions and bit-AND the // conditions.) assert(transferMask->createMaskOp.getVectorType().getRank() - transferMask->extractPosition.size() == 1 && "expected N-D -> (N-1)-D extract"); Value cond; // Note: There is one more `sz` than `pos`. The loop end with the last `pos`. for (auto [pos, sz] : llvm::zip(transferMask->extractPosition, transferMask->createMaskOp->getOperands())) { Value cmp = b.create(loc, arith::CmpIPredicate::slt, b.create(loc, pos), sz); if (!cond) { cond = cmp; continue; } cond = b.create(loc, cmp, cond); } return b.create( loc, cond, transferMask->createMaskOp->getOperands().back(), b.create(loc, 0)); } /// Return "true" if the conversion to async copy is supported by "async copy". static bool resultsInSupportedAsyncCopy(MemRefType memrefType, VectorType vecType) { assert(vecType.getRank() == 1 && "expected 1-D vector"); constexpr int64_t kSupportedCpAsyncAlignmentsInBytes[3] = {4, 8, 16}; // Condition 1: the copy size must be supported. bool supportedCopySize = false; int64_t numElements = vecType.getNumElements(); Type elementType = vecType.getElementType(); for (int64_t alignmentInBytes : kSupportedCpAsyncAlignmentsInBytes) { if (alignmentInBytes * 8 == numElements * elementType.getIntOrFloatBitWidth()) { supportedCopySize = true; break; } } if (!supportedCopySize) return false; // TODO: Condition 2: the alignments must be supported. For cp.async the // NVIDIA doc (section 6.4.1) says: "The address must be naturally aligned to // a multiple of the access size. If an address is not properly aligned, the // resulting behavior is undefined.". return true; } void nvgpu::createAsyncGroups(RewriterBase &rewriter, Operation *op, bool bypassL1) { llvm::SmallSetVector copyToSharedMem; // Look for all the copy that can be converted to async copy ops. op->walk([&](Operation *writeOp) { // Look for contiguous 1D vector store into shared memory. if (!isContiguousStore(writeOp)) return; Value vectorVal = nvgpu::getValueStored(writeOp); if (cast(vectorVal.getType()).getRank() != 1) return; Value storeBase = nvgpu::getMemrefOperand(writeOp); if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace( cast(storeBase.getType()))) return; // The stored vector must originate from a contiguous 1D vector load. Operation *readOp = vectorVal.getDefiningOp(); if (readOp == nullptr || !isContiguousRead(readOp)) return; Value loadBase = nvgpu::getMemrefOperand(readOp); // Should be reading from global memory (not shared memory). if (nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace( cast(loadBase.getType()))) return; // Look for compatible mask and padding. if (auto transferRead = dyn_cast(readOp)) { if (Value mask = transferRead.getMask()) { if (getConstantIntValue(transferRead.getPadding()) == static_cast(0)) return; if (failed(getMaskOp(readOp))) return; } } // Check whether both accesses are supported before we emit: this is // necessary to ensure the correctness of DeviceAsyncCopyOp. VectorType vecType = cast(vectorVal.getType()); if (!resultsInSupportedAsyncCopy(cast(loadBase.getType()), vecType) || !resultsInSupportedAsyncCopy(cast(storeBase.getType()), vecType)) return; copyToSharedMem.insert(writeOp); return; }); while (!copyToSharedMem.empty()) { // Start a group with the first write. SmallVector group; Operation *writeOp = *copyToSharedMem.begin(); copyToSharedMem.remove(writeOp); group.push_back(writeOp); Operation *nextNode = writeOp; // Look in the next nodes for more copies to add to the same group. while ((nextNode = nextNode->getNextNode())) { // Ignore ops without side effects. auto memInterface = dyn_cast(nextNode); if (memInterface && memInterface.hasNoEffect() && !nextNode->hasTrait()) continue; // Ignore read from a different address space. if (isa(nextNode)) { Operation *readOp = nextNode; Value memrefOperand = nvgpu::getMemrefOperand(readOp); if (!nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace( cast(memrefOperand.getType()))) { continue; } } if (copyToSharedMem.count(nextNode)) { // Found another copy, add it to the group. copyToSharedMem.remove(nextNode); group.push_back(nextNode); continue; } // If the op is something else stop the accumulating op in the group. break; } // Emit the group. SmallVector tokens; for (Operation *writeOp : group) { rewriter.setInsertionPoint(writeOp); Value vectorVal = nvgpu::getValueStored(writeOp); auto vectorType = cast(vectorVal.getType()); int64_t numElements = vectorType.getNumElements(); Operation *readOp = vectorVal.getDefiningOp(); Value storeBase = nvgpu::getMemrefOperand(writeOp); Value loadBase = nvgpu::getMemrefOperand(readOp); Value numReadElements = buildNumReadElements(rewriter, writeOp->getLoc(), readOp); auto dstMemref = cast(storeBase.getType()); int64_t sizeInBytes = (dstMemref.getElementTypeBitWidth() * numElements) / 8; // bypass_l1 only possible with 16 byte transfer. Value token = rewriter.create( writeOp->getLoc(), nvgpu::DeviceAsyncTokenType::get(op->getContext()), /*dst=*/storeBase, /*dstIndices=*/nvgpu::getIndices(writeOp), /*src=*/loadBase, /*srcIndices=*/nvgpu::getIndices(readOp), /*dstElements=*/rewriter.getIndexAttr(numElements), /*srcElements=*/numReadElements, /*bypassL1=*/bypassL1 && sizeInBytes == 16 ? rewriter.getUnitAttr() : UnitAttr()); tokens.push_back(token); } // Create the group and wait for it right after. Value groupToken = rewriter.create( op->getLoc(), nvgpu::DeviceAsyncTokenType::get(op->getContext()), tokens); rewriter.create(op->getLoc(), groupToken, nullptr); // Clean up old stores. for (Operation *writeOp : group) rewriter.eraseOp(writeOp); } }