//===- ComplexToSPIRV.cpp - Complex to SPIR-V Patterns --------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements patterns to convert Complex dialect to SPIR-V dialect. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRV.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "complex-to-spirv-pattern" using namespace mlir; //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// namespace { struct ConstantOpPattern final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto spirvType = getTypeConverter()->convertType(constOp.getType()); if (!spirvType) return rewriter.notifyMatchFailure(constOp, "unable to convert result type"); rewriter.replaceOpWithNewOp( constOp, spirvType, DenseElementsAttr::get(spirvType, constOp.getValue().getValue())); return success(); } }; struct CreateOpPattern final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::CreateOp createOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type spirvType = getTypeConverter()->convertType(createOp.getType()); if (!spirvType) return rewriter.notifyMatchFailure(createOp, "unable to convert result type"); rewriter.replaceOpWithNewOp( createOp, spirvType, adaptor.getOperands()); return success(); } }; struct ReOpPattern final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::ReOp reOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type spirvType = getTypeConverter()->convertType(reOp.getType()); if (!spirvType) return rewriter.notifyMatchFailure(reOp, "unable to convert result type"); rewriter.replaceOpWithNewOp( reOp, adaptor.getComplex(), llvm::ArrayRef(0)); return success(); } }; struct ImOpPattern final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(complex::ImOp imOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type spirvType = getTypeConverter()->convertType(imOp.getType()); if (!spirvType) return rewriter.notifyMatchFailure(imOp, "unable to convert result type"); rewriter.replaceOpWithNewOp( imOp, adaptor.getComplex(), llvm::ArrayRef(1)); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // Pattern population //===----------------------------------------------------------------------===// void mlir::populateComplexToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); patterns.add( typeConverter, context); }