396 lines
15 KiB
C++
396 lines
15 KiB
C++
//===- TestDialectInterfaces.cpp - Test dialect interface definitions -----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "TestDialect.h"
|
|
#include "mlir/Interfaces/FoldInterfaces.h"
|
|
#include "mlir/Reducer/ReductionPatternInterface.h"
|
|
#include "mlir/Transforms/InliningUtils.h"
|
|
|
|
using namespace mlir;
|
|
using namespace test;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TestDialect Interfaces
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
|
|
/// Testing the correctness of some traits.
|
|
static_assert(
|
|
llvm::is_detected<OpTrait::has_implicit_terminator_t,
|
|
SingleBlockImplicitTerminatorOp>::value,
|
|
"has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
|
|
static_assert(OpTrait::hasSingleBlockImplicitTerminator<
|
|
SingleBlockImplicitTerminatorOp>::value,
|
|
"hasSingleBlockImplicitTerminator does not match "
|
|
"SingleBlockImplicitTerminatorOp");
|
|
|
|
struct TestResourceBlobManagerInterface
|
|
: public ResourceBlobManagerDialectInterfaceBase<
|
|
TestDialectResourceBlobHandle> {
|
|
using ResourceBlobManagerDialectInterfaceBase<
|
|
TestDialectResourceBlobHandle>::ResourceBlobManagerDialectInterfaceBase;
|
|
};
|
|
|
|
namespace {
|
|
enum test_encoding { k_attr_params = 0, k_test_i32 = 99 };
|
|
}
|
|
|
|
// Test support for interacting with the Bytecode reader/writer.
|
|
struct TestBytecodeDialectInterface : public BytecodeDialectInterface {
|
|
using BytecodeDialectInterface::BytecodeDialectInterface;
|
|
TestBytecodeDialectInterface(Dialect *dialect)
|
|
: BytecodeDialectInterface(dialect) {}
|
|
|
|
LogicalResult writeType(Type type,
|
|
DialectBytecodeWriter &writer) const final {
|
|
if (auto concreteType = llvm::dyn_cast<TestI32Type>(type)) {
|
|
writer.writeVarInt(test_encoding::k_test_i32);
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
Type readType(DialectBytecodeReader &reader) const final {
|
|
uint64_t encoding;
|
|
if (failed(reader.readVarInt(encoding)))
|
|
return Type();
|
|
if (encoding == test_encoding::k_test_i32)
|
|
return TestI32Type::get(getContext());
|
|
return Type();
|
|
}
|
|
|
|
LogicalResult writeAttribute(Attribute attr,
|
|
DialectBytecodeWriter &writer) const final {
|
|
if (auto concreteAttr = llvm::dyn_cast<TestAttrParamsAttr>(attr)) {
|
|
writer.writeVarInt(test_encoding::k_attr_params);
|
|
writer.writeVarInt(concreteAttr.getV0());
|
|
writer.writeVarInt(concreteAttr.getV1());
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
Attribute readAttribute(DialectBytecodeReader &reader) const final {
|
|
auto versionOr = reader.getDialectVersion<test::TestDialect>();
|
|
// Assume current version if not available through the reader.
|
|
const auto version =
|
|
(succeeded(versionOr))
|
|
? *reinterpret_cast<const TestDialectVersion *>(*versionOr)
|
|
: TestDialectVersion();
|
|
if (version.major_ < 2)
|
|
return readAttrOldEncoding(reader);
|
|
if (version.major_ == 2 && version.minor_ == 0)
|
|
return readAttrNewEncoding(reader);
|
|
// Forbid reading future versions by returning nullptr.
|
|
return Attribute();
|
|
}
|
|
|
|
// Emit a specific version of the dialect.
|
|
void writeVersion(DialectBytecodeWriter &writer) const final {
|
|
// Construct the current dialect version.
|
|
test::TestDialectVersion versionToEmit;
|
|
|
|
// Check if a target version to emit was specified on the writer configs.
|
|
auto versionOr = writer.getDialectVersion<test::TestDialect>();
|
|
if (succeeded(versionOr))
|
|
versionToEmit =
|
|
*reinterpret_cast<const test::TestDialectVersion *>(*versionOr);
|
|
writer.writeVarInt(versionToEmit.major_); // major
|
|
writer.writeVarInt(versionToEmit.minor_); // minor
|
|
}
|
|
|
|
std::unique_ptr<DialectVersion>
|
|
readVersion(DialectBytecodeReader &reader) const final {
|
|
uint64_t major_, minor_;
|
|
if (failed(reader.readVarInt(major_)) || failed(reader.readVarInt(minor_)))
|
|
return nullptr;
|
|
auto version = std::make_unique<TestDialectVersion>();
|
|
version->major_ = major_;
|
|
version->minor_ = minor_;
|
|
return version;
|
|
}
|
|
|
|
LogicalResult upgradeFromVersion(Operation *topLevelOp,
|
|
const DialectVersion &version_) const final {
|
|
const auto &version = static_cast<const TestDialectVersion &>(version_);
|
|
if ((version.major_ == 2) && (version.minor_ == 0))
|
|
return success();
|
|
if (version.major_ > 2 || (version.major_ == 2 && version.minor_ > 0)) {
|
|
return topLevelOp->emitError()
|
|
<< "current test dialect version is 2.0, can't parse version: "
|
|
<< version.major_ << "." << version.minor_;
|
|
}
|
|
// Prior version 2.0, the old op supported only a single attribute called
|
|
// "dimensions". We can perform the upgrade.
|
|
topLevelOp->walk([](TestVersionedOpA op) {
|
|
// Prior version 2.0, `readProperties` did not process the modifier
|
|
// attribute. Handle that according to the version here.
|
|
auto &prop = op.getProperties();
|
|
prop.modifier = BoolAttr::get(op->getContext(), false);
|
|
});
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
Attribute readAttrNewEncoding(DialectBytecodeReader &reader) const {
|
|
uint64_t encoding;
|
|
if (failed(reader.readVarInt(encoding)) ||
|
|
encoding != test_encoding::k_attr_params)
|
|
return Attribute();
|
|
// The new encoding has v0 first, v1 second.
|
|
uint64_t v0, v1;
|
|
if (failed(reader.readVarInt(v0)) || failed(reader.readVarInt(v1)))
|
|
return Attribute();
|
|
return TestAttrParamsAttr::get(getContext(), static_cast<int>(v0),
|
|
static_cast<int>(v1));
|
|
}
|
|
|
|
Attribute readAttrOldEncoding(DialectBytecodeReader &reader) const {
|
|
uint64_t encoding;
|
|
if (failed(reader.readVarInt(encoding)) ||
|
|
encoding != test_encoding::k_attr_params)
|
|
return Attribute();
|
|
// The old encoding has v1 first, v0 second.
|
|
uint64_t v0, v1;
|
|
if (failed(reader.readVarInt(v1)) || failed(reader.readVarInt(v0)))
|
|
return Attribute();
|
|
return TestAttrParamsAttr::get(getContext(), static_cast<int>(v0),
|
|
static_cast<int>(v1));
|
|
}
|
|
};
|
|
|
|
// Test support for interacting with the AsmPrinter.
|
|
struct TestOpAsmInterface : public OpAsmDialectInterface {
|
|
using OpAsmDialectInterface::OpAsmDialectInterface;
|
|
TestOpAsmInterface(Dialect *dialect, TestResourceBlobManagerInterface &mgr)
|
|
: OpAsmDialectInterface(dialect), blobManager(mgr) {}
|
|
|
|
//===------------------------------------------------------------------===//
|
|
// Aliases
|
|
//===------------------------------------------------------------------===//
|
|
|
|
AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
|
|
StringAttr strAttr = dyn_cast<StringAttr>(attr);
|
|
if (!strAttr)
|
|
return AliasResult::NoAlias;
|
|
|
|
// Check the contents of the string attribute to see what the test alias
|
|
// should be named.
|
|
std::optional<StringRef> aliasName =
|
|
StringSwitch<std::optional<StringRef>>(strAttr.getValue())
|
|
.Case("alias_test:dot_in_name", StringRef("test.alias"))
|
|
.Case("alias_test:trailing_digit", StringRef("test_alias0"))
|
|
.Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
|
|
.Case("alias_test:sanitize_conflict_a",
|
|
StringRef("test_alias_conflict0"))
|
|
.Case("alias_test:sanitize_conflict_b",
|
|
StringRef("test_alias_conflict0_"))
|
|
.Case("alias_test:tensor_encoding", StringRef("test_encoding"))
|
|
.Default(std::nullopt);
|
|
if (!aliasName)
|
|
return AliasResult::NoAlias;
|
|
|
|
os << *aliasName;
|
|
return AliasResult::FinalAlias;
|
|
}
|
|
|
|
AliasResult getAlias(Type type, raw_ostream &os) const final {
|
|
if (auto tupleType = dyn_cast<TupleType>(type)) {
|
|
if (tupleType.size() > 0 &&
|
|
llvm::all_of(tupleType.getTypes(), [](Type elemType) {
|
|
return isa<SimpleAType>(elemType);
|
|
})) {
|
|
os << "test_tuple";
|
|
return AliasResult::FinalAlias;
|
|
}
|
|
}
|
|
if (auto intType = dyn_cast<TestIntegerType>(type)) {
|
|
if (intType.getSignedness() ==
|
|
TestIntegerType::SignednessSemantics::Unsigned &&
|
|
intType.getWidth() == 8) {
|
|
os << "test_ui8";
|
|
return AliasResult::FinalAlias;
|
|
}
|
|
}
|
|
if (auto recType = dyn_cast<TestRecursiveType>(type)) {
|
|
if (recType.getName() == "type_to_alias") {
|
|
// We only make alias for a specific recursive type.
|
|
os << "testrec";
|
|
return AliasResult::FinalAlias;
|
|
}
|
|
}
|
|
if (auto recAliasType = dyn_cast<TestRecursiveAliasType>(type)) {
|
|
os << recAliasType.getName();
|
|
return AliasResult::FinalAlias;
|
|
}
|
|
return AliasResult::NoAlias;
|
|
}
|
|
|
|
//===------------------------------------------------------------------===//
|
|
// Resources
|
|
//===------------------------------------------------------------------===//
|
|
|
|
std::string
|
|
getResourceKey(const AsmDialectResourceHandle &handle) const override {
|
|
return cast<TestDialectResourceBlobHandle>(handle).getKey().str();
|
|
}
|
|
|
|
FailureOr<AsmDialectResourceHandle>
|
|
declareResource(StringRef key) const final {
|
|
return blobManager.insert(key);
|
|
}
|
|
|
|
LogicalResult parseResource(AsmParsedResourceEntry &entry) const final {
|
|
FailureOr<AsmResourceBlob> blob = entry.parseAsBlob();
|
|
if (failed(blob))
|
|
return failure();
|
|
|
|
// Update the blob for this entry.
|
|
blobManager.update(entry.getKey(), std::move(*blob));
|
|
return success();
|
|
}
|
|
|
|
void
|
|
buildResources(Operation *op,
|
|
const SetVector<AsmDialectResourceHandle> &referencedResources,
|
|
AsmResourceBuilder &provider) const final {
|
|
blobManager.buildResources(provider, referencedResources.getArrayRef());
|
|
}
|
|
|
|
private:
|
|
/// The blob manager for the dialect.
|
|
TestResourceBlobManagerInterface &blobManager;
|
|
};
|
|
|
|
struct TestDialectFoldInterface : public DialectFoldInterface {
|
|
using DialectFoldInterface::DialectFoldInterface;
|
|
|
|
/// Registered hook to check if the given region, which is attached to an
|
|
/// operation that is *not* isolated from above, should be used when
|
|
/// materializing constants.
|
|
bool shouldMaterializeInto(Region *region) const final {
|
|
// If this is a one region operation, then insert into it.
|
|
return isa<OneRegionOp>(region->getParentOp());
|
|
}
|
|
};
|
|
|
|
/// This class defines the interface for handling inlining with standard
|
|
/// operations.
|
|
struct TestInlinerInterface : public DialectInlinerInterface {
|
|
using DialectInlinerInterface::DialectInlinerInterface;
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Analysis Hooks
|
|
//===--------------------------------------------------------------------===//
|
|
|
|
bool isLegalToInline(Operation *call, Operation *callable,
|
|
bool wouldBeCloned) const final {
|
|
// Don't allow inlining calls that are marked `noinline`.
|
|
return !call->hasAttr("noinline");
|
|
}
|
|
bool isLegalToInline(Region *, Region *, bool, IRMapping &) const final {
|
|
// Inlining into test dialect regions is legal.
|
|
return true;
|
|
}
|
|
bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
|
|
return true;
|
|
}
|
|
|
|
bool shouldAnalyzeRecursively(Operation *op) const final {
|
|
// Analyze recursively if this is not a functional region operation, it
|
|
// froms a separate functional scope.
|
|
return !isa<FunctionalRegionOp>(op);
|
|
}
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Transformation Hooks
|
|
//===--------------------------------------------------------------------===//
|
|
|
|
/// Handle the given inlined terminator by replacing it with a new operation
|
|
/// as necessary.
|
|
void handleTerminator(Operation *op, ValueRange valuesToRepl) const final {
|
|
// Only handle "test.return" here.
|
|
auto returnOp = dyn_cast<TestReturnOp>(op);
|
|
if (!returnOp)
|
|
return;
|
|
|
|
// Replace the values directly with the return operands.
|
|
assert(returnOp.getNumOperands() == valuesToRepl.size());
|
|
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
|
|
valuesToRepl[it.index()].replaceAllUsesWith(it.value());
|
|
}
|
|
|
|
/// Attempt to materialize a conversion for a type mismatch between a call
|
|
/// from this dialect, and a callable region. This method should generate an
|
|
/// operation that takes 'input' as the only operand, and produces a single
|
|
/// result of 'resultType'. If a conversion can not be generated, nullptr
|
|
/// should be returned.
|
|
Operation *materializeCallConversion(OpBuilder &builder, Value input,
|
|
Type resultType,
|
|
Location conversionLoc) const final {
|
|
// Only allow conversion for i16/i32 types.
|
|
if (!(resultType.isSignlessInteger(16) ||
|
|
resultType.isSignlessInteger(32)) ||
|
|
!(input.getType().isSignlessInteger(16) ||
|
|
input.getType().isSignlessInteger(32)))
|
|
return nullptr;
|
|
return builder.create<TestCastOp>(conversionLoc, resultType, input);
|
|
}
|
|
|
|
Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable,
|
|
Value argument,
|
|
DictionaryAttr argumentAttrs) const final {
|
|
if (!argumentAttrs.contains("test.handle_argument"))
|
|
return argument;
|
|
return builder.create<TestTypeChangerOp>(call->getLoc(), argument.getType(),
|
|
argument);
|
|
}
|
|
|
|
Value handleResult(OpBuilder &builder, Operation *call, Operation *callable,
|
|
Value result, DictionaryAttr resultAttrs) const final {
|
|
if (!resultAttrs.contains("test.handle_result"))
|
|
return result;
|
|
return builder.create<TestTypeChangerOp>(call->getLoc(), result.getType(),
|
|
result);
|
|
}
|
|
|
|
void processInlinedCallBlocks(
|
|
Operation *call,
|
|
iterator_range<Region::iterator> inlinedBlocks) const final {
|
|
if (!isa<ConversionCallOp>(call))
|
|
return;
|
|
|
|
// Set attributed on all ops in the inlined blocks.
|
|
for (Block &block : inlinedBlocks) {
|
|
block.walk([&](Operation *op) {
|
|
op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
|
|
});
|
|
}
|
|
}
|
|
};
|
|
|
|
struct TestReductionPatternInterface : public DialectReductionPatternInterface {
|
|
public:
|
|
TestReductionPatternInterface(Dialect *dialect)
|
|
: DialectReductionPatternInterface(dialect) {}
|
|
|
|
void populateReductionPatterns(RewritePatternSet &patterns) const final {
|
|
populateTestReductionPatterns(patterns);
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void TestDialect::registerInterfaces() {
|
|
auto &blobInterface = addInterface<TestResourceBlobManagerInterface>();
|
|
addInterface<TestOpAsmInterface>(blobInterface);
|
|
|
|
addInterfaces<TestDialectFoldInterface, TestInlinerInterface,
|
|
TestReductionPatternInterface, TestBytecodeDialectInterface>();
|
|
}
|