693 lines
29 KiB
C++
693 lines
29 KiB
C++
//===- ConvertConv2DToImg2Col.cpp - im2col implementation -----------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Affine/Utils.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Complex/IR/Complex.h"
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
|
|
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
|
|
#include "mlir/IR/AffineExpr.h"
|
|
#include "mlir/IR/AffineMap.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include <utility>
|
|
|
|
namespace mlir {
|
|
namespace linalg {
|
|
static bool hasAllOneValues(DenseIntElementsAttr attr) {
|
|
return llvm::all_of(
|
|
attr, [](const APInt &element) { return element.getSExtValue() == 1; });
|
|
}
|
|
|
|
static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) {
|
|
if (isa<IntegerType>(x.getType()))
|
|
return builder.create<arith::AddIOp>(loc, x, y);
|
|
if (isa<ComplexType>(x.getType()))
|
|
return builder.create<complex::AddOp>(loc, x, y);
|
|
return builder.create<arith::AddFOp>(loc, x, y);
|
|
}
|
|
|
|
static Value createMul(Location loc, Value x, Value y, Type accType,
|
|
OpBuilder &builder) {
|
|
// Linalg named ops specify signed extend for named ops.
|
|
Value xConvert =
|
|
convertScalarToDtype(builder, loc, x, accType, /*isUnsignedCast=*/false);
|
|
Value yConvert =
|
|
convertScalarToDtype(builder, loc, y, accType, /*isUnsignedCast=*/false);
|
|
if (isa<ComplexType>(accType))
|
|
return builder.create<complex::MulOp>(loc, xConvert, yConvert);
|
|
if (isa<IntegerType>(accType))
|
|
return builder.create<arith::MulIOp>(loc, xConvert, yConvert);
|
|
return builder.create<arith::MulFOp>(loc, xConvert, yConvert);
|
|
}
|
|
|
|
// Delinearizes the given composite `index` by the basis specified in `factors`.
|
|
static SmallVector<Value> unrollIndex(OpBuilder &b, Location loc, Value index,
|
|
ArrayRef<int64_t> factors) {
|
|
assert(!factors.empty() && "empty factor list");
|
|
SmallVector<Value> basis;
|
|
for (int64_t f : factors)
|
|
basis.push_back(b.create<arith::ConstantOp>(loc, b.getIndexAttr(f)));
|
|
FailureOr<SmallVector<Value>> multiIndex =
|
|
affine::delinearizeIndex(b, loc, index, basis);
|
|
assert(!failed(multiIndex) && "Failed to linearize img2col index");
|
|
return *multiIndex;
|
|
}
|
|
|
|
// Given indices corresponding to iterators in the output (oIndex) and filter
|
|
// (fIndex) for a convolution, compute the convolved index for the
|
|
// input as `oIndex * stride + fIndex`.
|
|
static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex,
|
|
Value fIndex, int64_t stride) {
|
|
AffineExpr oExpr, fExpr;
|
|
bindSymbols(b.getContext(), oExpr, fExpr);
|
|
AffineMap convMap = AffineMap::get(0, 2, stride * oExpr + fExpr);
|
|
return affine::makeComposedAffineApply(b, loc, convMap, {oIndex, fIndex});
|
|
}
|
|
|
|
FailureOr<std::pair<Operation *, Operation *>>
|
|
rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) {
|
|
auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
|
|
auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
|
|
auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
|
|
|
|
if (!filterType.hasStaticShape())
|
|
return rewriter.notifyMatchFailure(
|
|
convOp, "expected a static shape for the filter");
|
|
|
|
if (!inputType.hasStaticShape())
|
|
return rewriter.notifyMatchFailure(convOp,
|
|
"expected a static shape for the input");
|
|
|
|
// TODO: Support dilation.
|
|
if (!hasAllOneValues(convOp.getDilations()))
|
|
return rewriter.notifyMatchFailure(convOp,
|
|
"expected all ones for dilations");
|
|
|
|
MLIRContext *context = rewriter.getContext();
|
|
Value input = convOp.getInputs()[0];
|
|
Value filter = convOp.getInputs()[1];
|
|
Value output = convOp.getOutputs()[0];
|
|
|
|
ArrayRef<int64_t> filterShape = filterType.getShape();
|
|
ArrayRef<int64_t> outputShape = outputType.getShape();
|
|
|
|
int64_t n = outputShape[0];
|
|
int64_t oh = outputShape[1];
|
|
int64_t ow = outputShape[2];
|
|
int64_t oc = outputShape[3];
|
|
int64_t fh = filterShape[0];
|
|
int64_t fw = filterShape[1];
|
|
int64_t ic = filterShape[2];
|
|
|
|
Location loc = convOp.getLoc();
|
|
|
|
// Reshape output and filter to the LHS and result of a (B)MNK matmul.
|
|
SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}};
|
|
auto reshapedFilterType =
|
|
RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType());
|
|
Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
|
|
loc, reshapedFilterType, filter, filterReassocIndices);
|
|
|
|
SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
|
|
RankedTensorType reshapedOutputType =
|
|
RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
|
|
Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
|
|
loc, reshapedOutputType, output, outputReassocIndices);
|
|
|
|
SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
|
|
Value colTensor = rewriter.create<tensor::EmptyOp>(
|
|
loc, colTensorShape, inputType.getElementType());
|
|
|
|
// Convert the input to a (BMK) column tensor.
|
|
auto nloops = colTensorShape.size();
|
|
|
|
auto parallel = utils::IteratorType::parallel;
|
|
auto reduction = utils::IteratorType::reduction;
|
|
SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
|
|
|
|
SmallVector<AffineMap> img2colIndexingMaps = {
|
|
AffineMap::getMultiDimIdentityMap(nloops, context)};
|
|
|
|
auto img2ColTensor = rewriter.create<linalg::GenericOp>(
|
|
loc, colTensor.getType(),
|
|
/*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
|
|
img2colIterators,
|
|
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
|
|
// Get the iterators named based on the matmul (batch, m, k).
|
|
Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
|
|
Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
|
|
Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
|
|
|
|
// Recover the original iteration indices from the problem/input sizes.
|
|
SmallVector<Value> mIndices = unrollIndex(
|
|
nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
|
|
auto ohIndex = mIndices[0];
|
|
auto owIndex = mIndices[1];
|
|
|
|
SmallVector<Value> kIndices = unrollIndex(
|
|
nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
|
|
auto fhIndex = kIndices[0];
|
|
auto fwIndex = kIndices[1];
|
|
auto icIndex = kIndices[2];
|
|
|
|
// Extract the input element corresponding to the expanded indices.
|
|
Value hIndex =
|
|
getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
|
|
convOp.getStrides().getValues<int64_t>()[0]);
|
|
Value wIndex =
|
|
getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
|
|
convOp.getStrides().getValues<int64_t>()[1]);
|
|
|
|
// im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
|
|
SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
|
|
Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
|
|
loc, input, extractionIndices);
|
|
nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
|
|
});
|
|
|
|
// Because the filter does not share the same batch dimension,
|
|
// the batch dimension is only used in indexing the input and output. Thus
|
|
// we cannot use existing linalg named ops like linalg.batch_matmul.
|
|
// i.e. (B x) M x K * K x N = (B x) M x N
|
|
AffineExpr bDim, mDim, nDim, kDim;
|
|
bindDims(context, bDim, mDim, nDim, kDim);
|
|
auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
|
|
auto rhsMap = AffineMap::get(4, 0, {kDim, nDim}, context);
|
|
auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
|
|
SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
|
|
parallel, reduction};
|
|
|
|
auto genericOp = rewriter.create<linalg::GenericOp>(
|
|
loc, reshapedOutputType,
|
|
/*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
|
|
/*outputs=*/ValueRange{reshapedOutput},
|
|
ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
|
|
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
|
|
Value mul =
|
|
createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
|
|
Value add = createAdd(loc, mul, args[2], nestedBuilder);
|
|
nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
|
|
});
|
|
Value result = genericOp.getResults().front();
|
|
|
|
auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
|
|
loc, outputType, result, outputReassocIndices);
|
|
|
|
rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
|
|
|
|
return std::make_pair(img2ColTensor.getOperation(),
|
|
reshapedResult.getOperation());
|
|
}
|
|
|
|
FailureOr<std::pair<Operation *, Operation *>>
|
|
rewriteInIm2Col(RewriterBase &rewriter,
|
|
linalg::DepthwiseConv2DNhwcHwcOp convOp) {
|
|
auto inputType = cast<RankedTensorType>(convOp.getInputs()[0].getType());
|
|
auto filterType = cast<RankedTensorType>(convOp.getInputs()[1].getType());
|
|
auto outputType = cast<RankedTensorType>(convOp.getOutputs()[0].getType());
|
|
|
|
if (!filterType.hasStaticShape())
|
|
return rewriter.notifyMatchFailure(
|
|
convOp, "expected a static shape for the filter");
|
|
|
|
if (!inputType.hasStaticShape())
|
|
return rewriter.notifyMatchFailure(convOp,
|
|
"expected a static shape for the input");
|
|
|
|
// TODO: Support dilation.
|
|
if (!hasAllOneValues(convOp.getDilations()))
|
|
return rewriter.notifyMatchFailure(convOp,
|
|
"expected all ones for dilations");
|
|
|
|
Location loc = convOp.getLoc();
|
|
|
|
auto transposeOperand = [&](Value operand, ArrayRef<int64_t> indices) {
|
|
auto operandTensorType = cast<RankedTensorType>(operand.getType());
|
|
auto nloops = indices.size();
|
|
ArrayRef<int64_t> inputShape = operandTensorType.getShape();
|
|
|
|
SmallVector<AffineExpr> exprs = llvm::to_vector<4>(
|
|
llvm::map_range(indices, [&](int64_t index) -> AffineExpr {
|
|
return rewriter.getAffineDimExpr(index);
|
|
}));
|
|
|
|
SmallVector<int64_t> targetShape = llvm::to_vector<4>(llvm::map_range(
|
|
indices, [&](int64_t index) -> int64_t { return inputShape[index]; }));
|
|
|
|
Value outputTensor = rewriter.create<tensor::EmptyOp>(
|
|
loc, targetShape, operandTensorType.getElementType());
|
|
|
|
SmallVector<utils::IteratorType> loopAttributeTypes(
|
|
nloops, utils::IteratorType::parallel);
|
|
|
|
SmallVector<AffineMap> indexingMaps = {
|
|
inversePermutation(
|
|
AffineMap::get(nloops, 0, exprs, rewriter.getContext())),
|
|
AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())};
|
|
|
|
auto transposedOp = rewriter.create<linalg::GenericOp>(
|
|
loc, outputTensor.getType(),
|
|
/*inputs=*/operand, /*outputs=*/outputTensor, indexingMaps,
|
|
loopAttributeTypes,
|
|
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
|
|
nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
|
|
});
|
|
|
|
return transposedOp.getResult(0);
|
|
};
|
|
|
|
Value input = convOp.getInputs()[0];
|
|
Value filter = convOp.getInputs()[1];
|
|
Value output = convOp.getOutputs()[0];
|
|
|
|
// Transpose input, filter so channels are outermost
|
|
Value inputT = transposeOperand(input, {0, 3, 1, 2});
|
|
Value filterT = transposeOperand(filter, {2, 0, 1});
|
|
ArrayRef<int64_t> filterTShape =
|
|
cast<RankedTensorType>(filterT.getType()).getShape();
|
|
ArrayRef<int64_t> outputShape = outputType.getShape();
|
|
|
|
int n = outputShape[0];
|
|
int oh = outputShape[1];
|
|
int ow = outputShape[2];
|
|
int c = outputShape[3];
|
|
int fh = filterTShape[1];
|
|
int fw = filterTShape[2];
|
|
|
|
SmallVector<int64_t> colTensorShape = {n, c, oh, ow, fh, fw};
|
|
Value transposedOutputTensor = transposeOperand(output, {0, 3, 1, 2});
|
|
|
|
AffineExpr nDim, cDim, ohDim, owDim, khDim, kwDim;
|
|
bindDims(rewriter.getContext(), nDim, cDim, ohDim, owDim, khDim, kwDim);
|
|
|
|
AffineExpr shSym = rewriter.getAffineConstantExpr(
|
|
convOp.getStrides().getValues<int64_t>()[0]);
|
|
AffineExpr swSym = rewriter.getAffineConstantExpr(
|
|
convOp.getStrides().getValues<int64_t>()[1]);
|
|
|
|
SmallVector<AffineExpr> inputExprs = {nDim, cDim, ohDim * shSym + khDim,
|
|
owDim * swSym + kwDim};
|
|
|
|
auto nloops = colTensorShape.size();
|
|
|
|
SmallVector<utils::IteratorType> loopAttributeTypes(
|
|
nloops, utils::IteratorType::parallel);
|
|
|
|
SmallVector<AffineMap> indexingMaps = {
|
|
AffineMap::get(nloops, 0, inputExprs, rewriter.getContext()),
|
|
AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())};
|
|
|
|
Value colTensor = rewriter.create<tensor::EmptyOp>(
|
|
loc, colTensorShape, inputType.getElementType());
|
|
|
|
auto img2ColTensor = rewriter.create<linalg::GenericOp>(
|
|
loc, colTensor.getType(),
|
|
/*inputs=*/inputT, /*outputs=*/colTensor, indexingMaps,
|
|
loopAttributeTypes,
|
|
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
|
|
nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]);
|
|
});
|
|
|
|
SmallVector<ReassociationIndices> img2ColTensorReassocIndices = {
|
|
{0, 1}, {2, 3}, {4, 5}};
|
|
SmallVector<ReassociationIndices> filterReassociationIndice = {{0}, {1, 2}};
|
|
SmallVector<ReassociationIndices> outputReassociationIndice = {{0, 1},
|
|
{2, 3}};
|
|
|
|
auto reshapedImg2ColTensorType = RankedTensorType::get(
|
|
{n * c, oh * ow, fh * fw}, inputType.getElementType());
|
|
auto reshapedFilterTensorType =
|
|
RankedTensorType::get({c, fh * fw}, filterType.getElementType());
|
|
auto reshapedOutputTensorType =
|
|
RankedTensorType::get({n * c, oh * ow}, outputType.getElementType());
|
|
|
|
Value reshapedImg2ColTensor = rewriter.create<tensor::CollapseShapeOp>(
|
|
loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0),
|
|
img2ColTensorReassocIndices);
|
|
Value reshapedFilterTensor = rewriter.create<tensor::CollapseShapeOp>(
|
|
loc, reshapedFilterTensorType, filterT, filterReassociationIndice);
|
|
Value reshapedoutputTensor = rewriter.create<tensor::CollapseShapeOp>(
|
|
loc, reshapedOutputTensorType, transposedOutputTensor,
|
|
outputReassociationIndice);
|
|
|
|
auto batchMatVecResult = rewriter.create<linalg::BatchMatvecOp>(
|
|
loc, TypeRange{reshapedoutputTensor.getType()},
|
|
ValueRange{reshapedImg2ColTensor, reshapedFilterTensor},
|
|
ValueRange{reshapedoutputTensor});
|
|
|
|
SmallVector<ReassociationIndices> batchMatVecReassociationIndice = {{0, 1},
|
|
{2, 3}};
|
|
|
|
Value batchMatVecResultReshaped = rewriter.create<tensor::ExpandShapeOp>(
|
|
loc, transposedOutputTensor.getType(), batchMatVecResult.getResult(0),
|
|
batchMatVecReassociationIndice);
|
|
|
|
Value transposedResult =
|
|
transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1});
|
|
|
|
rewriter.replaceOp(convOp, ArrayRef<Value>{transposedResult});
|
|
return std::make_pair(img2ColTensor.getOperation(),
|
|
transposedResult.getDefiningOp());
|
|
}
|
|
|
|
FailureOr<std::pair<Operation *, Operation *>>
|
|
rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) {
|
|
auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
|
|
auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
|
|
auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
|
|
|
|
if (!filterType.hasStaticShape())
|
|
return rewriter.notifyMatchFailure(
|
|
convOp, "expected a static shape for the filter");
|
|
|
|
if (!inputType.hasStaticShape())
|
|
return rewriter.notifyMatchFailure(convOp,
|
|
"expected a static shape for the input");
|
|
|
|
// TODO: Support dilation.
|
|
if (!hasAllOneValues(convOp.getDilations()))
|
|
return rewriter.notifyMatchFailure(convOp,
|
|
"expected all ones for dilations");
|
|
|
|
Value input = convOp.getInputs()[0];
|
|
Value filter = convOp.getInputs()[1];
|
|
Value output = convOp.getOutputs()[0];
|
|
|
|
auto filterShape = filterType.getShape();
|
|
auto outputShape = outputType.getShape();
|
|
|
|
int64_t n = outputShape[0];
|
|
int64_t oc = outputShape[1];
|
|
int64_t oh = outputShape[2];
|
|
int64_t ow = outputShape[3];
|
|
int64_t ic = filterShape[1];
|
|
int64_t fh = filterShape[2];
|
|
int64_t fw = filterShape[3];
|
|
|
|
auto loc = convOp.getLoc();
|
|
MLIRContext *context = rewriter.getContext();
|
|
|
|
SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
|
|
auto reshapedFilterType =
|
|
RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType());
|
|
Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
|
|
loc, reshapedFilterType, filter, filterReassocIndices);
|
|
|
|
SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1}, {2, 3}};
|
|
auto reshapedOutputType =
|
|
RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType());
|
|
Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
|
|
loc, reshapedOutputType, output, outputReassocIndices);
|
|
|
|
// Convert the input to a (BKN) tensor.
|
|
SmallVector<int64_t, 4> colTensorShape = {n, ic * fh * fw, oh * ow};
|
|
Value colTensor = rewriter.create<tensor::EmptyOp>(
|
|
loc, colTensorShape, inputType.getElementType());
|
|
|
|
auto nloops = colTensorShape.size();
|
|
|
|
auto parallel = utils::IteratorType::parallel;
|
|
auto reduction = utils::IteratorType::reduction;
|
|
SmallVector<utils::IteratorType, 3> img2colIterators(nloops, parallel);
|
|
|
|
SmallVector<AffineMap, 4> img2colIndexingMaps = {
|
|
AffineMap::getMultiDimIdentityMap(nloops, context)};
|
|
|
|
auto img2ColTensor = rewriter.create<linalg::GenericOp>(
|
|
loc, colTensor.getType(),
|
|
/*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
|
|
img2colIterators,
|
|
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
|
|
// Get the iterators named based on the matmul (batch, m, k).
|
|
Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
|
|
Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
|
|
Value nIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
|
|
|
|
// Recover the original iteration indices from the problem/input sizes.
|
|
SmallVector<Value> kIndices = unrollIndex(
|
|
nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{ic, fh, fw});
|
|
auto icIndex = kIndices[0];
|
|
auto fhIndex = kIndices[1];
|
|
auto fwIndex = kIndices[2];
|
|
|
|
SmallVector<Value> nIndices = unrollIndex(
|
|
nestedBuilder, nestedLoc, nIndex, ArrayRef<int64_t>{oh, ow});
|
|
auto ohIndex = nIndices[0];
|
|
auto owIndex = nIndices[1];
|
|
|
|
// Extract the input element corresponding to the expanded indices.
|
|
Value hIndex =
|
|
getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
|
|
convOp.getStrides().getValues<int64_t>()[0]);
|
|
Value wIndex =
|
|
getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
|
|
convOp.getStrides().getValues<int64_t>()[1]);
|
|
|
|
// im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw]
|
|
SmallVector<Value> extractionIndices{bIndex, icIndex, hIndex, wIndex};
|
|
Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
|
|
loc, input, extractionIndices);
|
|
nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
|
|
});
|
|
|
|
// Because the filter does not share the same batch dimension,
|
|
// the batch dimension is only used in indexing the input and output. Thus
|
|
// we cannot use existing linalg named ops like linalg.batch_matmul.
|
|
// i.e. M x K * (B x) K x N = (B x) M x N
|
|
AffineExpr bDim, mDim, nDim, kDim;
|
|
bindDims(context, bDim, mDim, nDim, kDim);
|
|
auto lhsMap = AffineMap::get(4, 0, {mDim, kDim}, context);
|
|
auto rhsMap = AffineMap::get(4, 0, {bDim, kDim, nDim}, context);
|
|
auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
|
|
SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
|
|
parallel, reduction};
|
|
auto genericOp = rewriter.create<linalg::GenericOp>(
|
|
loc, reshapedOutputType,
|
|
/*inputs=*/ValueRange{reshapedFilter, img2ColTensor.getResult(0)},
|
|
/*outputs=*/ValueRange{reshapedOutput},
|
|
ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
|
|
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
|
|
Value mul =
|
|
createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
|
|
Value add = createAdd(loc, mul, args[2], nestedBuilder);
|
|
nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
|
|
});
|
|
Value result = genericOp.getResults().front();
|
|
|
|
auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
|
|
loc, outputType, result, outputReassocIndices);
|
|
|
|
rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
|
|
|
|
return std::make_pair(img2ColTensor.getOperation(),
|
|
reshapedResult.getOperation());
|
|
}
|
|
|
|
FailureOr<std::pair<Operation *, Operation *>>
|
|
rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) {
|
|
auto inputType = cast<ShapedType>(convOp.getInputs()[0].getType());
|
|
auto filterType = cast<ShapedType>(convOp.getInputs()[1].getType());
|
|
auto outputType = cast<ShapedType>(convOp.getOutputs()[0].getType());
|
|
|
|
if (!filterType.hasStaticShape())
|
|
return rewriter.notifyMatchFailure(
|
|
convOp, "expected a static shape for the filter");
|
|
|
|
if (!inputType.hasStaticShape())
|
|
return rewriter.notifyMatchFailure(convOp,
|
|
"expected a static shape for the input");
|
|
|
|
// TODO: Support dilation.
|
|
if (!hasAllOneValues(convOp.getDilations()))
|
|
return rewriter.notifyMatchFailure(convOp,
|
|
"expected all ones for dilations");
|
|
|
|
MLIRContext *context = rewriter.getContext();
|
|
Value input = convOp.getInputs()[0];
|
|
Value filter = convOp.getInputs()[1];
|
|
Value output = convOp.getOutputs()[0];
|
|
|
|
ArrayRef<int64_t> filterShape = filterType.getShape();
|
|
ArrayRef<int64_t> outputShape = outputType.getShape();
|
|
|
|
int64_t n = outputShape[0];
|
|
int64_t oh = outputShape[1];
|
|
int64_t ow = outputShape[2];
|
|
int64_t oc = outputShape[3];
|
|
int64_t fh = filterShape[1];
|
|
int64_t fw = filterShape[2];
|
|
int64_t ic = filterShape[3];
|
|
|
|
Location loc = convOp.getLoc();
|
|
|
|
// Reshape output and filter to the LHS and result of a "row-wise" matrix
|
|
// multiplication.
|
|
SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}};
|
|
auto reshapedFilterType =
|
|
RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType());
|
|
Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>(
|
|
loc, reshapedFilterType, filter, filterReassocIndices);
|
|
|
|
SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}};
|
|
RankedTensorType reshapedOutputType =
|
|
RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType());
|
|
Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>(
|
|
loc, reshapedOutputType, output, outputReassocIndices);
|
|
|
|
SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic};
|
|
Value colTensor = rewriter.create<tensor::EmptyOp>(
|
|
loc, colTensorShape, inputType.getElementType());
|
|
|
|
// Convert the input to a (BMK) column tensor.
|
|
auto nloops = colTensorShape.size();
|
|
|
|
auto parallel = utils::IteratorType::parallel;
|
|
auto reduction = utils::IteratorType::reduction;
|
|
SmallVector<utils::IteratorType> img2colIterators(nloops, parallel);
|
|
|
|
SmallVector<AffineMap> img2colIndexingMaps = {
|
|
AffineMap::getMultiDimIdentityMap(nloops, context)};
|
|
|
|
auto img2ColTensor = rewriter.create<linalg::GenericOp>(
|
|
loc, colTensor.getType(),
|
|
/*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps,
|
|
img2colIterators,
|
|
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
|
|
// Get the iterators named based on the matmul (batch, m, k).
|
|
Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0);
|
|
Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1);
|
|
Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2);
|
|
|
|
// Recover the original iteration indices from the problem/input sizes.
|
|
SmallVector<Value> mIndices = unrollIndex(
|
|
nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow});
|
|
auto ohIndex = mIndices[0];
|
|
auto owIndex = mIndices[1];
|
|
|
|
SmallVector<Value> kIndices = unrollIndex(
|
|
nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic});
|
|
auto fhIndex = kIndices[0];
|
|
auto fwIndex = kIndices[1];
|
|
auto icIndex = kIndices[2];
|
|
|
|
// Extract the input element corresponding to the expanded indices.
|
|
Value hIndex =
|
|
getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex,
|
|
convOp.getStrides().getValues<int64_t>()[0]);
|
|
Value wIndex =
|
|
getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex,
|
|
convOp.getStrides().getValues<int64_t>()[1]);
|
|
|
|
// im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic]
|
|
SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex};
|
|
Value inputVal = nestedBuilder.create<tensor::ExtractOp>(
|
|
loc, input, extractionIndices);
|
|
nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal);
|
|
});
|
|
|
|
// Because we didn't transpose the filters we don't actually have a batched
|
|
// matrix multiply. Instead, we have an operation consisting of "row-wise" dot
|
|
// products.
|
|
AffineExpr bDim, mDim, nDim, kDim;
|
|
bindDims(context, bDim, mDim, nDim, kDim);
|
|
auto lhsMap = AffineMap::get(4, 0, {bDim, mDim, kDim}, context);
|
|
auto rhsMap = AffineMap::get(4, 0, {nDim, kDim}, context);
|
|
auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context);
|
|
SmallVector<utils::IteratorType> genericIterators = {parallel, parallel,
|
|
parallel, reduction};
|
|
|
|
auto genericOp = rewriter.create<linalg::GenericOp>(
|
|
loc, reshapedOutputType,
|
|
/*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter},
|
|
/*outputs=*/ValueRange{reshapedOutput},
|
|
ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators,
|
|
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
|
|
Value mul =
|
|
createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder);
|
|
Value add = createAdd(loc, mul, args[2], nestedBuilder);
|
|
nestedBuilder.create<linalg::YieldOp>(nestedLoc, add);
|
|
});
|
|
Value result = genericOp.getResults().front();
|
|
|
|
auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>(
|
|
loc, outputType, result, outputReassocIndices);
|
|
|
|
rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult});
|
|
|
|
return std::make_pair(img2ColTensor.getOperation(),
|
|
reshapedResult.getOperation());
|
|
}
|
|
|
|
namespace {
|
|
|
|
class ConvertConv2DNhwcHwcf final
|
|
: public OpRewritePattern<linalg::Conv2DNhwcHwcfOp> {
|
|
public:
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (failed(rewriteInIm2Col(rewriter, convOp)))
|
|
return failure();
|
|
return success();
|
|
}
|
|
};
|
|
|
|
class ConvertDepthwiseConv2DNhwcHwc final
|
|
: public OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp> {
|
|
public:
|
|
using OpRewritePattern<linalg::DepthwiseConv2DNhwcHwcOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(linalg::DepthwiseConv2DNhwcHwcOp convOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (failed(rewriteInIm2Col(rewriter, convOp)))
|
|
return failure();
|
|
return success();
|
|
}
|
|
};
|
|
|
|
class ConvertConv2DNchwFchw final
|
|
: public OpRewritePattern<linalg::Conv2DNchwFchwOp> {
|
|
public:
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (failed(rewriteInIm2Col(rewriter, convOp)))
|
|
return failure();
|
|
return success();
|
|
}
|
|
};
|
|
|
|
class ConvertConv2DNhwcFhwc final
|
|
: public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
|
|
public:
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (failed(rewriteInIm2Col(rewriter, convOp)))
|
|
return failure();
|
|
return success();
|
|
}
|
|
};
|
|
} // end anonymous namespace
|
|
|
|
void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns) {
|
|
MLIRContext *context = patterns.getContext();
|
|
patterns.insert<ConvertConv2DNhwcHwcf, ConvertDepthwiseConv2DNhwcHwc,
|
|
ConvertConv2DNchwFchw, ConvertConv2DNhwcFhwc>(context);
|
|
}
|
|
} // end namespace linalg
|
|
} // end namespace mlir
|