//===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements lowering of vector operations to GPU dialect ops. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/VectorToGPU/VectorToGPU.h" #include #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/NVGPU/Utils/MMAUtils.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Region.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #define DEBUG_TYPE "vector-to-gpu" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") #define DBGSNL() (llvm::dbgs() << "\n") namespace mlir { #define GEN_PASS_DEF_CONVERTVECTORTOGPU #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; /// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an /// AffineMap representing offsets to apply to indices, the function fills /// `indices` with the original indices plus the offsets. The offsets are /// applied by taking into account the permutation map of the transfer op. If /// the `offsetMap` has dimension placeholders, those should be provided in /// `dimValues`. template static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, AffineMap offsetMap, ArrayRef dimValues, SmallVector &indices) { indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end()); Location loc = xferOp.getLoc(); unsigned offsetsIdx = 0; for (auto expr : xferOp.getPermutationMap().getResults()) { if (auto dim = dyn_cast(expr)) { Value prevIdx = indices[dim.getPosition()]; SmallVector dims(dimValues.begin(), dimValues.end()); dims.push_back(prevIdx); AffineExpr d0 = rewriter.getAffineDimExpr(offsetMap.getNumDims()); indices[dim.getPosition()] = affine::makeComposedAffineApply( rewriter, loc, d0 + offsetMap.getResult(offsetsIdx++), dims); continue; } } } // Return true if the contract op can be convert to MMA matmul. static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, bool useNvGpu) { using MapList = ArrayRef>; auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; AffineExpr m, n, k; bindDims(contract.getContext(), m, n, k); auto iteratorTypes = contract.getIteratorTypes().getValue(); if (!(vector::isParallelIterator(iteratorTypes[0]) && vector::isParallelIterator(iteratorTypes[1]) && vector::isReductionIterator(iteratorTypes[2]))) return false; // The contract needs to represent a matmul to be able to convert to // MMAMatrix matmul. if (!useNvGpu && contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}})) return false; if (useNvGpu && contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}})) return false; return true; } // Return true if the given map represents a transposed matrix load, // i.e. (d0, d1, ...) -> (dn-1, dn-2). static bool isTransposeMatrixLoadMap(AffineMap permutationMap) { MLIRContext *ctx = permutationMap.getContext(); // Local OpBuilder is fine here, we just build attributes. OpBuilder b(ctx); auto nDim = permutationMap.getNumDims(); AffineExpr zero = b.getAffineConstantExpr(0); if (nDim < 2) { // Support transposed+broadcasted cases: affine_map<(d0) -> (d0, 0)>. AffineExpr dim0 = b.getAffineDimExpr(0); return permutationMap == AffineMap::get(1, 0, {dim0, zero}, ctx); } AffineExpr innerDim = b.getAffineDimExpr(nDim - 1); AffineExpr outerDim = b.getAffineDimExpr(nDim - 2); // Support both transposed and transposed+broadcasted cases. return permutationMap == AffineMap::get(nDim, 0, {innerDim, outerDim}, ctx) || permutationMap == AffineMap::get(nDim, 0, {innerDim, zero}, ctx); } // Return the stide for the second-to-last dimension of |type| if it is a memref // and has a constant stride. static std::optional getStaticallyKnownRowStride(ShapedType type) { auto memrefType = dyn_cast(type); if (!memrefType) return false; // If the memref is 0 or 1D the horizontal stride is 0. if (memrefType.getRank() < 2) return 0; int64_t offset = 0; SmallVector strides; if (failed(getStridesAndOffset(memrefType, strides, offset)) || strides.back() != 1) return std::nullopt; int64_t stride = strides[strides.size() - 2]; if (stride == ShapedType::kDynamic) return std::nullopt; return stride; } // Return true if the transfer op can be converted to a MMA matrix load. static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) { if (readOp.getMask() || readOp.hasOutOfBoundsDim() || readOp.getVectorType().getRank() != 2) return false; if (!getStaticallyKnownRowStride(readOp.getShapedType())) return false; // Only allow integer types if the signedness can be inferred. if (readOp.getVectorType().getElementType().isInteger(8)) if (!readOp->hasOneUse() || (!isa(*readOp->user_begin()) && !isa(*readOp->user_begin()))) return false; AffineMap map = readOp.getPermutationMap(); MLIRContext *ctx = readOp.getContext(); AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx); AffineExpr zero = getAffineConstantExpr(0, ctx); auto broadcastInnerDim = AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx); return map.isMinorIdentity() || map == broadcastInnerDim || isTransposeMatrixLoadMap(map); } // Return true if the transfer op can be converted to a MMA matrix store. static bool transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) { // TODO: support 0-d corner case. if (writeOp.getTransferRank() == 0) return false; if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() || writeOp.getVectorType().getRank() != 2) return false; if (!getStaticallyKnownRowStride(writeOp.getShapedType())) return false; // TODO: Support transpose once it is added to GPU dialect ops. if (!writeOp.getPermutationMap().isMinorIdentity()) return false; return true; } /// Return true if the constant is a splat to a 2D vector so that it can be /// converted to a MMA constant matrix op. static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) { auto vecType = dyn_cast(constantOp.getType()); if (!vecType || vecType.getRank() != 2) return false; return isa(constantOp.getValue()); } /// Return true if this is a broadcast from scalar to a 2D vector. static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) { return broadcastOp.getResultVectorType().getRank() == 2; } /// Return true if this integer extend op can be folded into a contract op. template static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) { if (!isa(extOp.getOperand().getDefiningOp())) return false; return llvm::all_of(extOp->getUsers(), [](Operation *user) { return isa(user); }); } static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; } /// Return the MMA elementwise enum associated with `op` if it is supported. /// Return `std::nullopt` otherwise. static std::optional convertElementwiseOpToMMA(Operation *op) { if (isa(op)) return gpu::MMAElementwiseOp::ADDF; if (isa(op)) return gpu::MMAElementwiseOp::MULF; if (isa(op)) return gpu::MMAElementwiseOp::SUBF; if (isa(op)) return gpu::MMAElementwiseOp::MAXF; if (isa(op)) return gpu::MMAElementwiseOp::MINF; if (isa(op)) return gpu::MMAElementwiseOp::DIVF; if (isa(op)) return gpu::MMAElementwiseOp::ADDI; if (isa(op)) return gpu::MMAElementwiseOp::MULI; if (isa(op)) return gpu::MMAElementwiseOp::SUBI; if (isa(op)) return gpu::MMAElementwiseOp::DIVS; if (isa(op)) return gpu::MMAElementwiseOp::DIVU; if (isa(op)) return gpu::MMAElementwiseOp::NEGATEF; if (isa(op)) return gpu::MMAElementwiseOp::EXTF; return std::nullopt; } /// Return true if the op is supported as elementwise op on MMAMatrix type. static bool elementwiseSupportsMMAMatrixType(Operation *op) { return convertElementwiseOpToMMA(op).has_value(); } /// Returns true if the extract strided slice op is supported with `mma.sync` /// path. static bool extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) { FailureOr warpMatrixInfo = nvgpu::getWarpMatrixInfo(op); if (failed(warpMatrixInfo)) return false; FailureOr contractOp = nvgpu::getUserContract(op); if (failed(contractOp)) return false; // Handle vector.extract_strided_slice on registers containing // matrixB and matrixC operands. vector.extract_strided_slice op // is not supported on registers containing matrixA operands. if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B) return (cast(op->getResult(0).getType()) == cast((*contractOp).getRhs().getType())); if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C) return (cast(op->getResult(0).getType()) == cast((*contractOp).getAcc().getType())); return false; } static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) { if (isa(op)) return true; if (auto transferRead = dyn_cast(op)) return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferRead) : transferReadSupportsMMAMatrixType(transferRead); if (auto transferWrite = dyn_cast(op)) return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferWrite) : transferWriteSupportsMMAMatrixType(transferWrite); if (auto extractStridedSlice = dyn_cast(op)) return useNvGpu && extractStridedSliceSupportsMMAMatrixType(extractStridedSlice); if (auto contract = dyn_cast(op)) return contractSupportsMMAMatrixType(contract, useNvGpu); if (auto constant = dyn_cast(op)) return constantSupportsMMAMatrixType(constant); if (auto broadcast = dyn_cast(op)) return broadcastSupportsMMAMatrixType(broadcast); if (auto signedExtend = dyn_cast(op)) return integerExtendSupportsMMAMatrixType(signedExtend); if (auto unsignedExtend = dyn_cast(op)) return integerExtendSupportsMMAMatrixType(unsignedExtend); if (auto fpExtend = dyn_cast(op)) return fpExtendSupportsMMAMatrixType(fpExtend); return elementwiseSupportsMMAMatrixType(op); } /// Return an unsorted slice handling scf.for region differently than /// `getSlice`. In scf.for we only want to include as part of the slice elements /// that are part of the use/def chain. static SetVector getSliceContract(Operation *op, const BackwardSliceOptions &backwardSliceOptions, const ForwardSliceOptions &forwardSliceOptions) { SetVector slice; slice.insert(op); unsigned currentIndex = 0; SetVector backwardSlice; SetVector forwardSlice; while (currentIndex != slice.size()) { auto *currentOp = (slice)[currentIndex]; // Compute and insert the backwardSlice starting from currentOp. backwardSlice.clear(); getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions); slice.insert(backwardSlice.begin(), backwardSlice.end()); // Compute and insert the forwardSlice starting from currentOp. forwardSlice.clear(); // Special case for ForOp, we don't want to include the whole region but // only the value using the region arguments. // TODO: We should refine this to only care about the region arguments being // converted to matrix type. if (auto forOp = dyn_cast(currentOp)) { for (Value forOpResult : forOp.getResults()) getForwardSlice(forOpResult, &forwardSlice, forwardSliceOptions); for (BlockArgument &arg : forOp.getRegionIterArgs()) getForwardSlice(arg, &forwardSlice, forwardSliceOptions); } else { getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions); } slice.insert(forwardSlice.begin(), forwardSlice.end()); ++currentIndex; } return slice; } // Analyze slice of operations based on convert op to figure out if the whole // slice can be converted to MMA operations. static SetVector getOpToConvert(mlir::Operation *op, bool useNvGpu) { auto hasVectorDest = [](Operation *op) { return llvm::any_of(op->getResultTypes(), [](Type t) { return isa(t); }); }; BackwardSliceOptions backwardSliceOptions; backwardSliceOptions.filter = hasVectorDest; auto hasVectorSrc = [](Operation *op) { return llvm::any_of(op->getOperandTypes(), [](Type t) { return isa(t); }); }; ForwardSliceOptions forwardSliceOptions; forwardSliceOptions.filter = hasVectorSrc; SetVector opToConvert; op->walk([&](vector::ContractionOp contract) { if (opToConvert.contains(contract.getOperation())) return; SetVector dependentOps = getSliceContract(contract, backwardSliceOptions, forwardSliceOptions); // If any instruction cannot use MMA matrix type drop the whole // chain. MMA matrix are stored in an opaque type so they cannot be used // by all operations. if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) { if (!supportsMMaMatrixType(op, useNvGpu)) { LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n"); return true; } return false; })) return; opToConvert.insert(dependentOps.begin(), dependentOps.end()); }); // Sort the operations so that we can convert them in topological order. return topologicalSort(opToConvert); } namespace { // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted // to MMA matmul. struct PrepareContractToGPUMMA : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc(); // Set up the parallel/reduction structure in right form. using MapList = ArrayRef>; auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; AffineExpr m, n, k; bindDims(rewriter.getContext(), m, n, k); static constexpr std::array perm = {1, 0}; auto iteratorTypes = op.getIteratorTypes().getValue(); SmallVector maps = op.getIndexingMapsArray(); if (!(vector::isParallelIterator(iteratorTypes[0]) && vector::isParallelIterator(iteratorTypes[1]) && vector::isReductionIterator(iteratorTypes[2]))) return rewriter.notifyMatchFailure(op, "not a gemm contraction"); // // Two outer parallel, one inner reduction (matmat flavor). // // This is the classical row-major matmul, nothing to do. if (maps == infer({{m, k}, {k, n}, {m, n}})) return rewriter.notifyMatchFailure(op, "contraction already prepared"); if (maps == infer({{m, k}, {n, k}, {m, n}})) { rhs = rewriter.create(loc, rhs, perm); } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { lhs = rewriter.create(loc, lhs, perm); } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { rhs = rewriter.create(loc, rhs, perm); lhs = rewriter.create(loc, lhs, perm); } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { std::swap(rhs, lhs); rhs = rewriter.create(loc, rhs, perm); lhs = rewriter.create(loc, lhs, perm); } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { std::swap(rhs, lhs); rhs = rewriter.create(loc, rhs, perm); } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { std::swap(lhs, rhs); lhs = rewriter.create(loc, lhs, perm); } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { std::swap(lhs, rhs); } else { // TODO: llvm_unreachable ? return rewriter.notifyMatchFailure(op, "unexpected contraction case"); } rewriter.replaceOpWithNewOp( op, lhs, rhs, res, rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})), op.getIteratorTypes()); return success(); } }; // Fold transpose op into the transfer read op. Nvgpu mma.sync op only supports // row-, column-, and row-major layout for matrixA, matrixB, and matrixC, // respectively. We can fold the transpose operation when loading the data from // Shared Memory to registers. struct CombineTransferReadOpTranspose final : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override { // Look through integer extend ops. Value source = op.getVector(); Type resultType = op.getType(); Operation *extOp; if ((extOp = source.getDefiningOp()) || (extOp = source.getDefiningOp()) || (extOp = source.getDefiningOp())) { source = extOp->getOperand(0); resultType = VectorType::get(cast(resultType).getShape(), cast(source.getType()).getElementType()); } auto transferReadOp = source.getDefiningOp(); if (!transferReadOp) return rewriter.notifyMatchFailure(op, "no transfer read"); // TODO: support 0-d corner case. if (transferReadOp.getTransferRank() == 0) return rewriter.notifyMatchFailure(op, "0-D transfer read"); if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim()) return rewriter.notifyMatchFailure(op, "not inbounds transfer read"); AffineMap permutationMap = AffineMap::getPermutationMap(op.getPermutation(), op.getContext()); AffineMap newMap = permutationMap.compose(transferReadOp.getPermutationMap()); auto loc = op.getLoc(); Value result = rewriter .create( loc, resultType, transferReadOp.getSource(), transferReadOp.getIndices(), AffineMapAttr::get(newMap), transferReadOp.getPadding(), transferReadOp.getMask(), transferReadOp.getInBoundsAttr()) .getResult(); // Fuse through the integer extend op. if (extOp) { if (isa(extOp)) result = rewriter.create(loc, op.getType(), result) .getResult(); else if (isa(extOp)) result = rewriter.create(loc, op.getType(), result) .getResult(); else result = rewriter.create(loc, op.getType(), result) .getResult(); } rewriter.replaceOp(op, result); return success(); } }; } // namespace // MMA types have different layout based on how they are used in matmul ops. // Figure the right layout to use by looking at op uses. // TODO: Change the GPU dialect to abstract the layout at the this level and // only care about it during lowering to NVVM. static const char *inferFragType(Operation *op) { for (Operation *users : op->getUsers()) { auto contract = dyn_cast(users); if (!contract) continue; assert(op->getNumResults() == 1); if (contract.getLhs() == op->getResult(0)) return "AOp"; if (contract.getRhs() == op->getResult(0)) return "BOp"; } return "COp"; } static LogicalResult convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap &valueMapping) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(op); assert(op.getTransferRank() > 0 && "unexpected 0-d transfer"); assert(transferReadSupportsMMAMatrixType(op) && "expected convertible operation"); std::optional stride = getStaticallyKnownRowStride(op.getShapedType()); if (!stride.has_value()) { LLVM_DEBUG(DBGS() << "no stride\n"); return rewriter.notifyMatchFailure(op, "no stride"); } AffineMap map = op.getPermutationMap(); bool isTranspose = isTransposeMatrixLoadMap(map); // Handle broadcast by setting the stride to 0. if (auto cstExpr = dyn_cast(map.getResult(isTranspose))) { assert(cstExpr.getValue() == 0); stride = 0; } Value mappingResult = op.getResult(); auto elType = op.getVectorType().getElementType(); const char *fragType = inferFragType(op); if (op->hasOneUse()) { auto *user = *op->user_begin(); // Infer the signedness of the mma type from the integer extend. bool isSignedExtend = isa(user); if (isSignedExtend || isa(user)) { elType = IntegerType::get( op.getContext(), cast(elType).getWidth(), isSignedExtend ? IntegerType::Signed : IntegerType::Unsigned); mappingResult = user->getResult(0); fragType = inferFragType(user); } } gpu::MMAMatrixType type = gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType); Value load = rewriter.create( op.getLoc(), type, op.getSource(), op.getIndices(), rewriter.getIndexAttr(*stride), isTranspose ? rewriter.getUnitAttr() : UnitAttr()); valueMapping[mappingResult] = load; LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n"); return success(); } static LogicalResult convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap &valueMapping) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(op); assert(transferWriteSupportsMMAMatrixType(op)); std::optional stride = getStaticallyKnownRowStride(op.getShapedType()); if (!stride.has_value()) { LLVM_DEBUG(DBGS() << "no stride\n"); return rewriter.notifyMatchFailure(op, "no stride"); } auto it = valueMapping.find(op.getVector()); if (it == valueMapping.end()) { LLVM_DEBUG(DBGS() << "no mapping\n"); return rewriter.notifyMatchFailure(op, "no mapping"); } Value matrix = it->second; auto store = rewriter.create( op.getLoc(), matrix, op.getSource(), op.getIndices(), rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr()); (void)store; LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n"); LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); rewriter.eraseOp(op); return success(); } /// Returns the vector type which represents a matrix fragment. static VectorType getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo ®Info) { SmallVector shape{regInfo.numRegistersPerFragment, regInfo.elementsPerRegister}; Type elType = regInfo.registerLLVMType; if (auto vecType = dyn_cast(elType)) elType = vecType.getElementType(); return VectorType::get(shape, elType); } /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. static LogicalResult convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap &valueMapping) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(op); FailureOr warpMatrixInfo = nvgpu::getWarpMatrixInfo(op); if (failed(warpMatrixInfo)) { LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n"); return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); } FailureOr regInfo = nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); if (failed(regInfo)) { LLVM_DEBUG(DBGS() << "not mma sync reg info\n"); return rewriter.notifyMatchFailure(op, "not mma sync reg info"); } VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); auto dense = dyn_cast(op.getValue()); if (!dense) { LLVM_DEBUG(DBGS() << "not a splat\n"); return rewriter.notifyMatchFailure(op, "not a splat"); } Value result = rewriter.create( op.getLoc(), vectorType, DenseElementsAttr::get(vectorType, dense.getSplatValue())); valueMapping[op.getResult()] = result; return success(); } /// Check if the loaded matrix operand requires transposed. /// Transposed Map Example: /// Example 1 : (..., d0, d1) -> (d1 * 1, d0 * 2) /// Example 2 : (d0, d1, d2, d3) -> (d3, d2) /// The code below checks if the output 2D is transposed using a generalized /// version : (d0, d1, dn, ..., dm, ...) -> (dm, dn) /// Returns : true; if m > n, false o.w. static FailureOr isTransposed(vector::TransferReadOp op) { mlir::AffineMap map = op.getPermutationMap(); if (map.getNumResults() != 2) { LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` " "is not a 2d operand\n"); return failure(); } // Output 2D matrix dimensions in the order of d0, d1. mlir::AffineExpr dM = map.getResult(0); mlir::AffineExpr dN = map.getResult(1); // Find the position of these expressions in the input. auto exprM = dyn_cast(dM); auto exprN = dyn_cast(dN); if (!exprM || !exprN) { LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim " "expressions, then transpose cannot be determined.\n"); return failure(); } return exprM.getPosition() > exprN.getPosition(); } static LogicalResult creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap &valueMapping) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(op); Location loc = op->getLoc(); FailureOr warpMatrixInfo = nvgpu::getWarpMatrixInfo(op); if (failed(warpMatrixInfo)) { LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n"); return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); } FailureOr regInfo = nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); if (failed(regInfo)) { LLVM_DEBUG(DBGS() << "not mma sync reg info\n"); return rewriter.notifyMatchFailure(op, "not mma sync reg info"); } FailureOr transpose = isTransposed(op); if (failed(transpose)) { LLVM_DEBUG(DBGS() << "failed to determine the transpose\n"); return rewriter.notifyMatchFailure( op, "Op should likely not be converted to a nvgpu.ldmatrix call."); } FailureOr params = nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose); if (failed(params)) { LLVM_DEBUG( DBGS() << "failed to convert vector.transfer_read to ldmatrix. " << "Op should likely not be converted to a nvgpu.ldmatrix call.\n"); return rewriter.notifyMatchFailure( op, "failed to convert vector.transfer_read to ldmatrix; this op " "likely should not be converted to a nvgpu.ldmatrix call."); } // Adjust the load offset. auto laneId = rewriter.create(loc); FailureOr offsets = nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params); if (failed(offsets)) { LLVM_DEBUG(DBGS() << "no offsets\n"); return rewriter.notifyMatchFailure(op, "no offsets"); } VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); SmallVector indices; getXferIndices(rewriter, op, *offsets, {laneId}, indices); nvgpu::LdMatrixOp newOp = rewriter.create( loc, vectorType, op.getSource(), indices, *transpose, params->numTiles); valueMapping[op] = newOp->getResult(0); return success(); } static LogicalResult createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap &valueMapping) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(op); Location loc = op.getLoc(); FailureOr warpMatrixInfo = nvgpu::getWarpMatrixInfo(op); if (failed(warpMatrixInfo)) return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); FailureOr regInfo = nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); if (failed(regInfo)) { return rewriter.notifyMatchFailure( op, "Failed to deduce register fragment type during " "conversion to distributed non-ldmatrix compatible load"); } Value laneId = rewriter.create(loc); SmallVector elements; // This is the individual element type. Type loadedElType = regInfo->registerLLVMType; VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); Value fill = rewriter.create( op.getLoc(), vectorType.getElementType(), rewriter.getZeroAttr(vectorType.getElementType())); Value result = rewriter.create(op.getLoc(), fill, vectorType); bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity(); // If we are not transposing, then we can use vectorized loads. Otherwise, we // must load each element individually. if (!isTransposeLoad) { if (!isa(loadedElType)) { loadedElType = VectorType::get({1}, loadedElType); } for (int i = 0; i < vectorType.getShape()[0]; i++) { FailureOr coords = nvgpu::getLaneIdAndValueIdToOperandCoord( rewriter, op.getLoc(), *warpMatrixInfo); if (failed(coords)) return rewriter.notifyMatchFailure(op, "no coords"); Value logicalValueId = rewriter.create( loc, rewriter.getIndexType(), rewriter.getIndexAttr(i * regInfo->elementsPerRegister)); SmallVector newIndices; getXferIndices( rewriter, op, *coords, {laneId, logicalValueId}, newIndices); Value el = rewriter.create(loc, loadedElType, op.getSource(), newIndices); result = rewriter.create(loc, el, result, i); } } else { if (auto vecType = dyn_cast(loadedElType)) { loadedElType = vecType.getElementType(); } for (int i = 0; i < vectorType.getShape()[0]; i++) { for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1]; innerIdx++) { Value logicalValueId = rewriter.create( loc, rewriter.getIndexType(), rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx)); FailureOr coords = nvgpu::getLaneIdAndValueIdToOperandCoord( rewriter, op.getLoc(), *warpMatrixInfo); if (failed(coords)) return rewriter.notifyMatchFailure(op, "no coords"); SmallVector newIndices; getXferIndices( rewriter, op, *coords, {laneId, logicalValueId}, newIndices); Value el = rewriter.create(op.getLoc(), loadedElType, op.getSource(), newIndices); result = rewriter.create( op.getLoc(), el, result, ArrayRef{i, innerIdx}); } } } valueMapping[op.getResult()] = result; return success(); } /// Return true if this is a shared memory memref type. static bool isSharedMemory(MemRefType type) { auto addressSpace = dyn_cast_or_null(type.getMemorySpace()); return addressSpace && addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace(); } /// Converts a `vector.transfer_read` operation directly to either a /// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be /// used when converting to `nvgpu.mma.sync` operations. static LogicalResult convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap &valueMapping) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(op); FailureOr warpMatrixInfo = nvgpu::getWarpMatrixInfo(op); if (failed(warpMatrixInfo)) return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); bool isLdMatrixCompatible = isSharedMemory(cast(op.getSource().getType())) && nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128; VectorType vecTy = op.getVectorType(); int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth(); // When we are transposing the B operand, ldmatrix will only work if we have // at least 8 rows to read and the width to read for the transpose is 128 // bits. if (!op.getPermutationMap().isMinorIdentity() && (bitWidth != 16 || vecTy.getDimSize(1) < 8 || vecTy.getDimSize(0) * bitWidth < 128)) isLdMatrixCompatible = false; if (!isLdMatrixCompatible) return createNonLdMatrixLoads(rewriter, op, valueMapping); return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping); } static LogicalResult convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap &valueMapping) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(op); Location loc = op->getLoc(); auto it = valueMapping.find(op.getVector()); if (it == valueMapping.end()) return rewriter.notifyMatchFailure(op, "no mapping"); Value matrix = it->second; FailureOr warpMatrixInfo = nvgpu::getWarpMatrixInfo(op); if (failed(warpMatrixInfo)) return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); FailureOr regInfo = nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); if (failed(regInfo)) return rewriter.notifyMatchFailure(op, "not mma sync reg info"); VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); Value laneId = rewriter.create(loc); for (unsigned i = 0; i < vectorType.getShape()[0]; i++) { Value logicalValueId = rewriter.create( loc, rewriter.getIndexType(), rewriter.getIndexAttr(i * regInfo->elementsPerRegister)); FailureOr coords = nvgpu::getLaneIdAndValueIdToOperandCoord( rewriter, op.getLoc(), *warpMatrixInfo); if (failed(coords)) return rewriter.notifyMatchFailure(op, "no coords"); Value el = rewriter.create(loc, matrix, ArrayRef{i}); SmallVector newIndices; getXferIndices( rewriter, op, *coords, {laneId, logicalValueId}, newIndices); rewriter.create(loc, el, op.getSource(), newIndices); } LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); rewriter.eraseOp(op); return success(); } static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl &results) { for (auto attr : arrayAttr) results.push_back(cast(attr).getInt()); } static LogicalResult convertExtractStridedSlice(RewriterBase &rewriter, vector::ExtractStridedSliceOp op, llvm::DenseMap &valueMapping) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(op); Location loc = op->getLoc(); FailureOr warpMatrixInfo = nvgpu::getWarpMatrixInfo(op); if (failed(warpMatrixInfo)) return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); FailureOr mmaSyncFragmentInfo = nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); if (failed(mmaSyncFragmentInfo)) return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo"); // Find the vector.transer_read whose result vector is being sliced. auto transferReadOp = op.getVector().getDefiningOp(); if (!transferReadOp) return rewriter.notifyMatchFailure(op, "no transfer read"); warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp); if (failed(warpMatrixInfo)) return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); FailureOr ldFragmentInfo = nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); if (failed(ldFragmentInfo)) return rewriter.notifyMatchFailure(op, "no ldFragmentInfo"); assert( (mmaSyncFragmentInfo->elementsPerRegister == ldFragmentInfo->elementsPerRegister) && "Number of elements per register should be same for load and mma.sync"); // Create vector.extract_strided_slice op for thread-owned fragments. std::array strides = {1, 1}; // stride for extract slice is always 1. std::array sliceShape = { mmaSyncFragmentInfo->numRegistersPerFragment, mmaSyncFragmentInfo->elementsPerRegister}; auto it = valueMapping.find(transferReadOp); if (it == valueMapping.end()) return rewriter.notifyMatchFailure(op, "no mapping"); auto sourceVector = it->second; // offset and sizes at warp-level of onwership. SmallVector offsets; populateFromInt64AttrArray(op.getOffsets(), offsets); SmallVector sizes; populateFromInt64AttrArray(op.getSizes(), sizes); ArrayRef warpVectorShape = op.getSourceVectorType().getShape(); // Compute offset in vector registers. Note that the mma.sync vector registers // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0]. std::array sliceOffset = {0, 0}; if (offsets[0] && offsets[1]) return op->emitError() << "Slicing fragments in 2D is not supported. "; if (offsets[0]) sliceOffset[0] = (warpVectorShape[0] / offsets[0]); else if (offsets[1]) sliceOffset[0] = (warpVectorShape[1] / offsets[1]); Value newOp = rewriter.create( loc, sourceVector, sliceOffset, sliceShape, strides); valueMapping[op] = newOp; return success(); } static LogicalResult convertContractOp(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap &valueMapping) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(op); auto itA = valueMapping.find(op.getLhs()); auto itB = valueMapping.find(op.getRhs()); auto itC = valueMapping.find(op.getAcc()); if (itA == valueMapping.end() || itB == valueMapping.end() || itC == valueMapping.end()) return rewriter.notifyMatchFailure(op, "no mapping"); Value opA = itA->second, opB = itB->second, opC = itC->second; Value matmul = rewriter.create( op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(), /*b_transpose=*/UnitAttr()); valueMapping[op.getResult()] = matmul; return success(); } static LogicalResult convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap &valueMapping) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(op); auto itA = valueMapping.find(op.getLhs()); auto itB = valueMapping.find(op.getRhs()); auto itC = valueMapping.find(op.getAcc()); if (itA == valueMapping.end() || itB == valueMapping.end() || itC == valueMapping.end()) return rewriter.notifyMatchFailure(op, "no mapping"); Value opA = itA->second, opB = itB->second, opC = itC->second; int64_t m = cast(op.getLhs().getType()).getShape()[0]; int64_t n = cast(op.getRhs().getType()).getShape()[0]; int64_t k = cast(op.getLhs().getType()).getShape()[1]; Value matmul = rewriter.create( op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k})); valueMapping[op.getResult()] = matmul; return success(); } /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op. static LogicalResult convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap &valueMapping) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(op); assert(constantSupportsMMAMatrixType(op)); auto splat = cast(op.getValue()).getSplatValue(); auto scalarConstant = rewriter.create(op.getLoc(), splat.getType(), splat); const char *fragType = inferFragType(op); auto vecType = cast(op.getType()); gpu::MMAMatrixType type = gpu::MMAMatrixType::get( vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); auto matrix = rewriter.create( op.getLoc(), type, scalarConstant); valueMapping[op.getResult()] = matrix; return success(); } /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op. static LogicalResult convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op, llvm::DenseMap &valueMapping) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(op); assert(broadcastSupportsMMAMatrixType(op)); const char *fragType = inferFragType(op); auto vecType = op.getResultVectorType(); gpu::MMAMatrixType type = gpu::MMAMatrixType::get( vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType)); auto matrix = rewriter.create( op.getLoc(), type, op.getSource()); valueMapping[op.getResult()] = matrix; return success(); } // Replace ForOp with a new ForOp with extra operands. The YieldOp is not // updated and needs to be updated separately for the loop to be correct. static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, scf::ForOp loop, ValueRange newInitArgs) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(loop); // Create a new loop before the existing one, with the extra operands. rewriter.setInsertionPoint(loop); auto operands = llvm::to_vector<4>(loop.getInitArgs()); llvm::append_range(operands, newInitArgs); scf::ForOp newLoop = rewriter.create( loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), operands); rewriter.eraseBlock(newLoop.getBody()); newLoop.getRegion().getBlocks().splice( newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks()); for (Value operand : newInitArgs) newLoop.getBody()->addArgument(operand.getType(), operand.getLoc()); for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front( loop.getNumResults()))) rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n"); LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n"); LLVM_DEBUG(DBGS() << "erase: " << loop); rewriter.eraseOp(loop); return newLoop; } static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, llvm::DenseMap &valueMapping) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(op); SmallVector newOperands; SmallVector> argMapping; for (const auto &operand : llvm::enumerate(op.getInitArgs())) { auto it = valueMapping.find(operand.value()); if (it == valueMapping.end()) { LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n"); continue; } argMapping.push_back(std::make_pair( operand.index(), op.getInitArgs().size() + newOperands.size())); newOperands.push_back(it->second); } scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, op, newOperands); Block &loopBody = *newForOp.getBody(); for (auto mapping : argMapping) { valueMapping[newForOp.getResult(mapping.first)] = newForOp.getResult(mapping.second); valueMapping[loopBody.getArgument(mapping.first + newForOp.getNumInductionVars())] = loopBody.getArgument(mapping.second + newForOp.getNumInductionVars()); } LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n"); return success(); } static LogicalResult convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, llvm::DenseMap &valueMapping) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(op); auto loop = cast(op->getParentOp()); auto yieldOperands = llvm::to_vector<4>(op.getOperands()); for (const auto &operand : llvm::enumerate(op.getOperands())) { auto it = valueMapping.find(operand.value()); if (it == valueMapping.end()) continue; // Replace the yield of old value with the for op argument to make it easier // to remove the dead code. yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()]; yieldOperands.push_back(it->second); } rewriter.create(op.getLoc(), yieldOperands); LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); rewriter.eraseOp(op); return success(); } /// Convert an elementwise op to the equivalent elementwise op on MMA matrix. static LogicalResult convertElementwiseOp(RewriterBase &rewriter, Operation *op, gpu::MMAElementwiseOp opType, llvm::DenseMap &valueMapping) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(op); SmallVector matrixOperands; for (Value operand : op->getOperands()) { auto it = valueMapping.find(operand); if (it == valueMapping.end()) return rewriter.notifyMatchFailure(op, "no mapping"); matrixOperands.push_back(it->second); } auto resultType = matrixOperands[0].getType().cast(); if (opType == gpu::MMAElementwiseOp::EXTF) { // The floating point extension case has a different result type. auto vectorType = op->getResultTypes()[0].cast(); resultType = gpu::MMAMatrixType::get(resultType.getShape(), vectorType.getElementType(), resultType.getOperand()); } Value newOp = rewriter.create( op->getLoc(), resultType, matrixOperands, opType); valueMapping[op->getResult(0)] = newOp; return success(); } void mlir::populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, bool useNvGpu) { if (!useNvGpu) { patterns.add( patterns.getContext()); return; } vector::populateVectorContractCanonicalizeMatmulToMMT(patterns); patterns.add(patterns.getContext()); } LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter, Operation *rootOp) { SetVector ops = getOpToConvert(rootOp, /*useNvGpu=*/false); llvm::DenseMap valueMapping; auto globalRes = LogicalResult::success(); for (Operation *op : ops) { LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n"); // Apparently callers do not want to early exit on failure here. auto res = LogicalResult::success(); if (auto transferRead = dyn_cast(op)) { res = convertTransferReadOp(rewriter, transferRead, valueMapping); } else if (auto transferWrite = dyn_cast(op)) { res = convertTransferWriteOp(rewriter, transferWrite, valueMapping); } else if (auto contractOp = dyn_cast(op)) { res = convertContractOp(rewriter, contractOp, valueMapping); } else if (auto constantOp = dyn_cast(op)) { res = convertConstantOp(rewriter, constantOp, valueMapping); } else if (auto broadcastOp = dyn_cast(op)) { res = convertBroadcastOp(rewriter, broadcastOp, valueMapping); } else if (auto forOp = dyn_cast(op)) { res = convertForOp(rewriter, forOp, valueMapping); } else if (auto yieldOp = dyn_cast(op)) { res = convertYieldOp(rewriter, yieldOp, valueMapping); } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) { res = convertElementwiseOp(rewriter, op, *elementwiseType, valueMapping); } if (failed(res)) globalRes = failure(); } return globalRes; } LogicalResult mlir::convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter, Operation *rootOp) { SetVector ops = getOpToConvert(rootOp, /*useNvGpu=*/true); llvm::DenseMap valueMapping; for (Operation *op : ops) { if (llvm::TypeSwitch(op) .Case([&](vector::TransferReadOp transferReadOp) { return convertTransferReadToLoads(rewriter, transferReadOp, valueMapping); }) .Case([&](vector::TransferWriteOp transferWriteOp) { return convertTransferWriteToStores(rewriter, transferWriteOp, valueMapping); }) .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) { return convertExtractStridedSlice(rewriter, extractStridedSliceOp, valueMapping); }) .Case([&](vector::ContractionOp contractionOp) { return convertContractOpToMmaSync(rewriter, contractionOp, valueMapping); }) .Case([&](scf::ForOp forOp) { return convertForOp(rewriter, forOp, valueMapping); }) .Case([&](scf::YieldOp yieldOp) { return convertYieldOp(rewriter, yieldOp, valueMapping); }) .Case([&](arith::ConstantOp constOp) { return convertConstantOpMmaSync(rewriter, constOp, valueMapping); }) .Default([&](Operation *op) { return op->emitError() << "unhandled vector to mma type: " << *op; }) .failed()) { return op->emitOpError() << "failed to convert op during vector-to-nvgpu conversion"; } } return success(); } namespace { struct ConvertVectorToGPUPass : public impl::ConvertVectorToGPUBase { explicit ConvertVectorToGPUPass(bool useNvGpu_) { useNvGpu.setValue(useNvGpu_); } void runOnOperation() override { RewritePatternSet patterns(&getContext()); populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue()); if (failed( applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); IRRewriter rewriter(&getContext()); if (useNvGpu) { if (failed( convertVectorToNVVMCompatibleMMASync(rewriter, getOperation()))) return signalPassFailure(); return; } (void)convertVectorToMMAOps(rewriter, getOperation()); } }; } // namespace std::unique_ptr mlir::createConvertVectorToGPUPass(bool useNvGpu) { return std::make_unique(useNvGpu); }