bolt/deps/llvm-18.1.8/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
2025-02-14 19:21:04 +01:00

339 lines
14 KiB
C++

//===- CastOps.cpp - MLIR SPIR-V Cast Ops --------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines the cast and conversion operations in the SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "SPIRVOpUtils.h"
#include "SPIRVParsingUtils.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir::spirv::AttrNames;
namespace mlir::spirv {
static LogicalResult verifyCastOp(Operation *op,
bool requireSameBitWidth = true,
bool skipBitWidthCheck = false) {
// Some CastOps have no limit on bit widths for result and operand type.
if (skipBitWidthCheck)
return success();
Type operandType = op->getOperand(0).getType();
Type resultType = op->getResult(0).getType();
// ODS checks that result type and operand type have the same shape. Check
// that composite types match and extract the element types, if any.
using TypePair = std::pair<Type, Type>;
auto [operandElemTy, resultElemTy] =
TypeSwitch<Type, TypePair>(operandType)
.Case<VectorType, spirv::CooperativeMatrixType,
spirv::JointMatrixINTELType>(
[resultType](auto concreteOperandTy) -> TypePair {
if (auto concreteResultTy =
dyn_cast<decltype(concreteOperandTy)>(resultType)) {
return {concreteOperandTy.getElementType(),
concreteResultTy.getElementType()};
}
return {};
})
.Default([resultType](Type operandType) -> TypePair {
return {operandType, resultType};
});
if (!operandElemTy || !resultElemTy)
return op->emitOpError("incompatible operand and result types");
unsigned operandTypeBitWidth = operandElemTy.getIntOrFloatBitWidth();
unsigned resultTypeBitWidth = resultElemTy.getIntOrFloatBitWidth();
bool isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
if (requireSameBitWidth) {
if (!isSameBitWidth) {
return op->emitOpError(
"expected the same bit widths for operand type and result "
"type, but provided ")
<< operandElemTy << " and " << resultElemTy;
}
return success();
}
if (isSameBitWidth) {
return op->emitOpError(
"expected the different bit widths for operand type and result "
"type, but provided ")
<< operandElemTy << " and " << resultElemTy;
}
return success();
}
//===----------------------------------------------------------------------===//
// spirv.BitcastOp
//===----------------------------------------------------------------------===//
LogicalResult BitcastOp::verify() {
// TODO: The SPIR-V spec validation rules are different for different
// versions.
auto operandType = getOperand().getType();
auto resultType = getResult().getType();
if (operandType == resultType) {
return emitError("result type must be different from operand type");
}
if (llvm::isa<spirv::PointerType>(operandType) &&
!llvm::isa<spirv::PointerType>(resultType)) {
return emitError(
"unhandled bit cast conversion from pointer type to non-pointer type");
}
if (!llvm::isa<spirv::PointerType>(operandType) &&
llvm::isa<spirv::PointerType>(resultType)) {
return emitError(
"unhandled bit cast conversion from non-pointer type to pointer type");
}
auto operandBitWidth = getBitWidth(operandType);
auto resultBitWidth = getBitWidth(resultType);
if (operandBitWidth != resultBitWidth) {
return emitOpError("mismatch in result type bitwidth ")
<< resultBitWidth << " and operand type bitwidth "
<< operandBitWidth;
}
return success();
}
//===----------------------------------------------------------------------===//
// spirv.ConvertPtrToUOp
//===----------------------------------------------------------------------===//
LogicalResult ConvertPtrToUOp::verify() {
auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
auto resultType = llvm::cast<spirv::ScalarType>(getResult().getType());
if (!resultType || !resultType.isSignlessInteger())
return emitError("result must be a scalar type of unsigned integer");
auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
if (!spirvModule)
return success();
auto addressingModel = spirvModule.getAddressingModel();
if ((addressingModel == spirv::AddressingModel::Logical) ||
(addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 &&
operandType.getStorageClass() !=
spirv::StorageClass::PhysicalStorageBuffer))
return emitError("operand must be a physical pointer");
return success();
}
//===----------------------------------------------------------------------===//
// spirv.ConvertUToPtrOp
//===----------------------------------------------------------------------===//
LogicalResult ConvertUToPtrOp::verify() {
auto operandType = llvm::cast<spirv::ScalarType>(getOperand().getType());
auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
if (!operandType || !operandType.isSignlessInteger())
return emitError("result must be a scalar type of unsigned integer");
auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>();
if (!spirvModule)
return success();
auto addressingModel = spirvModule.getAddressingModel();
if ((addressingModel == spirv::AddressingModel::Logical) ||
(addressingModel == spirv::AddressingModel::PhysicalStorageBuffer64 &&
resultType.getStorageClass() !=
spirv::StorageClass::PhysicalStorageBuffer))
return emitError("result must be a physical pointer");
return success();
}
//===----------------------------------------------------------------------===//
// spirv.PtrCastToGenericOp
//===----------------------------------------------------------------------===//
LogicalResult PtrCastToGenericOp::verify() {
auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
spirv::StorageClass operandStorage = operandType.getStorageClass();
if (operandStorage != spirv::StorageClass::Workgroup &&
operandStorage != spirv::StorageClass::CrossWorkgroup &&
operandStorage != spirv::StorageClass::Function)
return emitError("pointer must point to the Workgroup, CrossWorkgroup"
", or Function Storage Class");
spirv::StorageClass resultStorage = resultType.getStorageClass();
if (resultStorage != spirv::StorageClass::Generic)
return emitError("result type must be of storage class Generic");
Type operandPointeeType = operandType.getPointeeType();
Type resultPointeeType = resultType.getPointeeType();
if (operandPointeeType != resultPointeeType)
return emitOpError("pointer operand's pointee type must have the same "
"as the op result type, but found ")
<< operandPointeeType << " vs " << resultPointeeType;
return success();
}
//===----------------------------------------------------------------------===//
// spirv.GenericCastToPtrOp
//===----------------------------------------------------------------------===//
LogicalResult GenericCastToPtrOp::verify() {
auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
spirv::StorageClass operandStorage = operandType.getStorageClass();
if (operandStorage != spirv::StorageClass::Generic)
return emitError("pointer type must be of storage class Generic");
spirv::StorageClass resultStorage = resultType.getStorageClass();
if (resultStorage != spirv::StorageClass::Workgroup &&
resultStorage != spirv::StorageClass::CrossWorkgroup &&
resultStorage != spirv::StorageClass::Function)
return emitError("result must point to the Workgroup, CrossWorkgroup, "
"or Function Storage Class");
Type operandPointeeType = operandType.getPointeeType();
Type resultPointeeType = resultType.getPointeeType();
if (operandPointeeType != resultPointeeType)
return emitOpError("pointer operand's pointee type must have the same "
"as the op result type, but found ")
<< operandPointeeType << " vs " << resultPointeeType;
return success();
}
//===----------------------------------------------------------------------===//
// spirv.GenericCastToPtrExplicitOp
//===----------------------------------------------------------------------===//
LogicalResult GenericCastToPtrExplicitOp::verify() {
auto operandType = llvm::cast<spirv::PointerType>(getPointer().getType());
auto resultType = llvm::cast<spirv::PointerType>(getResult().getType());
spirv::StorageClass operandStorage = operandType.getStorageClass();
if (operandStorage != spirv::StorageClass::Generic)
return emitError("pointer type must be of storage class Generic");
spirv::StorageClass resultStorage = resultType.getStorageClass();
if (resultStorage != spirv::StorageClass::Workgroup &&
resultStorage != spirv::StorageClass::CrossWorkgroup &&
resultStorage != spirv::StorageClass::Function)
return emitError("result must point to the Workgroup, CrossWorkgroup, "
"or Function Storage Class");
Type operandPointeeType = operandType.getPointeeType();
Type resultPointeeType = resultType.getPointeeType();
if (operandPointeeType != resultPointeeType)
return emitOpError("pointer operand's pointee type must have the same "
"as the op result type, but found ")
<< operandPointeeType << " vs " << resultPointeeType;
return success();
}
//===----------------------------------------------------------------------===//
// spirv.ConvertFToSOp
//===----------------------------------------------------------------------===//
LogicalResult ConvertFToSOp::verify() {
return verifyCastOp(*this, /*requireSameBitWidth=*/false,
/*skipBitWidthCheck=*/true);
}
//===----------------------------------------------------------------------===//
// spirv.ConvertFToUOp
//===----------------------------------------------------------------------===//
LogicalResult ConvertFToUOp::verify() {
return verifyCastOp(*this, /*requireSameBitWidth=*/false,
/*skipBitWidthCheck=*/true);
}
//===----------------------------------------------------------------------===//
// spirv.ConvertSToFOp
//===----------------------------------------------------------------------===//
LogicalResult ConvertSToFOp::verify() {
return verifyCastOp(*this, /*requireSameBitWidth=*/false,
/*skipBitWidthCheck=*/true);
}
//===----------------------------------------------------------------------===//
// spirv.ConvertUToFOp
//===----------------------------------------------------------------------===//
LogicalResult ConvertUToFOp::verify() {
return verifyCastOp(*this, /*requireSameBitWidth=*/false,
/*skipBitWidthCheck=*/true);
}
//===----------------------------------------------------------------------===//
// spirv.INTELConvertBF16ToFOp
//===----------------------------------------------------------------------===//
LogicalResult INTELConvertBF16ToFOp::verify() {
auto operandType = getOperand().getType();
auto resultType = getResult().getType();
// ODS checks that vector result type and vector operand type have the same
// shape.
if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
unsigned operandNumElements = vectorType.getNumElements();
unsigned resultNumElements =
llvm::cast<VectorType>(resultType).getNumElements();
if (operandNumElements != resultNumElements) {
return emitOpError(
"operand and result must have same number of elements");
}
}
return success();
}
//===----------------------------------------------------------------------===//
// spirv.INTELConvertFToBF16Op
//===----------------------------------------------------------------------===//
LogicalResult INTELConvertFToBF16Op::verify() {
auto operandType = getOperand().getType();
auto resultType = getResult().getType();
// ODS checks that vector result type and vector operand type have the same
// shape.
if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
unsigned operandNumElements = vectorType.getNumElements();
unsigned resultNumElements =
llvm::cast<VectorType>(resultType).getNumElements();
if (operandNumElements != resultNumElements) {
return emitOpError(
"operand and result must have same number of elements");
}
}
return success();
}
//===----------------------------------------------------------------------===//
// spirv.FConvertOp
//===----------------------------------------------------------------------===//
LogicalResult spirv::FConvertOp::verify() {
return verifyCastOp(*this, /*requireSameBitWidth=*/false);
}
//===----------------------------------------------------------------------===//
// spirv.SConvertOp
//===----------------------------------------------------------------------===//
LogicalResult spirv::SConvertOp::verify() {
return verifyCastOp(*this, /*requireSameBitWidth=*/false);
}
//===----------------------------------------------------------------------===//
// spirv.UConvertOp
//===----------------------------------------------------------------------===//
LogicalResult spirv::UConvertOp::verify() {
return verifyCastOp(*this, /*requireSameBitWidth=*/false);
}
} // namespace mlir::spirv