//===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements utilities used to lower to SPIR-V dialect. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Debug.h" #include "llvm/Support/MathExtras.h" #include #include #define DEBUG_TYPE "mlir-spirv-conversion" using namespace mlir; //===----------------------------------------------------------------------===// // Utility functions //===----------------------------------------------------------------------===// /// Checks that `candidates` extension requirements are possible to be satisfied /// with the given `targetEnv`. /// /// `candidates` is a vector of vector for extension requirements following /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) /// convention. template static LogicalResult checkExtensionRequirements( LabelT label, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::ExtensionArrayRefVector &candidates) { for (const auto &ors : candidates) { if (targetEnv.allows(ors)) continue; LLVM_DEBUG({ SmallVector extStrings; for (spirv::Extension ext : ors) extStrings.push_back(spirv::stringifyExtension(ext)); llvm::dbgs() << label << " illegal: requires at least one extension in [" << llvm::join(extStrings, ", ") << "] but none allowed in target environment\n"; }); return failure(); } return success(); } /// Checks that `candidates`capability requirements are possible to be satisfied /// with the given `isAllowedFn`. /// /// `candidates` is a vector of vector for capability requirements following /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D)) /// convention. template static LogicalResult checkCapabilityRequirements( LabelT label, const spirv::TargetEnv &targetEnv, const spirv::SPIRVType::CapabilityArrayRefVector &candidates) { for (const auto &ors : candidates) { if (targetEnv.allows(ors)) continue; LLVM_DEBUG({ SmallVector capStrings; for (spirv::Capability cap : ors) capStrings.push_back(spirv::stringifyCapability(cap)); llvm::dbgs() << label << " illegal: requires at least one capability in [" << llvm::join(capStrings, ", ") << "] but none allowed in target environment\n"; }); return failure(); } return success(); } /// Returns true if the given `storageClass` needs explicit layout when used in /// Shader environments. static bool needsExplicitLayout(spirv::StorageClass storageClass) { switch (storageClass) { case spirv::StorageClass::PhysicalStorageBuffer: case spirv::StorageClass::PushConstant: case spirv::StorageClass::StorageBuffer: case spirv::StorageClass::Uniform: return true; default: return false; } } /// Wraps the given `elementType` in a struct and gets the pointer to the /// struct. This is used to satisfy Vulkan interface requirements. static spirv::PointerType wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) { auto structType = needsExplicitLayout(storageClass) ? spirv::StructType::get(elementType, /*offsetInfo=*/0) : spirv::StructType::get(elementType); return spirv::PointerType::get(structType, storageClass); } //===----------------------------------------------------------------------===// // Type Conversion //===----------------------------------------------------------------------===// static spirv::ScalarType getIndexType(MLIRContext *ctx, const SPIRVConversionOptions &options) { return cast( IntegerType::get(ctx, options.use64bitIndex ? 64 : 32)); } Type SPIRVTypeConverter::getIndexType() const { return ::getIndexType(getContext(), options); } MLIRContext *SPIRVTypeConverter::getContext() const { return targetEnv.getAttr().getContext(); } bool SPIRVTypeConverter::allows(spirv::Capability capability) const { return targetEnv.allows(capability); } // TODO: This is a utility function that should probably be exposed by the // SPIR-V dialect. Keeping it local till the use case arises. static std::optional getTypeNumBytes(const SPIRVConversionOptions &options, Type type) { if (isa(type)) { auto bitWidth = type.getIntOrFloatBitWidth(); // According to the SPIR-V spec: // "There is no physical size or bit pattern defined for values with boolean // type. If they are stored (in conjunction with OpVariable), they can only // be used with logical addressing operations, not physical, and only with // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup, // Private, Function, Input, and Output." if (bitWidth == 1) return std::nullopt; return bitWidth / 8; } if (auto complexType = dyn_cast(type)) { auto elementSize = getTypeNumBytes(options, complexType.getElementType()); if (!elementSize) return std::nullopt; return 2 * *elementSize; } if (auto vecType = dyn_cast(type)) { auto elementSize = getTypeNumBytes(options, vecType.getElementType()); if (!elementSize) return std::nullopt; return vecType.getNumElements() * *elementSize; } if (auto memRefType = dyn_cast(type)) { // TODO: Layout should also be controlled by the ABI attributes. For now // using the layout from MemRef. int64_t offset; SmallVector strides; if (!memRefType.hasStaticShape() || failed(getStridesAndOffset(memRefType, strides, offset))) return std::nullopt; // To get the size of the memref object in memory, the total size is the // max(stride * dimension-size) computed for all dimensions times the size // of the element. auto elementSize = getTypeNumBytes(options, memRefType.getElementType()); if (!elementSize) return std::nullopt; if (memRefType.getRank() == 0) return elementSize; auto dims = memRefType.getShape(); if (llvm::is_contained(dims, ShapedType::kDynamic) || ShapedType::isDynamic(offset) || llvm::is_contained(strides, ShapedType::kDynamic)) return std::nullopt; int64_t memrefSize = -1; for (const auto &shape : enumerate(dims)) memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]); return (offset + memrefSize) * *elementSize; } if (auto tensorType = dyn_cast(type)) { if (!tensorType.hasStaticShape()) return std::nullopt; auto elementSize = getTypeNumBytes(options, tensorType.getElementType()); if (!elementSize) return std::nullopt; int64_t size = *elementSize; for (auto shape : tensorType.getShape()) size *= shape; return size; } // TODO: Add size computation for other types. return std::nullopt; } /// Converts a scalar `type` to a suitable type under the given `targetEnv`. static Type convertScalarType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, spirv::ScalarType type, std::optional storageClass = {}) { // Get extension and capability requirements for the given type. SmallVector, 1> extensions; SmallVector, 2> capabilities; type.getExtensions(extensions, storageClass); type.getCapabilities(capabilities, storageClass); // If all requirements are met, then we can accept this type as-is. if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && succeeded(checkExtensionRequirements(type, targetEnv, extensions))) return type; // Otherwise we need to adjust the type, which really means adjusting the // bitwidth given this is a scalar type. if (!options.emulateLT32BitScalarTypes) return nullptr; // We only emulate narrower scalar types here and do not truncate results. if (type.getIntOrFloatBitWidth() > 32) { LLVM_DEBUG(llvm::dbgs() << type << " not converted to 32-bit for SPIR-V to avoid truncation\n"); return nullptr; } if (auto floatType = dyn_cast(type)) { LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); return Builder(targetEnv.getContext()).getF32Type(); } auto intType = cast(type); LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); return IntegerType::get(targetEnv.getContext(), /*width=*/32, intType.getSignedness()); } /// Converts a sub-byte integer `type` to i32 regardless of target environment. /// /// Note that we don't recognize sub-byte types in `spirv::ScalarType` and use /// the above given that these sub-byte types are not supported at all in /// SPIR-V; there are no compute/storage capability for them like other /// supported integer types. static Type convertSubByteIntegerType(const SPIRVConversionOptions &options, IntegerType type) { if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) { LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n"); return nullptr; } if (!llvm::isPowerOf2_32(type.getWidth())) { LLVM_DEBUG(llvm::dbgs() << "unsupported non-power-of-two bitwidth in sub-byte" << type << "\n"); return nullptr; } LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n"); return IntegerType::get(type.getContext(), /*width=*/32, type.getSignedness()); } /// Returns a type with the same shape but with any index element type converted /// to the matching integer type. This is a noop when the element type is not /// the index type. static ShapedType convertIndexElementType(ShapedType type, const SPIRVConversionOptions &options) { Type indexType = dyn_cast(type.getElementType()); if (!indexType) return type; return type.clone(getIndexType(type.getContext(), options)); } /// Converts a vector `type` to a suitable type under the given `targetEnv`. static Type convertVectorType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, VectorType type, std::optional storageClass = {}) { type = cast(convertIndexElementType(type, options)); auto scalarType = dyn_cast_or_null(type.getElementType()); if (!scalarType) { // If this is not a spec allowed scalar type, try to handle sub-byte integer // types. auto intType = dyn_cast(type.getElementType()); if (!intType) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot convert non-scalar element type\n"); return nullptr; } Type elementType = convertSubByteIntegerType(options, intType); if (type.getRank() <= 1 && type.getNumElements() == 1) return elementType; if (type.getNumElements() > 4) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: > 4-element unimplemented\n"); return nullptr; } return VectorType::get(type.getShape(), elementType); } if (type.getRank() <= 1 && type.getNumElements() == 1) return convertScalarType(targetEnv, options, scalarType, storageClass); if (!spirv::CompositeType::isValid(type)) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: not a valid composite type\n"); return nullptr; } // Get extension and capability requirements for the given type. SmallVector, 1> extensions; SmallVector, 2> capabilities; cast(type).getExtensions(extensions, storageClass); cast(type).getCapabilities(capabilities, storageClass); // If all requirements are met, then we can accept this type as-is. if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) && succeeded(checkExtensionRequirements(type, targetEnv, extensions))) return type; auto elementType = convertScalarType(targetEnv, options, scalarType, storageClass); if (elementType) return VectorType::get(type.getShape(), elementType); return nullptr; } static Type convertComplexType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, ComplexType type, std::optional storageClass = {}) { auto scalarType = dyn_cast_or_null(type.getElementType()); if (!scalarType) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot convert non-scalar element type\n"); return nullptr; } auto elementType = convertScalarType(targetEnv, options, scalarType, storageClass); if (!elementType) return nullptr; if (elementType != type.getElementType()) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: complex type emulation unsupported\n"); return nullptr; } return VectorType::get(2, elementType); } /// Converts a tensor `type` to a suitable type under the given `targetEnv`. /// /// Note that this is mainly for lowering constant tensors. In SPIR-V one can /// create composite constants with OpConstantComposite to embed relative large /// constant values and use OpCompositeExtract and OpCompositeInsert to /// manipulate, like what we do for vectors. static Type convertTensorType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, TensorType type) { // TODO: Handle dynamic shapes. if (!type.hasStaticShape()) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: dynamic shape unimplemented\n"); return nullptr; } type = cast(convertIndexElementType(type, options)); auto scalarType = dyn_cast_or_null(type.getElementType()); if (!scalarType) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot convert non-scalar element type\n"); return nullptr; } std::optional scalarSize = getTypeNumBytes(options, scalarType); std::optional tensorSize = getTypeNumBytes(options, type); if (!scalarSize || !tensorSize) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot deduce element count\n"); return nullptr; } int64_t arrayElemCount = *tensorSize / *scalarSize; if (arrayElemCount == 0) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot handle zero-element tensors\n"); return nullptr; } Type arrayElemType = convertScalarType(targetEnv, options, scalarType); if (!arrayElemType) return nullptr; std::optional arrayElemSize = getTypeNumBytes(options, arrayElemType); if (!arrayElemSize) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot deduce converted element size\n"); return nullptr; } return spirv::ArrayType::get(arrayElemType, arrayElemCount); } static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, MemRefType type, spirv::StorageClass storageClass) { unsigned numBoolBits = options.boolNumBits; if (numBoolBits != 8) { LLVM_DEBUG(llvm::dbgs() << "using non-8-bit storage for bool types unimplemented"); return nullptr; } auto elementType = dyn_cast( IntegerType::get(type.getContext(), numBoolBits)); if (!elementType) return nullptr; Type arrayElemType = convertScalarType(targetEnv, options, elementType, storageClass); if (!arrayElemType) return nullptr; std::optional arrayElemSize = getTypeNumBytes(options, arrayElemType); if (!arrayElemSize) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot deduce converted element size\n"); return nullptr; } if (!type.hasStaticShape()) { // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing // to the element. if (targetEnv.allows(spirv::Capability::Kernel)) return spirv::PointerType::get(arrayElemType, storageClass); int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0; auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride); // For Vulkan we need extra wrapping struct and array to satisfy interface // needs. return wrapInStructAndGetPointer(arrayType, storageClass); } if (type.getNumElements() == 0) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: zero-element memrefs are not supported\n"); return nullptr; } int64_t memrefSize = llvm::divideCeil(type.getNumElements() * numBoolBits, 8); int64_t arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize); int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0; auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride); if (targetEnv.allows(spirv::Capability::Kernel)) return spirv::PointerType::get(arrayType, storageClass); return wrapInStructAndGetPointer(arrayType, storageClass); } static Type convertSubByteMemrefType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, MemRefType type, spirv::StorageClass storageClass) { IntegerType elementType = cast(type.getElementType()); Type arrayElemType = convertSubByteIntegerType(options, elementType); if (!arrayElemType) return nullptr; int64_t arrayElemSize = *getTypeNumBytes(options, arrayElemType); if (!type.hasStaticShape()) { // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing // to the element. if (targetEnv.allows(spirv::Capability::Kernel)) return spirv::PointerType::get(arrayElemType, storageClass); int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0; auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride); // For Vulkan we need extra wrapping struct and array to satisfy interface // needs. return wrapInStructAndGetPointer(arrayType, storageClass); } if (type.getNumElements() == 0) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: zero-element memrefs are not supported\n"); return nullptr; } int64_t memrefSize = llvm::divideCeil(type.getNumElements() * elementType.getWidth(), 8); int64_t arrayElemCount = llvm::divideCeil(memrefSize, arrayElemSize); int64_t stride = needsExplicitLayout(storageClass) ? arrayElemSize : 0; auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride); if (targetEnv.allows(spirv::Capability::Kernel)) return spirv::PointerType::get(arrayType, storageClass); return wrapInStructAndGetPointer(arrayType, storageClass); } static Type convertMemrefType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, MemRefType type) { auto attr = dyn_cast_or_null(type.getMemorySpace()); if (!attr) { LLVM_DEBUG( llvm::dbgs() << type << " illegal: expected memory space to be a SPIR-V storage class " "attribute; please use MemorySpaceToStorageClassConverter to map " "numeric memory spaces beforehand\n"); return nullptr; } spirv::StorageClass storageClass = attr.getValue(); if (isa(type.getElementType())) { if (type.getElementTypeBitWidth() == 1) return convertBoolMemrefType(targetEnv, options, type, storageClass); if (type.getElementTypeBitWidth() < 8) return convertSubByteMemrefType(targetEnv, options, type, storageClass); } Type arrayElemType; Type elementType = type.getElementType(); if (auto vecType = dyn_cast(elementType)) { arrayElemType = convertVectorType(targetEnv, options, vecType, storageClass); } else if (auto complexType = dyn_cast(elementType)) { arrayElemType = convertComplexType(targetEnv, options, complexType, storageClass); } else if (auto scalarType = dyn_cast(elementType)) { arrayElemType = convertScalarType(targetEnv, options, scalarType, storageClass); } else if (auto indexType = dyn_cast(elementType)) { type = cast(convertIndexElementType(type, options)); arrayElemType = type.getElementType(); } else { LLVM_DEBUG( llvm::dbgs() << type << " unhandled: can only convert scalar or vector element type\n"); return nullptr; } if (!arrayElemType) return nullptr; std::optional arrayElemSize = getTypeNumBytes(options, arrayElemType); if (!arrayElemSize) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot deduce converted element size\n"); return nullptr; } if (!type.hasStaticShape()) { // For OpenCL Kernel, dynamic shaped memrefs convert into a pointer pointing // to the element. if (targetEnv.allows(spirv::Capability::Kernel)) return spirv::PointerType::get(arrayElemType, storageClass); int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0; auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride); // For Vulkan we need extra wrapping struct and array to satisfy interface // needs. return wrapInStructAndGetPointer(arrayType, storageClass); } std::optional memrefSize = getTypeNumBytes(options, type); if (!memrefSize) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot deduce element count\n"); return nullptr; } if (*memrefSize == 0) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: zero-element memrefs are not supported\n"); return nullptr; } int64_t arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize); int64_t stride = needsExplicitLayout(storageClass) ? *arrayElemSize : 0; auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride); if (targetEnv.allows(spirv::Capability::Kernel)) return spirv::PointerType::get(arrayType, storageClass); return wrapInStructAndGetPointer(arrayType, storageClass); } //===----------------------------------------------------------------------===// // Type casting materialization //===----------------------------------------------------------------------===// /// Converts the given `inputs` to the original source `type` considering the /// `targetEnv`'s capabilities. /// /// This function is meant to be used for source materialization in type /// converters. When the type converter needs to materialize a cast op back /// to some original source type, we need to check whether the original source /// type is supported in the target environment. If so, we can insert legal /// SPIR-V cast ops accordingly. /// /// Note that in SPIR-V the capabilities for storage and compute are separate. /// This function is meant to handle the **compute** side; so it does not /// involve storage classes in its logic. The storage side is expected to be /// handled by MemRef conversion logic. std::optional castToSourceType(const spirv::TargetEnv &targetEnv, OpBuilder &builder, Type type, ValueRange inputs, Location loc) { // We can only cast one value in SPIR-V. if (inputs.size() != 1) { auto castOp = builder.create(loc, type, inputs); return castOp.getResult(0); } Value input = inputs.front(); // Only support integer types for now. Floating point types to be implemented. if (!isa(type)) { auto castOp = builder.create(loc, type, inputs); return castOp.getResult(0); } auto inputType = cast(input.getType()); auto scalarType = dyn_cast(type); if (!scalarType) { auto castOp = builder.create(loc, type, inputs); return castOp.getResult(0); } // Only support source type with a smaller bitwidth. This would mean we are // truncating to go back so we don't need to worry about the signedness. // For extension, we cannot have enough signal here to decide which op to use. if (inputType.getIntOrFloatBitWidth() < scalarType.getIntOrFloatBitWidth()) { auto castOp = builder.create(loc, type, inputs); return castOp.getResult(0); } // Boolean values would need to use different ops than normal integer values. if (type.isInteger(1)) { Value one = spirv::ConstantOp::getOne(inputType, loc, builder); return builder.create(loc, input, one); } // Check that the source integer type is supported by the environment. SmallVector, 1> exts; SmallVector, 2> caps; scalarType.getExtensions(exts); scalarType.getCapabilities(caps); if (failed(checkCapabilityRequirements(type, targetEnv, caps)) || failed(checkExtensionRequirements(type, targetEnv, exts))) { auto castOp = builder.create(loc, type, inputs); return castOp.getResult(0); } // We've already made sure this is truncating previously, so we don't need to // care about signedness here. Still try to use a corresponding op for better // consistency though. if (type.isSignedInteger()) { return builder.create(loc, type, input); } return builder.create(loc, type, input); } //===----------------------------------------------------------------------===// // SPIRVTypeConverter //===----------------------------------------------------------------------===// SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, const SPIRVConversionOptions &options) : targetEnv(targetAttr), options(options) { // Add conversions. The order matters here: later ones will be tried earlier. // Allow all SPIR-V dialect specific types. This assumes all builtin types // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType) // were tried before. // // TODO: This assumes that the SPIR-V types are valid to use in the given // target environment, which should be the case if the whole pipeline is // driven by the same target environment. Still, we probably still want to // validate and convert to be safe. addConversion([](spirv::SPIRVType type) { return type; }); addConversion([this](IndexType /*indexType*/) { return getIndexType(); }); addConversion([this](IntegerType intType) -> std::optional { if (auto scalarType = dyn_cast(intType)) return convertScalarType(this->targetEnv, this->options, scalarType); if (intType.getWidth() < 8) return convertSubByteIntegerType(this->options, intType); return Type(); }); addConversion([this](FloatType floatType) -> std::optional { if (auto scalarType = dyn_cast(floatType)) return convertScalarType(this->targetEnv, this->options, scalarType); return Type(); }); addConversion([this](ComplexType complexType) { return convertComplexType(this->targetEnv, this->options, complexType); }); addConversion([this](VectorType vectorType) { return convertVectorType(this->targetEnv, this->options, vectorType); }); addConversion([this](TensorType tensorType) { return convertTensorType(this->targetEnv, this->options, tensorType); }); addConversion([this](MemRefType memRefType) { return convertMemrefType(this->targetEnv, this->options, memRefType); }); // Register some last line of defense casting logic. addSourceMaterialization( [this](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { return castToSourceType(this->targetEnv, builder, type, inputs, loc); }); addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs, Location loc) { auto cast = builder.create(loc, type, inputs); return std::optional(cast.getResult(0)); }); } //===----------------------------------------------------------------------===// // func::FuncOp Conversion Patterns //===----------------------------------------------------------------------===// namespace { /// A pattern for rewriting function signature to convert arguments of functions /// to be of valid SPIR-V types. class FuncOpConversion final : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult FuncOpConversion::matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto fnType = funcOp.getFunctionType(); if (fnType.getNumResults() > 1) return failure(); TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); for (const auto &argType : enumerate(fnType.getInputs())) { auto convertedType = getTypeConverter()->convertType(argType.value()); if (!convertedType) return failure(); signatureConverter.addInputs(argType.index(), convertedType); } Type resultType; if (fnType.getNumResults() == 1) { resultType = getTypeConverter()->convertType(fnType.getResult(0)); if (!resultType) return failure(); } // Create the converted spirv.func op. auto newFuncOp = rewriter.create( funcOp.getLoc(), funcOp.getName(), rewriter.getFunctionType(signatureConverter.getConvertedTypes(), resultType ? TypeRange(resultType) : TypeRange())); // Copy over all attributes other than the function name and type. for (const auto &namedAttr : funcOp->getAttrs()) { if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() && namedAttr.getName() != SymbolTable::getSymbolAttrName()) newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue()); } rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); if (failed(rewriter.convertRegionTypes( &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter))) return failure(); rewriter.eraseOp(funcOp); return success(); } void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add(typeConverter, patterns.getContext()); } //===----------------------------------------------------------------------===// // Builtin Variables //===----------------------------------------------------------------------===// static spirv::GlobalVariableOp getBuiltinVariable(Block &body, spirv::BuiltIn builtin) { // Look through all global variables in the given `body` block and check if // there is a spirv.GlobalVariable that has the same `builtin` attribute. for (auto varOp : body.getOps()) { if (auto builtinAttr = varOp->getAttrOfType( spirv::SPIRVDialect::getAttributeName( spirv::Decoration::BuiltIn))) { auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue()); if (varBuiltIn && *varBuiltIn == builtin) { return varOp; } } } return nullptr; } /// Gets name of global variable for a builtin. static std::string getBuiltinVarName(spirv::BuiltIn builtin, StringRef prefix, StringRef suffix) { return Twine(prefix).concat(stringifyBuiltIn(builtin)).concat(suffix).str(); } /// Gets or inserts a global variable for a builtin within `body` block. static spirv::GlobalVariableOp getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix, StringRef suffix) { if (auto varOp = getBuiltinVariable(body, builtin)) return varOp; OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToStart(&body); spirv::GlobalVariableOp newVarOp; switch (builtin) { case spirv::BuiltIn::NumWorkgroups: case spirv::BuiltIn::WorkgroupSize: case spirv::BuiltIn::WorkgroupId: case spirv::BuiltIn::LocalInvocationId: case spirv::BuiltIn::GlobalInvocationId: { auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType), spirv::StorageClass::Input); std::string name = getBuiltinVarName(builtin, prefix, suffix); newVarOp = builder.create(loc, ptrType, name, builtin); break; } case spirv::BuiltIn::SubgroupId: case spirv::BuiltIn::NumSubgroups: case spirv::BuiltIn::SubgroupSize: { auto ptrType = spirv::PointerType::get(integerType, spirv::StorageClass::Input); std::string name = getBuiltinVarName(builtin, prefix, suffix); newVarOp = builder.create(loc, ptrType, name, builtin); break; } default: emitError(loc, "unimplemented builtin variable generation for ") << stringifyBuiltIn(builtin); } return newVarOp; } Value mlir::spirv::getBuiltinVariableValue(Operation *op, spirv::BuiltIn builtin, Type integerType, OpBuilder &builder, StringRef prefix, StringRef suffix) { Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp()); if (!parent) { op->emitError("expected operation to be within a module-like op"); return nullptr; } spirv::GlobalVariableOp varOp = getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(), builtin, integerType, builder, prefix, suffix); Value ptr = builder.create(op->getLoc(), varOp); return builder.create(op->getLoc(), ptr); } //===----------------------------------------------------------------------===// // Push constant storage //===----------------------------------------------------------------------===// /// Returns the pointer type for the push constant storage containing /// `elementCount` 32-bit integer values. static spirv::PointerType getPushConstantStorageType(unsigned elementCount, Builder &builder, Type indexType) { auto arrayType = spirv::ArrayType::get(indexType, elementCount, /*stride=*/4); auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0); return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant); } /// Returns the push constant varible containing `elementCount` 32-bit integer /// values in `body`. Returns null op if such an op does not exit. static spirv::GlobalVariableOp getPushConstantVariable(Block &body, unsigned elementCount) { for (auto varOp : body.getOps()) { auto ptrType = dyn_cast(varOp.getType()); if (!ptrType) continue; // Note that Vulkan requires "There must be no more than one push constant // block statically used per shader entry point." So we should always reuse // the existing one. if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) { auto numElements = cast( cast(ptrType.getPointeeType()) .getElementType(0)) .getNumElements(); if (numElements == elementCount) return varOp; } } return nullptr; } /// Gets or inserts a global variable for push constant storage containing /// `elementCount` 32-bit integer values in `block`. static spirv::GlobalVariableOp getOrInsertPushConstantVariable(Location loc, Block &block, unsigned elementCount, OpBuilder &b, Type indexType) { if (auto varOp = getPushConstantVariable(block, elementCount)) return varOp; auto builder = OpBuilder::atBlockBegin(&block, b.getListener()); auto type = getPushConstantStorageType(elementCount, builder, indexType); const char *name = "__push_constant_var__"; return builder.create(loc, type, name, /*initializer=*/nullptr); } Value spirv::getPushConstantValue(Operation *op, unsigned elementCount, unsigned offset, Type integerType, OpBuilder &builder) { Location loc = op->getLoc(); Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp()); if (!parent) { op->emitError("expected operation to be within a module-like op"); return nullptr; } spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable( loc, parent->getRegion(0).front(), elementCount, builder, integerType); Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder); Value offsetOp = builder.create( loc, integerType, builder.getI32IntegerAttr(offset)); auto addrOp = builder.create(loc, varOp); auto acOp = builder.create( loc, addrOp, llvm::ArrayRef({zeroOp, offsetOp})); return builder.create(loc, acOp); } //===----------------------------------------------------------------------===// // Index calculation //===----------------------------------------------------------------------===// Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef strides, int64_t offset, Type integerType, Location loc, OpBuilder &builder) { assert(indices.size() == strides.size() && "must provide indices for all dimensions"); // TODO: Consider moving to use affine.apply and patterns converting // affine.apply to standard ops. This needs converting to SPIR-V passes to be // broken down into progressive small steps so we can have intermediate steps // using other dialects. At the moment SPIR-V is the final sink. Value linearizedIndex = builder.create( loc, integerType, IntegerAttr::get(integerType, offset)); for (const auto &index : llvm::enumerate(indices)) { Value strideVal = builder.create( loc, integerType, IntegerAttr::get(integerType, strides[index.index()])); Value update = builder.create(loc, strideVal, index.value()); linearizedIndex = builder.create(loc, linearizedIndex, update); } return linearizedIndex; } Value mlir::spirv::getVulkanElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder) { // Get base and offset of the MemRefType and verify they are static. int64_t offset; SmallVector strides; if (failed(getStridesAndOffset(baseType, strides, offset)) || llvm::is_contained(strides, ShapedType::kDynamic) || ShapedType::isDynamic(offset)) { return nullptr; } auto indexType = typeConverter.getIndexType(); SmallVector linearizedIndices; auto zero = spirv::ConstantOp::getZero(indexType, loc, builder); // Add a '0' at the start to index into the struct. linearizedIndices.push_back(zero); if (baseType.getRank() == 0) { linearizedIndices.push_back(zero); } else { linearizedIndices.push_back( linearizeIndex(indices, strides, offset, indexType, loc, builder)); } return builder.create(loc, basePtr, linearizedIndices); } Value mlir::spirv::getOpenCLElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder) { // Get base and offset of the MemRefType and verify they are static. int64_t offset; SmallVector strides; if (failed(getStridesAndOffset(baseType, strides, offset)) || llvm::is_contained(strides, ShapedType::kDynamic) || ShapedType::isDynamic(offset)) { return nullptr; } auto indexType = typeConverter.getIndexType(); SmallVector linearizedIndices; Value linearIndex; if (baseType.getRank() == 0) { linearIndex = spirv::ConstantOp::getZero(indexType, loc, builder); } else { linearIndex = linearizeIndex(indices, strides, offset, indexType, loc, builder); } Type pointeeType = cast(basePtr.getType()).getPointeeType(); if (isa(pointeeType)) { linearizedIndices.push_back(linearIndex); return builder.create(loc, basePtr, linearizedIndices); } return builder.create(loc, basePtr, linearIndex, linearizedIndices); } Value mlir::spirv::getElementPtr(const SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder) { if (typeConverter.allows(spirv::Capability::Kernel)) { return getOpenCLElementPtr(typeConverter, baseType, basePtr, indices, loc, builder); } return getVulkanElementPtr(typeConverter, baseType, basePtr, indices, loc, builder); } //===----------------------------------------------------------------------===// // SPIR-V ConversionTarget //===----------------------------------------------------------------------===// std::unique_ptr SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) { std::unique_ptr target( // std::make_unique does not work here because the constructor is private. new SPIRVConversionTarget(targetAttr)); SPIRVConversionTarget *targetPtr = target.get(); target->addDynamicallyLegalDialect( // We need to capture the raw pointer here because it is stable: // target will be destroyed once this function is returned. [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); }); return target; } SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr) : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {} bool SPIRVConversionTarget::isLegalOp(Operation *op) { // Make sure this op is available at the given version. Ops not implementing // QueryMinVersionInterface/QueryMaxVersionInterface are available to all // SPIR-V versions. if (auto minVersionIfx = dyn_cast(op)) { std::optional minVersion = minVersionIfx.getMinVersion(); if (minVersion && *minVersion > this->targetEnv.getVersion()) { LLVM_DEBUG(llvm::dbgs() << op->getName() << " illegal: requiring min version " << spirv::stringifyVersion(*minVersion) << "\n"); return false; } } if (auto maxVersionIfx = dyn_cast(op)) { std::optional maxVersion = maxVersionIfx.getMaxVersion(); if (maxVersion && *maxVersion < this->targetEnv.getVersion()) { LLVM_DEBUG(llvm::dbgs() << op->getName() << " illegal: requiring max version " << spirv::stringifyVersion(*maxVersion) << "\n"); return false; } } // Make sure this op's required extensions are allowed to use. Ops not // implementing QueryExtensionInterface do not require extensions to be // available. if (auto extensions = dyn_cast(op)) if (failed(checkExtensionRequirements(op->getName(), this->targetEnv, extensions.getExtensions()))) return false; // Make sure this op's required extensions are allowed to use. Ops not // implementing QueryCapabilityInterface do not require capabilities to be // available. if (auto capabilities = dyn_cast(op)) if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv, capabilities.getCapabilities()))) return false; SmallVector valueTypes; valueTypes.append(op->operand_type_begin(), op->operand_type_end()); valueTypes.append(op->result_type_begin(), op->result_type_end()); // Ensure that all types have been converted to SPIRV types. if (llvm::any_of(valueTypes, [](Type t) { return !isa(t); })) return false; // Special treatment for global variables, whose type requirements are // conveyed by type attributes. if (auto globalVar = dyn_cast(op)) valueTypes.push_back(globalVar.getType()); // Make sure the op's operands/results use types that are allowed by the // target environment. SmallVector, 4> typeExtensions; SmallVector, 8> typeCapabilities; for (Type valueType : valueTypes) { typeExtensions.clear(); cast(valueType).getExtensions(typeExtensions); if (failed(checkExtensionRequirements(op->getName(), this->targetEnv, typeExtensions))) return false; typeCapabilities.clear(); cast(valueType).getCapabilities(typeCapabilities); if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv, typeCapabilities))) return false; } return true; }