416 lines
15 KiB
C++
416 lines
15 KiB
C++
//===- TestOpProperties.cpp - Test all properties-related APIs ------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/IR/Attributes.h"
|
|
#include "mlir/IR/OpDefinition.h"
|
|
#include "mlir/IR/OperationSupport.h"
|
|
#include "mlir/Parser/Parser.h"
|
|
#include "gtest/gtest.h"
|
|
#include <optional>
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
/// Simple structure definining a struct to define "properties" for a given
|
|
/// operation. Default values are honored when creating an operation.
|
|
struct TestProperties {
|
|
int a = -1;
|
|
float b = -1.;
|
|
std::vector<int64_t> array = {-33};
|
|
/// A shared_ptr to a const object is safe: it is equivalent to a value-based
|
|
/// member. Here the label will be deallocated when the last operation
|
|
/// referring to it is destroyed. However there is no pool-allocation: this is
|
|
/// offloaded to the client.
|
|
std::shared_ptr<const std::string> label;
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestProperties)
|
|
};
|
|
|
|
bool operator==(const TestProperties &lhs, TestProperties &rhs) {
|
|
return lhs.a == rhs.a && lhs.b == rhs.b && lhs.array == rhs.array &&
|
|
lhs.label == rhs.label;
|
|
}
|
|
|
|
/// Convert a DictionaryAttr to a TestProperties struct, optionally emit errors
|
|
/// through the provided diagnostic if any. This is used for example during
|
|
/// parsing with the generic format.
|
|
static LogicalResult
|
|
setPropertiesFromAttribute(TestProperties &prop, Attribute attr,
|
|
function_ref<InFlightDiagnostic()> emitError) {
|
|
DictionaryAttr dict = dyn_cast<DictionaryAttr>(attr);
|
|
if (!dict) {
|
|
emitError() << "expected DictionaryAttr to set TestProperties";
|
|
return failure();
|
|
}
|
|
auto aAttr = dict.getAs<IntegerAttr>("a");
|
|
if (!aAttr) {
|
|
emitError() << "expected IntegerAttr for key `a`";
|
|
return failure();
|
|
}
|
|
auto bAttr = dict.getAs<FloatAttr>("b");
|
|
if (!bAttr ||
|
|
&bAttr.getValue().getSemantics() != &llvm::APFloatBase::IEEEsingle()) {
|
|
emitError() << "expected FloatAttr for key `b`";
|
|
return failure();
|
|
}
|
|
|
|
auto arrayAttr = dict.getAs<DenseI64ArrayAttr>("array");
|
|
if (!arrayAttr) {
|
|
emitError() << "expected DenseI64ArrayAttr for key `array`";
|
|
return failure();
|
|
}
|
|
|
|
auto label = dict.getAs<mlir::StringAttr>("label");
|
|
if (!label) {
|
|
emitError() << "expected StringAttr for key `label`";
|
|
return failure();
|
|
}
|
|
|
|
prop.a = aAttr.getValue().getSExtValue();
|
|
prop.b = bAttr.getValue().convertToFloat();
|
|
prop.array.assign(arrayAttr.asArrayRef().begin(),
|
|
arrayAttr.asArrayRef().end());
|
|
prop.label = std::make_shared<std::string>(label.getValue());
|
|
return success();
|
|
}
|
|
|
|
/// Convert a TestProperties struct to a DictionaryAttr, this is used for
|
|
/// example during printing with the generic format.
|
|
static Attribute getPropertiesAsAttribute(MLIRContext *ctx,
|
|
const TestProperties &prop) {
|
|
SmallVector<NamedAttribute> attrs;
|
|
Builder b{ctx};
|
|
attrs.push_back(b.getNamedAttr("a", b.getI32IntegerAttr(prop.a)));
|
|
attrs.push_back(b.getNamedAttr("b", b.getF32FloatAttr(prop.b)));
|
|
attrs.push_back(b.getNamedAttr("array", b.getDenseI64ArrayAttr(prop.array)));
|
|
attrs.push_back(b.getNamedAttr(
|
|
"label", b.getStringAttr(prop.label ? *prop.label : "<nullptr>")));
|
|
return b.getDictionaryAttr(attrs);
|
|
}
|
|
|
|
inline llvm::hash_code computeHash(const TestProperties &prop) {
|
|
// We hash `b` which is a float using its underlying array of char:
|
|
unsigned char const *p = reinterpret_cast<unsigned char const *>(&prop.b);
|
|
ArrayRef<unsigned char> bBytes{p, sizeof(prop.b)};
|
|
return llvm::hash_combine(
|
|
prop.a, llvm::hash_combine_range(bBytes.begin(), bBytes.end()),
|
|
llvm::hash_combine_range(prop.array.begin(), prop.array.end()),
|
|
StringRef(*prop.label));
|
|
}
|
|
|
|
/// A custom operation for the purpose of showcasing how to use "properties".
|
|
class OpWithProperties : public Op<OpWithProperties> {
|
|
public:
|
|
// Begin boilerplate
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpWithProperties)
|
|
using Op::Op;
|
|
static ArrayRef<StringRef> getAttributeNames() { return {}; }
|
|
static StringRef getOperationName() {
|
|
return "test_op_properties.op_with_properties";
|
|
}
|
|
// End boilerplate
|
|
|
|
// This alias is the only definition needed for enabling "properties" for this
|
|
// operation.
|
|
using Properties = TestProperties;
|
|
static std::optional<mlir::Attribute> getInherentAttr(MLIRContext *context,
|
|
const Properties &prop,
|
|
StringRef name) {
|
|
return std::nullopt;
|
|
}
|
|
static void setInherentAttr(Properties &prop, StringRef name,
|
|
mlir::Attribute value) {}
|
|
static void populateInherentAttrs(MLIRContext *context,
|
|
const Properties &prop,
|
|
NamedAttrList &attrs) {}
|
|
static LogicalResult
|
|
verifyInherentAttrs(OperationName opName, NamedAttrList &attrs,
|
|
function_ref<InFlightDiagnostic()> emitError) {
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// A custom operation for the purpose of showcasing how discardable attributes
|
|
/// are handled in absence of properties.
|
|
class OpWithoutProperties : public Op<OpWithoutProperties> {
|
|
public:
|
|
// Begin boilerplate.
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpWithoutProperties)
|
|
using Op::Op;
|
|
static ArrayRef<StringRef> getAttributeNames() {
|
|
static StringRef attributeNames[] = {StringRef("inherent_attr")};
|
|
return ArrayRef(attributeNames);
|
|
};
|
|
static StringRef getOperationName() {
|
|
return "test_op_properties.op_without_properties";
|
|
}
|
|
// End boilerplate.
|
|
};
|
|
|
|
// A trivial supporting dialect to register the above operation.
|
|
class TestOpPropertiesDialect : public Dialect {
|
|
public:
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOpPropertiesDialect)
|
|
static constexpr StringLiteral getDialectNamespace() {
|
|
return StringLiteral("test_op_properties");
|
|
}
|
|
explicit TestOpPropertiesDialect(MLIRContext *context)
|
|
: Dialect(getDialectNamespace(), context,
|
|
TypeID::get<TestOpPropertiesDialect>()) {
|
|
addOperations<OpWithProperties, OpWithoutProperties>();
|
|
}
|
|
};
|
|
|
|
constexpr StringLiteral mlirSrc = R"mlir(
|
|
"test_op_properties.op_with_properties"()
|
|
<{a = -42 : i32,
|
|
b = -4.200000e+01 : f32,
|
|
array = array<i64: 40, 41>,
|
|
label = "bar foo"}> : () -> ()
|
|
)mlir";
|
|
|
|
TEST(OpPropertiesTest, Properties) {
|
|
MLIRContext context;
|
|
context.getOrLoadDialect<TestOpPropertiesDialect>();
|
|
ParserConfig config(&context);
|
|
// Parse the operation with some properties.
|
|
OwningOpRef<Operation *> op = parseSourceString(mlirSrc, config);
|
|
ASSERT_TRUE(op.get() != nullptr);
|
|
auto opWithProp = dyn_cast<OpWithProperties>(op.get());
|
|
ASSERT_TRUE(opWithProp);
|
|
{
|
|
std::string output;
|
|
llvm::raw_string_ostream os(output);
|
|
opWithProp.print(os);
|
|
ASSERT_STREQ("\"test_op_properties.op_with_properties\"() "
|
|
"<{a = -42 : i32, "
|
|
"array = array<i64: 40, 41>, "
|
|
"b = -4.200000e+01 : f32, "
|
|
"label = \"bar foo\"}> : () -> ()\n",
|
|
os.str().c_str());
|
|
}
|
|
// Get a mutable reference to the properties for this operation and modify it
|
|
// in place one member at a time.
|
|
TestProperties &prop = opWithProp.getProperties();
|
|
prop.a = 42;
|
|
{
|
|
std::string output;
|
|
llvm::raw_string_ostream os(output);
|
|
opWithProp.print(os);
|
|
EXPECT_TRUE(StringRef(os.str()).contains("a = 42"));
|
|
EXPECT_TRUE(StringRef(os.str()).contains("b = -4.200000e+01"));
|
|
EXPECT_TRUE(StringRef(os.str()).contains("array = array<i64: 40, 41>"));
|
|
EXPECT_TRUE(StringRef(os.str()).contains("label = \"bar foo\""));
|
|
}
|
|
prop.b = 42.;
|
|
{
|
|
std::string output;
|
|
llvm::raw_string_ostream os(output);
|
|
opWithProp.print(os);
|
|
EXPECT_TRUE(StringRef(os.str()).contains("a = 42"));
|
|
EXPECT_TRUE(StringRef(os.str()).contains("b = 4.200000e+01"));
|
|
EXPECT_TRUE(StringRef(os.str()).contains("array = array<i64: 40, 41>"));
|
|
EXPECT_TRUE(StringRef(os.str()).contains("label = \"bar foo\""));
|
|
}
|
|
prop.array.push_back(42);
|
|
{
|
|
std::string output;
|
|
llvm::raw_string_ostream os(output);
|
|
opWithProp.print(os);
|
|
EXPECT_TRUE(StringRef(os.str()).contains("a = 42"));
|
|
EXPECT_TRUE(StringRef(os.str()).contains("b = 4.200000e+01"));
|
|
EXPECT_TRUE(StringRef(os.str()).contains("array = array<i64: 40, 41, 42>"));
|
|
EXPECT_TRUE(StringRef(os.str()).contains("label = \"bar foo\""));
|
|
}
|
|
prop.label = std::make_shared<std::string>("foo bar");
|
|
{
|
|
std::string output;
|
|
llvm::raw_string_ostream os(output);
|
|
opWithProp.print(os);
|
|
EXPECT_TRUE(StringRef(os.str()).contains("a = 42"));
|
|
EXPECT_TRUE(StringRef(os.str()).contains("b = 4.200000e+01"));
|
|
EXPECT_TRUE(StringRef(os.str()).contains("array = array<i64: 40, 41, 42>"));
|
|
EXPECT_TRUE(StringRef(os.str()).contains("label = \"foo bar\""));
|
|
}
|
|
}
|
|
|
|
// Test diagnostic emission when using invalid dictionary.
|
|
TEST(OpPropertiesTest, FailedProperties) {
|
|
MLIRContext context;
|
|
context.getOrLoadDialect<TestOpPropertiesDialect>();
|
|
std::string diagnosticStr;
|
|
context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
|
|
diagnosticStr += diag.str();
|
|
return success();
|
|
});
|
|
|
|
// Parse the operation with some properties.
|
|
ParserConfig config(&context);
|
|
|
|
// Parse an operation with invalid (incomplete) properties.
|
|
OwningOpRef<Operation *> owningOp =
|
|
parseSourceString("\"test_op_properties.op_with_properties\"() "
|
|
"<{a = -42 : i32}> : () -> ()\n",
|
|
config);
|
|
ASSERT_EQ(owningOp.get(), nullptr);
|
|
EXPECT_STREQ(
|
|
"invalid properties {a = -42 : i32} for op "
|
|
"test_op_properties.op_with_properties: expected FloatAttr for key `b`",
|
|
diagnosticStr.c_str());
|
|
diagnosticStr.clear();
|
|
|
|
owningOp = parseSourceString(mlirSrc, config);
|
|
Operation *op = owningOp.get();
|
|
ASSERT_TRUE(op != nullptr);
|
|
Location loc = op->getLoc();
|
|
auto opWithProp = dyn_cast<OpWithProperties>(op);
|
|
ASSERT_TRUE(opWithProp);
|
|
|
|
OperationState state(loc, op->getName());
|
|
Builder b{&context};
|
|
NamedAttrList attrs;
|
|
attrs.push_back(b.getNamedAttr("a", b.getStringAttr("foo")));
|
|
state.propertiesAttr = attrs.getDictionary(&context);
|
|
{
|
|
auto emitError = [&]() {
|
|
return op->emitError("setting properties failed: ");
|
|
};
|
|
auto result = state.setProperties(op, emitError);
|
|
EXPECT_TRUE(result.failed());
|
|
}
|
|
EXPECT_STREQ("setting properties failed: expected IntegerAttr for key `a`",
|
|
diagnosticStr.c_str());
|
|
}
|
|
|
|
TEST(OpPropertiesTest, DefaultValues) {
|
|
MLIRContext context;
|
|
context.getOrLoadDialect<TestOpPropertiesDialect>();
|
|
OperationState state(UnknownLoc::get(&context),
|
|
"test_op_properties.op_with_properties");
|
|
Operation *op = Operation::create(state);
|
|
ASSERT_TRUE(op != nullptr);
|
|
{
|
|
std::string output;
|
|
llvm::raw_string_ostream os(output);
|
|
op->print(os);
|
|
EXPECT_TRUE(StringRef(os.str()).contains("a = -1"));
|
|
EXPECT_TRUE(StringRef(os.str()).contains("b = -1"));
|
|
EXPECT_TRUE(StringRef(os.str()).contains("array = array<i64: -33>"));
|
|
}
|
|
op->erase();
|
|
}
|
|
|
|
TEST(OpPropertiesTest, Cloning) {
|
|
MLIRContext context;
|
|
context.getOrLoadDialect<TestOpPropertiesDialect>();
|
|
ParserConfig config(&context);
|
|
// Parse the operation with some properties.
|
|
OwningOpRef<Operation *> op = parseSourceString(mlirSrc, config);
|
|
ASSERT_TRUE(op.get() != nullptr);
|
|
auto opWithProp = dyn_cast<OpWithProperties>(op.get());
|
|
ASSERT_TRUE(opWithProp);
|
|
Operation *clone = opWithProp->clone();
|
|
|
|
// Check that op and its clone prints equally
|
|
std::string opStr;
|
|
std::string cloneStr;
|
|
{
|
|
llvm::raw_string_ostream os(opStr);
|
|
op.get()->print(os);
|
|
}
|
|
{
|
|
llvm::raw_string_ostream os(cloneStr);
|
|
clone->print(os);
|
|
}
|
|
clone->erase();
|
|
EXPECT_STREQ(opStr.c_str(), cloneStr.c_str());
|
|
}
|
|
|
|
TEST(OpPropertiesTest, Equivalence) {
|
|
MLIRContext context;
|
|
context.getOrLoadDialect<TestOpPropertiesDialect>();
|
|
ParserConfig config(&context);
|
|
// Parse the operation with some properties.
|
|
OwningOpRef<Operation *> op = parseSourceString(mlirSrc, config);
|
|
ASSERT_TRUE(op.get() != nullptr);
|
|
auto opWithProp = dyn_cast<OpWithProperties>(op.get());
|
|
ASSERT_TRUE(opWithProp);
|
|
llvm::hash_code reference = OperationEquivalence::computeHash(opWithProp);
|
|
TestProperties &prop = opWithProp.getProperties();
|
|
prop.a = 42;
|
|
EXPECT_NE(reference, OperationEquivalence::computeHash(opWithProp));
|
|
prop.a = -42;
|
|
EXPECT_EQ(reference, OperationEquivalence::computeHash(opWithProp));
|
|
prop.b = 42.;
|
|
EXPECT_NE(reference, OperationEquivalence::computeHash(opWithProp));
|
|
prop.b = -42.;
|
|
EXPECT_EQ(reference, OperationEquivalence::computeHash(opWithProp));
|
|
prop.array.push_back(42);
|
|
EXPECT_NE(reference, OperationEquivalence::computeHash(opWithProp));
|
|
prop.array.pop_back();
|
|
EXPECT_EQ(reference, OperationEquivalence::computeHash(opWithProp));
|
|
}
|
|
|
|
TEST(OpPropertiesTest, getOrAddProperties) {
|
|
MLIRContext context;
|
|
context.getOrLoadDialect<TestOpPropertiesDialect>();
|
|
OperationState state(UnknownLoc::get(&context),
|
|
"test_op_properties.op_with_properties");
|
|
// Test `getOrAddProperties` API on OperationState.
|
|
TestProperties &prop = state.getOrAddProperties<TestProperties>();
|
|
prop.a = 1;
|
|
prop.b = 2;
|
|
prop.array = {3, 4, 5};
|
|
Operation *op = Operation::create(state);
|
|
ASSERT_TRUE(op != nullptr);
|
|
{
|
|
std::string output;
|
|
llvm::raw_string_ostream os(output);
|
|
op->print(os);
|
|
EXPECT_TRUE(StringRef(os.str()).contains("a = 1"));
|
|
EXPECT_TRUE(StringRef(os.str()).contains("b = 2"));
|
|
EXPECT_TRUE(StringRef(os.str()).contains("array = array<i64: 3, 4, 5>"));
|
|
}
|
|
op->erase();
|
|
}
|
|
|
|
constexpr StringLiteral withoutPropertiesAttrsSrc = R"mlir(
|
|
"test_op_properties.op_without_properties"()
|
|
{inherent_attr = 42, other_attr = 56} : () -> ()
|
|
)mlir";
|
|
|
|
TEST(OpPropertiesTest, withoutPropertiesDiscardableAttrs) {
|
|
MLIRContext context;
|
|
context.getOrLoadDialect<TestOpPropertiesDialect>();
|
|
ParserConfig config(&context);
|
|
OwningOpRef<Operation *> op =
|
|
parseSourceString(withoutPropertiesAttrsSrc, config);
|
|
ASSERT_EQ(llvm::range_size(op->getDiscardableAttrs()), 1u);
|
|
EXPECT_EQ(op->getDiscardableAttrs().begin()->getName().getValue(),
|
|
"other_attr");
|
|
|
|
EXPECT_EQ(op->getAttrs().size(), 2u);
|
|
EXPECT_TRUE(op->getInherentAttr("inherent_attr") != std::nullopt);
|
|
EXPECT_TRUE(op->getDiscardableAttr("other_attr") != Attribute());
|
|
|
|
std::string output;
|
|
llvm::raw_string_ostream os(output);
|
|
op->print(os);
|
|
EXPECT_TRUE(StringRef(os.str()).contains("inherent_attr = 42"));
|
|
EXPECT_TRUE(StringRef(os.str()).contains("other_attr = 56"));
|
|
|
|
OwningOpRef<Operation *> reparsed = parseSourceString(os.str(), config);
|
|
auto trivialHash = [](Value v) { return hash_value(v); };
|
|
auto hash = [&](Operation *operation) {
|
|
return OperationEquivalence::computeHash(
|
|
operation, trivialHash, trivialHash,
|
|
OperationEquivalence::Flags::IgnoreLocations);
|
|
};
|
|
EXPECT_TRUE(hash(op.get()) == hash(reparsed.get()));
|
|
}
|
|
|
|
} // namespace
|