bolt/deps/llvm-18.1.8/mlir/lib/Conversion/ComplexToSPIRV/ComplexToSPIRV.cpp

112 lines
4.1 KiB
C++
Raw Normal View History

2025-02-14 19:21:04 +01:00
//===- ComplexToSPIRV.cpp - Complex to SPIR-V Patterns --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns to convert Complex dialect to SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRV.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "complex-to-spirv-pattern"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//
namespace {
struct ConstantOpPattern final : OpConversionPattern<complex::ConstantOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(complex::ConstantOp constOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto spirvType =
getTypeConverter()->convertType<ShapedType>(constOp.getType());
if (!spirvType)
return rewriter.notifyMatchFailure(constOp,
"unable to convert result type");
rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
constOp, spirvType,
DenseElementsAttr::get(spirvType, constOp.getValue().getValue()));
return success();
}
};
struct CreateOpPattern final : OpConversionPattern<complex::CreateOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(complex::CreateOp createOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type spirvType = getTypeConverter()->convertType(createOp.getType());
if (!spirvType)
return rewriter.notifyMatchFailure(createOp,
"unable to convert result type");
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
createOp, spirvType, adaptor.getOperands());
return success();
}
};
struct ReOpPattern final : OpConversionPattern<complex::ReOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(complex::ReOp reOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type spirvType = getTypeConverter()->convertType(reOp.getType());
if (!spirvType)
return rewriter.notifyMatchFailure(reOp, "unable to convert result type");
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
reOp, adaptor.getComplex(), llvm::ArrayRef(0));
return success();
}
};
struct ImOpPattern final : OpConversionPattern<complex::ImOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(complex::ImOp imOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type spirvType = getTypeConverter()->convertType(imOp.getType());
if (!spirvType)
return rewriter.notifyMatchFailure(imOp, "unable to convert result type");
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
imOp, adaptor.getComplex(), llvm::ArrayRef(1));
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//
void mlir::populateComplexToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<ConstantOpPattern, CreateOpPattern, ReOpPattern, ImOpPattern>(
typeConverter, context);
}