//===- ROCDLAttachTarget.cpp - Attach an ROCDL target ---------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements the `GpuROCDLAttachTarget` pass, attaching // `#rocdl.target` attributes to GPU modules. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/GPU/Transforms/Passes.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" #include "mlir/Target/LLVM/ROCDL/Target.h" #include "llvm/Support/Regex.h" namespace mlir { #define GEN_PASS_DEF_GPUROCDLATTACHTARGET #include "mlir/Dialect/GPU/Transforms/Passes.h.inc" } // namespace mlir using namespace mlir; using namespace mlir::ROCDL; namespace { struct ROCDLAttachTarget : public impl::GpuROCDLAttachTargetBase { using Base::Base; DictionaryAttr getFlags(OpBuilder &builder) const; void runOnOperation() override; void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } }; } // namespace DictionaryAttr ROCDLAttachTarget::getFlags(OpBuilder &builder) const { UnitAttr unitAttr = builder.getUnitAttr(); SmallVector flags; auto addFlag = [&](StringRef flag) { flags.push_back(builder.getNamedAttr(flag, unitAttr)); }; if (!wave64Flag) addFlag("no_wave64"); if (fastFlag) addFlag("fast"); if (dazFlag) addFlag("daz"); if (finiteOnlyFlag) addFlag("finite_only"); if (unsafeMathFlag) addFlag("unsafe_math"); if (!correctSqrtFlag) addFlag("unsafe_sqrt"); if (!flags.empty()) return builder.getDictionaryAttr(flags); return nullptr; } void ROCDLAttachTarget::runOnOperation() { OpBuilder builder(&getContext()); ArrayRef libs(linkLibs); SmallVector filesToLink(libs.begin(), libs.end()); auto target = builder.getAttr( optLevel, triple, chip, features, abiVersion, getFlags(builder), filesToLink.empty() ? nullptr : builder.getStrArrayAttr(filesToLink)); llvm::Regex matcher(moduleMatcher); for (Region ®ion : getOperation()->getRegions()) for (Block &block : region.getBlocks()) for (auto module : block.getOps()) { // Check if the name of the module matches. if (!moduleMatcher.empty() && !matcher.match(module.getName())) continue; // Create the target array. SmallVector targets; if (std::optional attrs = module.getTargets()) targets.append(attrs->getValue().begin(), attrs->getValue().end()); targets.push_back(target); // Remove any duplicate targets. targets.erase(std::unique(targets.begin(), targets.end()), targets.end()); // Update the target attribute array. module.setTargetsAttr(builder.getArrayAttr(targets)); } }