bolt/deps/llvm-18.1.8/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
2025-02-14 19:21:04 +01:00

377 lines
12 KiB
C++

//===- ComplexOps.cpp - MLIR Complex Operations ---------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace mlir::complex;
//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
return getValue();
}
void ConstantOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "cst");
}
bool ConstantOp::isBuildableWith(Attribute value, Type type) {
if (auto arrAttr = llvm::dyn_cast<ArrayAttr>(value)) {
auto complexTy = llvm::dyn_cast<ComplexType>(type);
if (!complexTy || arrAttr.size() != 2)
return false;
auto complexEltTy = complexTy.getElementType();
if (auto fre = llvm::dyn_cast<FloatAttr>(arrAttr[0])) {
auto im = llvm::dyn_cast<FloatAttr>(arrAttr[1]);
return im && fre.getType() == complexEltTy &&
im.getType() == complexEltTy;
}
if (auto ire = llvm::dyn_cast<IntegerAttr>(arrAttr[0])) {
auto im = llvm::dyn_cast<IntegerAttr>(arrAttr[1]);
return im && ire.getType() == complexEltTy &&
im.getType() == complexEltTy;
}
}
return false;
}
LogicalResult ConstantOp::verify() {
ArrayAttr arrayAttr = getValue();
if (arrayAttr.size() != 2) {
return emitOpError(
"requires 'value' to be a complex constant, represented as array of "
"two values");
}
auto complexEltTy = getType().getElementType();
if (!isa<FloatAttr, IntegerAttr>(arrayAttr[0]) ||
!isa<FloatAttr, IntegerAttr>(arrayAttr[1]))
return emitOpError(
"requires attribute's elements to be float or integer attributes");
auto re = llvm::dyn_cast<TypedAttr>(arrayAttr[0]);
auto im = llvm::dyn_cast<TypedAttr>(arrayAttr[1]);
if (complexEltTy != re.getType() || complexEltTy != im.getType()) {
return emitOpError()
<< "requires attribute's element types (" << re.getType() << ", "
<< im.getType()
<< ") to match the element type of the op's return type ("
<< complexEltTy << ")";
}
return success();
}
//===----------------------------------------------------------------------===//
// BitcastOp
//===----------------------------------------------------------------------===//
OpFoldResult BitcastOp::fold(FoldAdaptor bitcast) {
if (getOperand().getType() == getType())
return getOperand();
return {};
}
LogicalResult BitcastOp::verify() {
auto operandType = getOperand().getType();
auto resultType = getType();
// We allow this to be legal as it can be folded away.
if (operandType == resultType)
return success();
if (!operandType.isIntOrFloat() && !isa<ComplexType>(operandType)) {
return emitOpError("operand must be int/float/complex");
}
if (!resultType.isIntOrFloat() && !isa<ComplexType>(resultType)) {
return emitOpError("result must be int/float/complex");
}
if (isa<ComplexType>(operandType) == isa<ComplexType>(resultType)) {
return emitOpError(
"requires that either input or output has a complex type");
}
if (isa<ComplexType>(resultType))
std::swap(operandType, resultType);
int32_t operandBitwidth = dyn_cast<ComplexType>(operandType)
.getElementType()
.getIntOrFloatBitWidth() *
2;
int32_t resultBitwidth = resultType.getIntOrFloatBitWidth();
if (operandBitwidth != resultBitwidth) {
return emitOpError("casting bitwidths do not match");
}
return success();
}
struct MergeComplexBitcast final : OpRewritePattern<BitcastOp> {
using OpRewritePattern<BitcastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(BitcastOp op,
PatternRewriter &rewriter) const override {
if (auto defining = op.getOperand().getDefiningOp<BitcastOp>()) {
if (isa<ComplexType>(op.getType()) ||
isa<ComplexType>(defining.getOperand().getType())) {
// complex.bitcast requires that input or output is complex.
rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
defining.getOperand());
} else {
rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
defining.getOperand());
}
return success();
}
if (auto defining = op.getOperand().getDefiningOp<arith::BitcastOp>()) {
rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
defining.getOperand());
return success();
}
return failure();
}
};
struct MergeArithBitcast final : OpRewritePattern<arith::BitcastOp> {
using OpRewritePattern<arith::BitcastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(arith::BitcastOp op,
PatternRewriter &rewriter) const override {
if (auto defining = op.getOperand().getDefiningOp<complex::BitcastOp>()) {
rewriter.replaceOpWithNewOp<complex::BitcastOp>(op, op.getType(),
defining.getOperand());
return success();
}
return failure();
}
};
void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<MergeComplexBitcast, MergeArithBitcast>(context);
}
//===----------------------------------------------------------------------===//
// CreateOp
//===----------------------------------------------------------------------===//
OpFoldResult CreateOp::fold(FoldAdaptor adaptor) {
// Fold complex.create(complex.re(op), complex.im(op)).
if (auto reOp = getOperand(0).getDefiningOp<ReOp>()) {
if (auto imOp = getOperand(1).getDefiningOp<ImOp>()) {
if (reOp.getOperand() == imOp.getOperand()) {
return reOp.getOperand();
}
}
}
return {};
}
//===----------------------------------------------------------------------===//
// ImOp
//===----------------------------------------------------------------------===//
OpFoldResult ImOp::fold(FoldAdaptor adaptor) {
ArrayAttr arrayAttr =
llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
if (arrayAttr && arrayAttr.size() == 2)
return arrayAttr[1];
if (auto createOp = getOperand().getDefiningOp<CreateOp>())
return createOp.getOperand(1);
return {};
}
namespace {
template <typename OpKind, int ComponentIndex>
struct FoldComponentNeg final : OpRewritePattern<OpKind> {
using OpRewritePattern<OpKind>::OpRewritePattern;
LogicalResult matchAndRewrite(OpKind op,
PatternRewriter &rewriter) const override {
auto negOp = op.getOperand().template getDefiningOp<NegOp>();
if (!negOp)
return failure();
auto createOp = negOp.getComplex().template getDefiningOp<CreateOp>();
if (!createOp)
return failure();
Type elementType = createOp.getType().getElementType();
assert(isa<FloatType>(elementType));
rewriter.replaceOpWithNewOp<arith::NegFOp>(
op, elementType, createOp.getOperand(ComponentIndex));
return success();
}
};
} // namespace
void ImOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldComponentNeg<ImOp, 1>>(context);
}
//===----------------------------------------------------------------------===//
// ReOp
//===----------------------------------------------------------------------===//
OpFoldResult ReOp::fold(FoldAdaptor adaptor) {
ArrayAttr arrayAttr =
llvm::dyn_cast_if_present<ArrayAttr>(adaptor.getComplex());
if (arrayAttr && arrayAttr.size() == 2)
return arrayAttr[0];
if (auto createOp = getOperand().getDefiningOp<CreateOp>())
return createOp.getOperand(0);
return {};
}
void ReOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldComponentNeg<ReOp, 0>>(context);
}
//===----------------------------------------------------------------------===//
// AddOp
//===----------------------------------------------------------------------===//
OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
// complex.add(complex.sub(a, b), b) -> a
if (auto sub = getLhs().getDefiningOp<SubOp>())
if (getRhs() == sub.getRhs())
return sub.getLhs();
// complex.add(b, complex.sub(a, b)) -> a
if (auto sub = getRhs().getDefiningOp<SubOp>())
if (getLhs() == sub.getRhs())
return sub.getLhs();
// complex.add(a, complex.constant<0.0, 0.0>) -> a
if (auto constantOp = getRhs().getDefiningOp<ConstantOp>()) {
auto arrayAttr = constantOp.getValue();
if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) {
return getLhs();
}
}
return {};
}
//===----------------------------------------------------------------------===//
// SubOp
//===----------------------------------------------------------------------===//
OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
// complex.sub(complex.add(a, b), b) -> a
if (auto add = getLhs().getDefiningOp<AddOp>())
if (getRhs() == add.getRhs())
return add.getLhs();
// complex.sub(a, complex.constant<0.0, 0.0>) -> a
if (auto constantOp = getRhs().getDefiningOp<ConstantOp>()) {
auto arrayAttr = constantOp.getValue();
if (llvm::cast<FloatAttr>(arrayAttr[0]).getValue().isZero() &&
llvm::cast<FloatAttr>(arrayAttr[1]).getValue().isZero()) {
return getLhs();
}
}
return {};
}
//===----------------------------------------------------------------------===//
// NegOp
//===----------------------------------------------------------------------===//
OpFoldResult NegOp::fold(FoldAdaptor adaptor) {
// complex.neg(complex.neg(a)) -> a
if (auto negOp = getOperand().getDefiningOp<NegOp>())
return negOp.getOperand();
return {};
}
//===----------------------------------------------------------------------===//
// LogOp
//===----------------------------------------------------------------------===//
OpFoldResult LogOp::fold(FoldAdaptor adaptor) {
// complex.log(complex.exp(a)) -> a
if (auto expOp = getOperand().getDefiningOp<ExpOp>())
return expOp.getOperand();
return {};
}
//===----------------------------------------------------------------------===//
// ExpOp
//===----------------------------------------------------------------------===//
OpFoldResult ExpOp::fold(FoldAdaptor adaptor) {
// complex.exp(complex.log(a)) -> a
if (auto logOp = getOperand().getDefiningOp<LogOp>())
return logOp.getOperand();
return {};
}
//===----------------------------------------------------------------------===//
// ConjOp
//===----------------------------------------------------------------------===//
OpFoldResult ConjOp::fold(FoldAdaptor adaptor) {
// complex.conj(complex.conj(a)) -> a
if (auto conjOp = getOperand().getDefiningOp<ConjOp>())
return conjOp.getOperand();
return {};
}
//===----------------------------------------------------------------------===//
// MulOp
//===----------------------------------------------------------------------===//
OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
auto constant = getRhs().getDefiningOp<ConstantOp>();
if (!constant)
return {};
ArrayAttr arrayAttr = constant.getValue();
APFloat real = cast<FloatAttr>(arrayAttr[0]).getValue();
APFloat imag = cast<FloatAttr>(arrayAttr[1]).getValue();
if (!imag.isZero())
return {};
// complex.mul(a, complex.constant<1.0, 0.0>) -> a
if (real == APFloat(real.getSemantics(), 1))
return getLhs();
return {};
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"