109 lines
3.8 KiB
C++
109 lines
3.8 KiB
C++
|
//===- AMXDialect.cpp - MLIR AMX ops implementation -----------------------===//
|
||
|
//
|
||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
//
|
||
|
// This file implements the AMX dialect and its operations.
|
||
|
//
|
||
|
//===----------------------------------------------------------------------===//
|
||
|
|
||
|
#include "mlir/Dialect/AMX/AMXDialect.h"
|
||
|
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
|
||
|
#include "mlir/IR/Builders.h"
|
||
|
#include "mlir/IR/OpImplementation.h"
|
||
|
#include "mlir/IR/TypeUtilities.h"
|
||
|
|
||
|
using namespace mlir;
|
||
|
|
||
|
#include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
|
||
|
|
||
|
void amx::AMXDialect::initialize() {
|
||
|
addOperations<
|
||
|
#define GET_OP_LIST
|
||
|
#include "mlir/Dialect/AMX/AMX.cpp.inc"
|
||
|
>();
|
||
|
}
|
||
|
|
||
|
/// Verify that AMX supports the implied tile shape.
|
||
|
static LogicalResult verifyTileSize(Operation *op, VectorType tp) {
|
||
|
const unsigned kMaxRows = 16;
|
||
|
const unsigned kBitsPerRow = 64 * 8;
|
||
|
unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
|
||
|
if (tp.getDimSize(0) > kMaxRows)
|
||
|
return op->emitOpError("bad row height: ") << tp.getDimSize(0);
|
||
|
if (col > kBitsPerRow || col & 0x1f)
|
||
|
return op->emitOpError("bad column width: ") << (col >> 3);
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
/// Verify that AMX supports the multiplication.
|
||
|
static LogicalResult verifyMultShape(Operation *op, VectorType atp,
|
||
|
VectorType btp, VectorType ctp,
|
||
|
unsigned scale) {
|
||
|
unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
|
||
|
unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
|
||
|
unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1);
|
||
|
if (cm != am || cn != bn || ak != bk)
|
||
|
return op->emitOpError("bad mult shape: ")
|
||
|
<< cm << " x " << cn << " x " << ak;
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult amx::TileZeroOp::verify() {
|
||
|
return verifyTileSize(*this, getVectorType());
|
||
|
}
|
||
|
|
||
|
LogicalResult amx::TileLoadOp::verify() {
|
||
|
unsigned rank = getMemRefType().getRank();
|
||
|
if (getIndices().size() != rank)
|
||
|
return emitOpError("requires ") << rank << " indices";
|
||
|
return verifyTileSize(*this, getVectorType());
|
||
|
}
|
||
|
|
||
|
LogicalResult amx::TileStoreOp::verify() {
|
||
|
unsigned rank = getMemRefType().getRank();
|
||
|
if (getIndices().size() != rank)
|
||
|
return emitOpError("requires ") << rank << " indices";
|
||
|
return verifyTileSize(*this, getVectorType());
|
||
|
}
|
||
|
|
||
|
LogicalResult amx::TileMulFOp::verify() {
|
||
|
VectorType aType = getLhsVectorType();
|
||
|
VectorType bType = getRhsVectorType();
|
||
|
VectorType cType = getVectorType();
|
||
|
if (failed(verifyTileSize(*this, aType)) ||
|
||
|
failed(verifyTileSize(*this, bType)) ||
|
||
|
failed(verifyTileSize(*this, cType)) ||
|
||
|
failed(verifyMultShape(*this, aType, bType, cType, 1)))
|
||
|
return failure();
|
||
|
Type ta = aType.getElementType();
|
||
|
Type tb = bType.getElementType();
|
||
|
Type tc = cType.getElementType();
|
||
|
if (!ta.isBF16() || !tb.isBF16() || !tc.isF32())
|
||
|
return emitOpError("unsupported type combination");
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
LogicalResult amx::TileMulIOp::verify() {
|
||
|
VectorType aType = getLhsVectorType();
|
||
|
VectorType bType = getRhsVectorType();
|
||
|
VectorType cType = getVectorType();
|
||
|
if (failed(verifyTileSize(*this, aType)) ||
|
||
|
failed(verifyTileSize(*this, bType)) ||
|
||
|
failed(verifyTileSize(*this, cType)) ||
|
||
|
failed(verifyMultShape(*this, aType, bType, cType, 2)))
|
||
|
return failure();
|
||
|
Type ta = aType.getElementType();
|
||
|
Type tb = bType.getElementType();
|
||
|
Type tc = cType.getElementType();
|
||
|
if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32))
|
||
|
return emitOpError("unsupported type combination");
|
||
|
return success();
|
||
|
}
|
||
|
|
||
|
#define GET_OP_CLASSES
|
||
|
#include "mlir/Dialect/AMX/AMX.cpp.inc"
|