//===- SubgroupReduceLowering.cpp - subgroup_reduce lowering patterns -----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Implements gradual lowering of `gpu.subgroup_reduce` ops. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/GPU/Transforms/Passes.h" #include "mlir/Dialect/GPU/Transforms/Utils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LogicalResult.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MathExtras.h" #include #include using namespace mlir; namespace { /// Example, assumes `maxShuffleBitwidth` equal to 32: /// ``` /// %a = gpu.subgroup_reduce add %x : (vector<3xf16>) -> vector<3xf16> /// ==> /// %v0 = arith.constant dense<0.0> : vector<3xf16> /// %e0 = vector.extract_strided_slice %x /// {offsets = [0], sizes = [2], strides = [1}: vector<3xf32> to vector<2xf32> /// %r0 = gpu.subgroup_reduce add %e0 : (vector<2xf16>) -> vector<2xf16> /// %v1 = vector.insert_strided_slice %r0, %v0 /// {offsets = [0], strides = [1}: vector<2xf32> into vector<3xf32> /// %e1 = vector.extract %x[2] : f16 from vector<2xf16> /// %r1 = gpu.subgroup_reduce add %e1 : (f16) -> f16 /// %a = vector.insert %r1, %v1[2] : f16 into vector<3xf16> /// ``` struct BreakDownSubgroupReduce final : OpRewritePattern { BreakDownSubgroupReduce(MLIRContext *ctx, unsigned maxShuffleBitwidth, PatternBenefit benefit) : OpRewritePattern(ctx, benefit), maxShuffleBitwidth(maxShuffleBitwidth) { } LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, PatternRewriter &rewriter) const override { auto vecTy = dyn_cast(op.getType()); if (!vecTy || vecTy.getNumElements() < 2) return rewriter.notifyMatchFailure(op, "not a multi-element reduction"); assert(vecTy.getRank() == 1 && "Unexpected vector type"); assert(!vecTy.isScalable() && "Unexpected vector type"); Type elemTy = vecTy.getElementType(); unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth(); if (elemBitwidth >= maxShuffleBitwidth) return rewriter.notifyMatchFailure( op, llvm::formatv("element type too large ({0}), cannot break down " "into vectors of bitwidth {1} or less", elemBitwidth, maxShuffleBitwidth)); unsigned elementsPerShuffle = maxShuffleBitwidth / elemBitwidth; assert(elementsPerShuffle >= 1); unsigned numNewReductions = llvm::divideCeil(vecTy.getNumElements(), elementsPerShuffle); assert(numNewReductions >= 1); if (numNewReductions == 1) return rewriter.notifyMatchFailure(op, "nothing to break down"); Location loc = op.getLoc(); Value res = rewriter.create(loc, rewriter.getZeroAttr(vecTy)); for (unsigned i = 0; i != numNewReductions; ++i) { int64_t startIdx = i * elementsPerShuffle; int64_t endIdx = std::min(startIdx + elementsPerShuffle, vecTy.getNumElements()); int64_t numElems = endIdx - startIdx; Value extracted; if (numElems == 1) { extracted = rewriter.create(loc, op.getValue(), startIdx); } else { extracted = rewriter.create( loc, op.getValue(), /*offsets=*/startIdx, /*sizes=*/numElems, /*strides=*/1); } Value reduce = rewriter.create( loc, extracted, op.getOp(), op.getUniform()); if (numElems == 1) { res = rewriter.create(loc, reduce, res, startIdx); continue; } res = rewriter.create( loc, reduce, res, /*offsets=*/startIdx, /*strides=*/1); } rewriter.replaceOp(op, res); return success(); } private: unsigned maxShuffleBitwidth = 0; }; /// Example: /// ``` /// %a = gpu.subgroup_reduce add %x : (vector<1xf32>) -> vector<1xf32> /// ==> /// %e0 = vector.extract %x[0] : f32 from vector<1xf32> /// %r0 = gpu.subgroup_reduce add %e0 : (f32) -> f32 /// %a = vector.broadcast %r0 : f32 to vector<1xf32> /// ``` struct ScalarizeSingleElementReduce final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, PatternRewriter &rewriter) const override { auto vecTy = dyn_cast(op.getType()); if (!vecTy || vecTy.getNumElements() != 1) return rewriter.notifyMatchFailure(op, "not a single-element reduction"); assert(vecTy.getRank() == 1 && "Unexpected vector type"); assert(!vecTy.isScalable() && "Unexpected vector type"); Location loc = op.getLoc(); Value extracted = rewriter.create(loc, op.getValue(), 0); Value reduce = rewriter.create( loc, extracted, op.getOp(), op.getUniform()); rewriter.replaceOpWithNewOp(op, vecTy, reduce); return success(); } }; /// Emits a subgroup reduction using a sequence of shuffles. Uses the `packFn` /// and `unpackFn` to convert to the native shuffle type and to the reduction /// type, respectively. For example, with `input` of type `f16`, `packFn` could /// build ops to cast the value to `i32` to perform shuffles, while `unpackFn` /// would cast it back to `f16` to perform arithmetic reduction on. Assumes that /// the subgroup is `subgroupSize` lanes wide and reduces across all of them. static Value createSubgroupShuffleReduction( OpBuilder &builder, Location loc, Value input, gpu::AllReduceOperation mode, unsigned subgroupSize, function_ref packFn, function_ref unpackFn) { assert(llvm::isPowerOf2_32(subgroupSize)); // Lane value always stays in the original type. We use it to perform arith // reductions. Value laneVal = input; // Parallel reduction using butterfly shuffles. for (unsigned i = 1; i < subgroupSize; i <<= 1) { Value shuffled = builder .create(loc, packFn(laneVal), i, /*width=*/subgroupSize, /*mode=*/gpu::ShuffleMode::XOR) .getShuffleResult(); laneVal = vector::makeArithReduction(builder, loc, gpu::convertReductionKind(mode), laneVal, unpackFn(shuffled)); assert(laneVal.getType() == input.getType()); } return laneVal; } /// Lowers scalar gpu subgroup reductions to a series of shuffles. struct ScalarSubgroupReduceToShuffles final : OpRewritePattern { ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize, unsigned shuffleBitwidth, PatternBenefit benefit) : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize), shuffleBitwidth(shuffleBitwidth) {} LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, PatternRewriter &rewriter) const override { Type valueTy = op.getType(); unsigned elemBitwidth = getElementTypeOrSelf(valueTy).getIntOrFloatBitWidth(); if (!valueTy.isIntOrFloat() || elemBitwidth > shuffleBitwidth) return rewriter.notifyMatchFailure( op, "value type is not a compatible scalar"); Location loc = op.getLoc(); // Since this is already a native shuffle scalar, no packing is necessary. if (elemBitwidth == shuffleBitwidth) { auto identityFn = [](Value v) { return v; }; rewriter.replaceOp(op, createSubgroupShuffleReduction( rewriter, loc, op.getValue(), op.getOp(), subgroupSize, identityFn, identityFn)); return success(); } auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth); auto equivIntType = rewriter.getIntegerType(elemBitwidth); auto packFn = [loc, &rewriter, equivIntType, shuffleIntType](Value unpackedVal) -> Value { auto asInt = rewriter.create(loc, equivIntType, unpackedVal); return rewriter.create(loc, shuffleIntType, asInt); }; auto unpackFn = [loc, &rewriter, equivIntType, valueTy](Value packedVal) -> Value { auto asInt = rewriter.create(loc, equivIntType, packedVal); return rewriter.create(loc, valueTy, asInt); }; rewriter.replaceOp(op, createSubgroupShuffleReduction( rewriter, loc, op.getValue(), op.getOp(), subgroupSize, packFn, unpackFn)); return success(); } private: unsigned subgroupSize = 0; unsigned shuffleBitwidth = 0; }; /// Lowers vector gpu subgroup reductions to a series of shuffles. struct VectorSubgroupReduceToShuffles final : OpRewritePattern { VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize, unsigned shuffleBitwidth, PatternBenefit benefit) : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize), shuffleBitwidth(shuffleBitwidth) {} LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, PatternRewriter &rewriter) const override { auto vecTy = dyn_cast(op.getType()); if (!vecTy) return rewriter.notifyMatchFailure(op, "value type is not a vector"); unsigned vecBitwidth = vecTy.getNumElements() * vecTy.getElementTypeBitWidth(); if (vecBitwidth > shuffleBitwidth) return rewriter.notifyMatchFailure( op, llvm::formatv("vector type bitwidth too large ({0}), cannot lower " "to shuffles of size {1}", vecBitwidth, shuffleBitwidth)); unsigned elementsPerShuffle = shuffleBitwidth / vecTy.getElementTypeBitWidth(); if (elementsPerShuffle * vecTy.getElementTypeBitWidth() != shuffleBitwidth) return rewriter.notifyMatchFailure( op, "shuffle bitwidth is not a multiple of the element bitwidth"); Location loc = op.getLoc(); // If the reduced type is smaller than the native shuffle size, extend it, // perform the shuffles, and extract at the end. auto extendedVecTy = VectorType::get( static_cast(elementsPerShuffle), vecTy.getElementType()); Value extendedInput = op.getValue(); if (vecBitwidth < shuffleBitwidth) { auto zero = rewriter.create( loc, rewriter.getZeroAttr(extendedVecTy)); extendedInput = rewriter.create( loc, extendedInput, zero, /*offsets=*/0, /*strides=*/1); } auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth); auto shuffleVecType = VectorType::get(1, shuffleIntType); auto packFn = [loc, &rewriter, shuffleVecType](Value unpackedVal) -> Value { auto asIntVec = rewriter.create(loc, shuffleVecType, unpackedVal); return rewriter.create(loc, asIntVec, 0); }; auto unpackFn = [loc, &rewriter, shuffleVecType, extendedVecTy](Value packedVal) -> Value { auto asIntVec = rewriter.create(loc, shuffleVecType, packedVal); return rewriter.create(loc, extendedVecTy, asIntVec); }; Value res = createSubgroupShuffleReduction(rewriter, loc, extendedInput, op.getOp(), subgroupSize, packFn, unpackFn); if (vecBitwidth < shuffleBitwidth) { res = rewriter.create( loc, res, /*offsets=*/0, /*sizes=*/vecTy.getNumElements(), /*strides=*/1); } rewriter.replaceOp(op, res); return success(); } private: unsigned subgroupSize = 0; unsigned shuffleBitwidth = 0; }; } // namespace void mlir::populateGpuBreakDownSubgrupReducePatterns( RewritePatternSet &patterns, unsigned maxShuffleBitwidth, PatternBenefit benefit) { patterns.add(patterns.getContext(), maxShuffleBitwidth, benefit); patterns.add(patterns.getContext(), benefit); } void mlir::populateGpuLowerSubgroupReduceToShufflePattenrs( RewritePatternSet &patterns, unsigned subgroupSize, unsigned shuffleBitwidth, PatternBenefit benefit) { patterns.add( patterns.getContext(), subgroupSize, shuffleBitwidth, benefit); }