//===- MemRefToSPIRV.cpp - MemRef to SPIR-V Patterns ----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements patterns to convert MemRef dialect to SPIR-V dialect. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/Support/Debug.h" #include #define DEBUG_TYPE "memref-to-spirv-pattern" using namespace mlir; //===----------------------------------------------------------------------===// // Utility functions //===----------------------------------------------------------------------===// /// Returns the offset of the value in `targetBits` representation. /// /// `srcIdx` is an index into a 1-D array with each element having `sourceBits`. /// It's assumed to be non-negative. /// /// When accessing an element in the array treating as having elements of /// `targetBits`, multiple values are loaded in the same time. The method /// returns the offset where the `srcIdx` locates in the value. For example, if /// `sourceBits` equals to 8 and `targetBits` equals to 32, the x-th element is /// located at (x % 4) * 8. Because there are four elements in one i32, and one /// element has 8 bits. static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits, int targetBits, OpBuilder &builder) { assert(targetBits % sourceBits == 0); Type type = srcIdx.getType(); IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits); auto idx = builder.create(loc, type, idxAttr); IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits); auto srcBitsValue = builder.create(loc, type, srcBitsAttr); auto m = builder.create(loc, srcIdx, idx); return builder.create(loc, type, m, srcBitsValue); } /// Returns an adjusted spirv::AccessChainOp. Based on the /// extension/capabilities, certain integer bitwidths `sourceBits` might not be /// supported. During conversion if a memref of an unsupported type is used, /// load/stores to this memref need to be modified to use a supported higher /// bitwidth `targetBits` and extracting the required bits. For an accessing a /// 1D array (spirv.array or spirv.rtarray), the last index is modified to load /// the bits needed. The extraction of the actual bits needed are handled /// separately. Note that this only works for a 1-D tensor. static Value adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter, spirv::AccessChainOp op, int sourceBits, int targetBits, OpBuilder &builder) { assert(targetBits % sourceBits == 0); const auto loc = op.getLoc(); Value lastDim = op->getOperand(op.getNumOperands() - 1); Type type = lastDim.getType(); IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits); auto idx = builder.create(loc, type, attr); auto indices = llvm::to_vector<4>(op.getIndices()); // There are two elements if this is a 1-D tensor. assert(indices.size() == 2); indices.back() = builder.create(loc, lastDim, idx); Type t = typeConverter.convertType(op.getComponentPtr().getType()); return builder.create(loc, t, op.getBasePtr(), indices); } /// Casts the given `srcBool` into an integer of `dstType`. static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, OpBuilder &builder) { assert(srcBool.getType().isInteger(1)); if (dstType.isInteger(1)) return srcBool; Value zero = spirv::ConstantOp::getZero(dstType, loc, builder); Value one = spirv::ConstantOp::getOne(dstType, loc, builder); return builder.create(loc, dstType, srcBool, one, zero); } /// Returns the `targetBits`-bit value shifted by the given `offset`, and cast /// to the type destination type, and masked. static Value shiftValue(Location loc, Value value, Value offset, Value mask, OpBuilder &builder) { IntegerType dstType = cast(mask.getType()); int targetBits = static_cast(dstType.getWidth()); int valueBits = value.getType().getIntOrFloatBitWidth(); assert(valueBits <= targetBits); if (valueBits == 1) { value = castBoolToIntN(loc, value, dstType, builder); } else { if (valueBits < targetBits) { value = builder.create( loc, builder.getIntegerType(targetBits), value); } value = builder.create(loc, value, mask); } return builder.create(loc, value.getType(), value, offset); } /// Returns true if the allocations of memref `type` generated from `allocOp` /// can be lowered to SPIR-V. static bool isAllocationSupported(Operation *allocOp, MemRefType type) { if (isa(allocOp)) { auto sc = dyn_cast_or_null(type.getMemorySpace()); if (!sc || sc.getValue() != spirv::StorageClass::Workgroup) return false; } else if (isa(allocOp)) { auto sc = dyn_cast_or_null(type.getMemorySpace()); if (!sc || sc.getValue() != spirv::StorageClass::Function) return false; } else { return false; } // Currently only support static shape and int or float or vector of int or // float element type. if (!type.hasStaticShape()) return false; Type elementType = type.getElementType(); if (auto vecType = dyn_cast(elementType)) elementType = vecType.getElementType(); return elementType.isIntOrFloat(); } /// Returns the scope to use for atomic operations use for emulating store /// operations of unsupported integer bitwidths, based on the memref /// type. Returns std::nullopt on failure. static std::optional getAtomicOpScope(MemRefType type) { auto sc = dyn_cast_or_null(type.getMemorySpace()); switch (sc.getValue()) { case spirv::StorageClass::StorageBuffer: return spirv::Scope::Device; case spirv::StorageClass::Workgroup: return spirv::Scope::Workgroup; default: break; } return {}; } /// Casts the given `srcInt` into a boolean value. static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) { if (srcInt.getType().isInteger(1)) return srcInt; auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder); return builder.create(loc, srcInt, one); } //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// // Note that DRR cannot be used for the patterns in this file: we may need to // convert type along the way, which requires ConversionPattern. DRR generates // normal RewritePattern. namespace { /// Converts memref.alloca to SPIR-V Function variables. class AllocaOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; /// Converts an allocation operation to SPIR-V. Currently only supports lowering /// to Workgroup memory when the size is constant. Note that this pattern needs /// to be applied in a pass that runs at least at spirv.module scope since it /// wil ladd global variables into the spirv.module. class AllocOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; /// Converts memref.automic_rmw operations to SPIR-V atomic operations. class AtomicRMWOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; /// Removed a deallocation if it is a supported allocation. Currently only /// removes deallocation if the memory space is workgroup memory. class DeallocOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; /// Converts memref.load to spirv.Load + spirv.AccessChain on integers. class IntLoadOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; /// Converts memref.load to spirv.Load + spirv.AccessChain. class LoadOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; /// Converts memref.store to spirv.Store on integers. class IntStoreOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; /// Converts memref.memory_space_cast to the appropriate spirv cast operations. class MemorySpaceCastOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; /// Converts memref.store to spirv.Store. class StoreOpPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; class ReinterpretCastPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; class CastPattern final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(memref::CastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value src = adaptor.getSource(); Type srcType = src.getType(); const TypeConverter *converter = getTypeConverter(); Type dstType = converter->convertType(op.getType()); if (srcType != dstType) return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { diag << "types doesn't match: " << srcType << " and " << dstType; }); rewriter.replaceOp(op, src); return success(); } }; } // namespace //===----------------------------------------------------------------------===// // AllocaOp //===----------------------------------------------------------------------===// LogicalResult AllocaOpPattern::matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { MemRefType allocType = allocaOp.getType(); if (!isAllocationSupported(allocaOp, allocType)) return rewriter.notifyMatchFailure(allocaOp, "unhandled allocation type"); // Get the SPIR-V type for the allocation. Type spirvType = getTypeConverter()->convertType(allocType); if (!spirvType) return rewriter.notifyMatchFailure(allocaOp, "type conversion failed"); rewriter.replaceOpWithNewOp(allocaOp, spirvType, spirv::StorageClass::Function, /*initializer=*/nullptr); return success(); } //===----------------------------------------------------------------------===// // AllocOp //===----------------------------------------------------------------------===// LogicalResult AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { MemRefType allocType = operation.getType(); if (!isAllocationSupported(operation, allocType)) return rewriter.notifyMatchFailure(operation, "unhandled allocation type"); // Get the SPIR-V type for the allocation. Type spirvType = getTypeConverter()->convertType(allocType); if (!spirvType) return rewriter.notifyMatchFailure(operation, "type conversion failed"); // Insert spirv.GlobalVariable for this allocation. Operation *parent = SymbolTable::getNearestSymbolTable(operation->getParentOp()); if (!parent) return failure(); Location loc = operation.getLoc(); spirv::GlobalVariableOp varOp; { OpBuilder::InsertionGuard guard(rewriter); Block &entryBlock = *parent->getRegion(0).begin(); rewriter.setInsertionPointToStart(&entryBlock); auto varOps = entryBlock.getOps(); std::string varName = std::string("__workgroup_mem__") + std::to_string(std::distance(varOps.begin(), varOps.end())); varOp = rewriter.create(loc, spirvType, varName, /*initializer=*/nullptr); } // Get pointer to global variable at the current scope. rewriter.replaceOpWithNewOp(operation, varOp); return success(); } //===----------------------------------------------------------------------===// // AllocOp //===----------------------------------------------------------------------===// LogicalResult AtomicRMWOpPattern::matchAndRewrite(memref::AtomicRMWOp atomicOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { if (isa(atomicOp.getType())) return rewriter.notifyMatchFailure(atomicOp, "unimplemented floating-point case"); auto memrefType = cast(atomicOp.getMemref().getType()); std::optional scope = getAtomicOpScope(memrefType); if (!scope) return rewriter.notifyMatchFailure(atomicOp, "unsupported memref memory space"); auto &typeConverter = *getTypeConverter(); Type resultType = typeConverter.convertType(atomicOp.getType()); if (!resultType) return rewriter.notifyMatchFailure(atomicOp, "failed to convert result type"); auto loc = atomicOp.getLoc(); Value ptr = spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(), adaptor.getIndices(), loc, rewriter); if (!ptr) return failure(); #define ATOMIC_CASE(kind, spirvOp) \ case arith::AtomicRMWKind::kind: \ rewriter.replaceOpWithNewOp( \ atomicOp, resultType, ptr, *scope, \ spirv::MemorySemantics::AcquireRelease, adaptor.getValue()); \ break switch (atomicOp.getKind()) { ATOMIC_CASE(addi, AtomicIAddOp); ATOMIC_CASE(maxs, AtomicSMaxOp); ATOMIC_CASE(maxu, AtomicUMaxOp); ATOMIC_CASE(mins, AtomicSMinOp); ATOMIC_CASE(minu, AtomicUMinOp); ATOMIC_CASE(ori, AtomicOrOp); ATOMIC_CASE(andi, AtomicAndOp); default: return rewriter.notifyMatchFailure(atomicOp, "unimplemented atomic kind"); } #undef ATOMIC_CASE return success(); } //===----------------------------------------------------------------------===// // DeallocOp //===----------------------------------------------------------------------===// LogicalResult DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { MemRefType deallocType = cast(operation.getMemref().getType()); if (!isAllocationSupported(operation, deallocType)) return rewriter.notifyMatchFailure(operation, "unhandled allocation type"); rewriter.eraseOp(operation); return success(); } //===----------------------------------------------------------------------===// // LoadOp //===----------------------------------------------------------------------===// LogicalResult IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto loc = loadOp.getLoc(); auto memrefType = cast(loadOp.getMemref().getType()); if (!memrefType.getElementType().isSignlessInteger()) return failure(); const auto &typeConverter = *getTypeConverter(); Value accessChain = spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(), adaptor.getIndices(), loc, rewriter); if (!accessChain) return failure(); int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); bool isBool = srcBits == 1; if (isBool) srcBits = typeConverter.getOptions().boolNumBits; auto pointerType = typeConverter.convertType(memrefType); if (!pointerType) return rewriter.notifyMatchFailure(loadOp, "failed to convert memref type"); Type pointeeType = pointerType.getPointeeType(); Type dstType; if (typeConverter.allows(spirv::Capability::Kernel)) { if (auto arrayType = dyn_cast(pointeeType)) dstType = arrayType.getElementType(); else dstType = pointeeType; } else { // For Vulkan we need to extract element from wrapping struct and array. Type structElemType = cast(pointeeType).getElementType(0); if (auto arrayType = dyn_cast(structElemType)) dstType = arrayType.getElementType(); else dstType = cast(structElemType).getElementType(); } int dstBits = dstType.getIntOrFloatBitWidth(); assert(dstBits % srcBits == 0); // If the rewritten load op has the same bit width, use the loading value // directly. if (srcBits == dstBits) { Value loadVal = rewriter.create(loc, accessChain); if (isBool) loadVal = castIntNToBool(loc, loadVal, rewriter); rewriter.replaceOp(loadOp, loadVal); return success(); } // Bitcasting is currently unsupported for Kernel capability / // spirv.PtrAccessChain. if (typeConverter.allows(spirv::Capability::Kernel)) return failure(); auto accessChainOp = accessChain.getDefiningOp(); if (!accessChainOp) return failure(); // Assume that getElementPtr() works linearizely. If it's a scalar, the method // still returns a linearized accessing. If the accessing is not linearized, // there will be offset issues. assert(accessChainOp.getIndices().size() == 2); Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, srcBits, dstBits, rewriter); Value spvLoadOp = rewriter.create( loc, dstType, adjustedPtr, loadOp->getAttrOfType( spirv::attributeName()), loadOp->getAttrOfType("alignment")); // Shift the bits to the rightmost. // ____XXXX________ -> ____________XXXX Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); Value result = rewriter.create( loc, spvLoadOp.getType(), spvLoadOp, offset); // Apply the mask to extract corresponding bits. Value mask = rewriter.create( loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); result = rewriter.create(loc, dstType, result, mask); // Apply sign extension on the loading value unconditionally. The signedness // semantic is carried in the operator itself, we relies other pattern to // handle the casting. IntegerAttr shiftValueAttr = rewriter.getIntegerAttr(dstType, dstBits - srcBits); Value shiftValue = rewriter.create(loc, dstType, shiftValueAttr); result = rewriter.create(loc, dstType, result, shiftValue); result = rewriter.create(loc, dstType, result, shiftValue); rewriter.replaceOp(loadOp, result); assert(accessChainOp.use_empty()); rewriter.eraseOp(accessChainOp); return success(); } LogicalResult LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto memrefType = cast(loadOp.getMemref().getType()); if (memrefType.getElementType().isSignlessInteger()) return failure(); auto loadPtr = spirv::getElementPtr( *getTypeConverter(), memrefType, adaptor.getMemref(), adaptor.getIndices(), loadOp.getLoc(), rewriter); if (!loadPtr) return failure(); rewriter.replaceOpWithNewOp(loadOp, loadPtr); return success(); } LogicalResult IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto memrefType = cast(storeOp.getMemref().getType()); if (!memrefType.getElementType().isSignlessInteger()) return rewriter.notifyMatchFailure(storeOp, "element type is not a signless int"); auto loc = storeOp.getLoc(); auto &typeConverter = *getTypeConverter(); Value accessChain = spirv::getElementPtr(typeConverter, memrefType, adaptor.getMemref(), adaptor.getIndices(), loc, rewriter); if (!accessChain) return rewriter.notifyMatchFailure( storeOp, "failed to convert element pointer type"); int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); bool isBool = srcBits == 1; if (isBool) srcBits = typeConverter.getOptions().boolNumBits; auto pointerType = typeConverter.convertType(memrefType); if (!pointerType) return rewriter.notifyMatchFailure(storeOp, "failed to convert memref type"); Type pointeeType = pointerType.getPointeeType(); IntegerType dstType; if (typeConverter.allows(spirv::Capability::Kernel)) { if (auto arrayType = dyn_cast(pointeeType)) dstType = dyn_cast(arrayType.getElementType()); else dstType = dyn_cast(pointeeType); } else { // For Vulkan we need to extract element from wrapping struct and array. Type structElemType = cast(pointeeType).getElementType(0); if (auto arrayType = dyn_cast(structElemType)) dstType = dyn_cast(arrayType.getElementType()); else dstType = dyn_cast( cast(structElemType).getElementType()); } if (!dstType) return rewriter.notifyMatchFailure( storeOp, "failed to determine destination element type"); int dstBits = static_cast(dstType.getWidth()); assert(dstBits % srcBits == 0); if (srcBits == dstBits) { Value storeVal = adaptor.getValue(); if (isBool) storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter); rewriter.replaceOpWithNewOp(storeOp, accessChain, storeVal); return success(); } // Bitcasting is currently unsupported for Kernel capability / // spirv.PtrAccessChain. if (typeConverter.allows(spirv::Capability::Kernel)) return failure(); auto accessChainOp = accessChain.getDefiningOp(); if (!accessChainOp) return failure(); // Since there are multiple threads in the processing, the emulation will be // done with atomic operations. E.g., if the stored value is i8, rewrite the // StoreOp to: // 1) load a 32-bit integer // 2) clear 8 bits in the loaded value // 3) set 8 bits in the loaded value // 4) store 32-bit value back // // Step 2 is done with AtomicAnd, and step 3 is done with AtomicOr (of the // loaded 32-bit value and the shifted 8-bit store value) as another atomic // step. assert(accessChainOp.getIndices().size() == 2); Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1); Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter); // Create a mask to clear the destination. E.g., if it is the second i8 in // i32, 0xFFFF00FF is created. Value mask = rewriter.create( loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1)); Value clearBitsMask = rewriter.create(loc, dstType, mask, offset); clearBitsMask = rewriter.create(loc, dstType, clearBitsMask); Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter); Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, srcBits, dstBits, rewriter); std::optional scope = getAtomicOpScope(memrefType); if (!scope) return rewriter.notifyMatchFailure(storeOp, "atomic scope not available"); Value result = rewriter.create( loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, clearBitsMask); result = rewriter.create( loc, dstType, adjustedPtr, *scope, spirv::MemorySemantics::AcquireRelease, storeVal); // The AtomicOrOp has no side effect. Since it is already inserted, we can // just remove the original StoreOp. Note that rewriter.replaceOp() // doesn't work because it only accepts that the numbers of result are the // same. rewriter.eraseOp(storeOp); assert(accessChainOp.use_empty()); rewriter.eraseOp(accessChainOp); return success(); } //===----------------------------------------------------------------------===// // MemorySpaceCastOp //===----------------------------------------------------------------------===// LogicalResult MemorySpaceCastOpPattern::matchAndRewrite( memref::MemorySpaceCastOp addrCastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = addrCastOp.getLoc(); auto &typeConverter = *getTypeConverter(); if (!typeConverter.allows(spirv::Capability::Kernel)) return rewriter.notifyMatchFailure( loc, "address space casts require kernel capability"); auto sourceType = dyn_cast(addrCastOp.getSource().getType()); if (!sourceType) return rewriter.notifyMatchFailure( loc, "SPIR-V lowering requires ranked memref types"); auto resultType = cast(addrCastOp.getResult().getType()); auto sourceStorageClassAttr = dyn_cast_or_null(sourceType.getMemorySpace()); if (!sourceStorageClassAttr) return rewriter.notifyMatchFailure(loc, [sourceType](Diagnostic &diag) { diag << "source address space " << sourceType.getMemorySpace() << " must be a SPIR-V storage class"; }); auto resultStorageClassAttr = dyn_cast_or_null(resultType.getMemorySpace()); if (!resultStorageClassAttr) return rewriter.notifyMatchFailure(loc, [resultType](Diagnostic &diag) { diag << "result address space " << resultType.getMemorySpace() << " must be a SPIR-V storage class"; }); spirv::StorageClass sourceSc = sourceStorageClassAttr.getValue(); spirv::StorageClass resultSc = resultStorageClassAttr.getValue(); Value result = adaptor.getSource(); Type resultPtrType = typeConverter.convertType(resultType); if (!resultPtrType) return rewriter.notifyMatchFailure(addrCastOp, "failed to convert memref type"); Type genericPtrType = resultPtrType; // SPIR-V doesn't have a general address space cast operation. Instead, it has // conversions to and from generic pointers. To implement the general case, // we use specific-to-generic conversions when the source class is not // generic. Then when the result storage class is not generic, we convert the // generic pointer (either the input on ar intermediate result) to that // class. This also means that we'll need the intermediate generic pointer // type if neither the source or destination have it. if (sourceSc != spirv::StorageClass::Generic && resultSc != spirv::StorageClass::Generic) { Type intermediateType = MemRefType::get(sourceType.getShape(), sourceType.getElementType(), sourceType.getLayout(), rewriter.getAttr( spirv::StorageClass::Generic)); genericPtrType = typeConverter.convertType(intermediateType); } if (sourceSc != spirv::StorageClass::Generic) { result = rewriter.create(loc, genericPtrType, result); } if (resultSc != spirv::StorageClass::Generic) { result = rewriter.create(loc, resultPtrType, result); } rewriter.replaceOp(addrCastOp, result); return success(); } LogicalResult StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto memrefType = cast(storeOp.getMemref().getType()); if (memrefType.getElementType().isSignlessInteger()) return rewriter.notifyMatchFailure(storeOp, "signless int"); auto storePtr = spirv::getElementPtr( *getTypeConverter(), memrefType, adaptor.getMemref(), adaptor.getIndices(), storeOp.getLoc(), rewriter); if (!storePtr) return rewriter.notifyMatchFailure(storeOp, "type conversion failed"); rewriter.replaceOpWithNewOp(storeOp, storePtr, adaptor.getValue()); return success(); } LogicalResult ReinterpretCastPattern::matchAndRewrite( memref::ReinterpretCastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value src = adaptor.getSource(); auto srcType = dyn_cast(src.getType()); if (!srcType) return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { diag << "invalid src type " << src.getType(); }); const TypeConverter *converter = getTypeConverter(); auto dstType = converter->convertType(op.getType()); if (dstType != srcType) return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { diag << "invalid dst type " << op.getType(); }); OpFoldResult offset = getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter) .front(); if (isConstantIntValue(offset, 0)) { rewriter.replaceOp(op, src); return success(); } Type intType = converter->convertType(rewriter.getIndexType()); if (!intType) return rewriter.notifyMatchFailure(op, "failed to convert index type"); Location loc = op.getLoc(); auto offsetValue = [&]() -> Value { if (auto val = dyn_cast(offset)) return val; int64_t attrVal = cast(offset.get()).getInt(); Attribute attr = rewriter.getIntegerAttr(intType, attrVal); return rewriter.create(loc, intType, attr); }(); rewriter.replaceOpWithNewOp( op, src, offsetValue, std::nullopt); return success(); } //===----------------------------------------------------------------------===// // Pattern population //===----------------------------------------------------------------------===// namespace mlir { void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add(typeConverter, patterns.getContext()); } } // namespace mlir