233 lines
8.8 KiB
Python
233 lines
8.8 KiB
Python
# RUN: %PYTHON %s | FileCheck %s
|
|
|
|
from typing import Callable
|
|
from mlir import ir
|
|
from mlir.dialects import scf, pdl
|
|
from mlir.dialects.transform import (
|
|
structured,
|
|
get_parent_op,
|
|
apply_patterns_canonicalization,
|
|
apply_cse,
|
|
any_op_t,
|
|
)
|
|
from mlir.dialects.transform import FailurePropagationMode
|
|
from mlir.dialects.transform.structured import structured_match
|
|
from mlir.dialects.transform.loop import loop_unroll
|
|
from mlir.dialects.transform.extras import (
|
|
constant_param,
|
|
OpHandle,
|
|
insert_transform_script,
|
|
sequence,
|
|
apply_patterns,
|
|
)
|
|
from mlir.extras import types as T
|
|
|
|
|
|
def construct_and_print_in_module(f):
|
|
print("\nTEST:", f.__name__)
|
|
with ir.Context(), ir.Location.unknown():
|
|
module = ir.Module.create()
|
|
with ir.InsertionPoint(module.body):
|
|
f()
|
|
print(module)
|
|
return f
|
|
|
|
|
|
def build_transform_script(script: Callable[[OpHandle], None]):
|
|
print("\nTEST:", script.__name__)
|
|
with ir.Context(), ir.Location.unknown():
|
|
module = ir.Module.create()
|
|
module.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
|
|
insert_transform_script(module.body, script=script, dump_script=True)
|
|
module.operation.verify()
|
|
|
|
|
|
def build_transform_script_at_insertion_point(script: Callable[[OpHandle], None]):
|
|
print("\nTEST:", script.__name__)
|
|
with ir.Context(), ir.Location.unknown():
|
|
module = ir.Module.create()
|
|
module.operation.attributes["transform.with_named_sequence"] = ir.UnitAttr.get()
|
|
insert_transform_script(
|
|
ir.InsertionPoint.at_block_begin(module.body),
|
|
script=script,
|
|
dump_script=True,
|
|
)
|
|
module.operation.verify()
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_build_script_at_insertion_point
|
|
@build_transform_script_at_insertion_point
|
|
def test_build_script_at_insertion_point(op: OpHandle):
|
|
pass
|
|
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
|
|
# CHECK-NEXT: transform.yield
|
|
# CHECK-NEXT: }
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_constant_param_int
|
|
@build_transform_script
|
|
def test_constant_param_int(_: OpHandle):
|
|
constant_param(ir.IntegerAttr.get(T.i32(), 42))
|
|
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
|
|
# CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant 42 : i32
|
|
# CHECK-SAME: !transform.param<i32>
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_constant_param_py_int
|
|
@build_transform_script
|
|
def test_constant_param_py_int(_: OpHandle):
|
|
constant_param(42)
|
|
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
|
|
# CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant 42 : i64
|
|
# CHECK-SAME: !transform.param<i64>
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_constant_param_symbol_attr
|
|
@build_transform_script
|
|
def test_constant_param_symbol_attr(_: OpHandle):
|
|
constant_param(ir.SymbolRefAttr.get(["symbol"]))
|
|
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
|
|
# CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant @symbol
|
|
# CHECK-SAME: !transform.any_param
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_constant_param_type
|
|
@build_transform_script
|
|
def test_constant_param_type(_: OpHandle):
|
|
constant_param(ir.TypeAttr.get(T.i32()))
|
|
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
|
|
# CHECK-NEXT: %[[VAL_1:.*]] = transform.param.constant i32
|
|
# CHECK-SAME: !transform.any_param
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_get_defining_op
|
|
@build_transform_script
|
|
def test_get_defining_op(op: OpHandle):
|
|
op.get_result().get_defining_op()
|
|
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
|
|
# CHECK-NEXT: %[[VAL_1:.*]] = transform.get_result %[[VAL_0]][0]
|
|
# CHECK-SAME: !transform.any_value
|
|
# CHECK-NEXT: %[[VAL_2:.*]] = transform.get_defining_op %[[VAL_1]]
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_get_result
|
|
@build_transform_script
|
|
def test_get_result(op: OpHandle):
|
|
op.get_result()
|
|
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
|
|
# CHECK-NEXT: %[[VAL_1:.*]] = transform.get_result %[[VAL_0]][0]
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_match_ops_single
|
|
@build_transform_script
|
|
def test_match_ops_single(op: OpHandle):
|
|
op.match_ops(scf.ForOp)
|
|
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
|
|
# CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match ops{["scf.for"]}
|
|
# CHECK-SAME: in %[[VAL_0]]
|
|
# CHECK-SAME: -> !transform.op<"scf.for">
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_match_ops_string_name
|
|
@build_transform_script
|
|
def test_match_ops_string_name(op: OpHandle):
|
|
op.match_ops("linalg.matmul")
|
|
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
|
|
# CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
|
|
# CHECK-SAME: ops{["linalg.matmul"]} in %[[VAL_0]]
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_match_ops_string_iface
|
|
@build_transform_script
|
|
def test_match_ops_string_iface(op: OpHandle):
|
|
op.match_ops("LinalgOp")
|
|
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
|
|
# CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
|
|
# CHECK-SAME: interface{LinalgOp} in %[[VAL_0]]
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_match_ops_iface
|
|
@build_transform_script
|
|
def test_match_ops_iface(op: OpHandle):
|
|
op.match_ops(structured.MatchInterfaceEnum.LinalgOp)
|
|
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
|
|
# CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
|
|
# CHECK-SAME: interface{LinalgOp} in %[[VAL_0]]
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_match_ops_multiple
|
|
@build_transform_script
|
|
def test_match_ops_multiple(op: OpHandle):
|
|
op.match_ops([scf.ForOp, scf.ForallOp])
|
|
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
|
|
# CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
|
|
# CHECK-SAME: ops{["scf.for", "scf.forall"]} in %[[VAL_0]]
|
|
# CHECK-SAME: -> !transform.any_op
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_match_ops_mixed
|
|
@build_transform_script
|
|
def test_match_ops_mixed(op: OpHandle):
|
|
op.match_ops([scf.ForOp, "linalg.matmul", scf.ForallOp])
|
|
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
|
|
# CHECK-NEXT: %[[VAL_1:.*]] = transform.structured.match
|
|
# CHECK-SAME: ops{["scf.for", "linalg.matmul", "scf.forall"]} in %[[VAL_0]]
|
|
# CHECK-SAME: -> !transform.any_op
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_print_message
|
|
@build_transform_script
|
|
def test_print_message(op: OpHandle):
|
|
op.print("message")
|
|
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
|
|
# CHECK-NEXT: transform.print %[[VAL_0]] {name = "message"}
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_print_plain
|
|
@build_transform_script
|
|
def test_print_plain(op: OpHandle):
|
|
op.print()
|
|
# CHECK: transform.named_sequence {{.*}}(%[[VAL_0:.*]]: !transform.any_op) {
|
|
# CHECK-NEXT: transform.print %[[VAL_0]]
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_sequence_region
|
|
@construct_and_print_in_module
|
|
def test_sequence_region():
|
|
# CHECK: transform.sequence failures(propagate) {
|
|
# CHECK: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
|
|
# CHECK: %[[VAL_1:.*]] = transform.structured.match ops{["arith.addi"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
|
|
# CHECK: %[[VAL_2:.*]] = get_parent_op %[[VAL_1]] {op_name = "scf.for"} : (!transform.any_op) -> !pdl.operation
|
|
# CHECK: transform.loop.unroll %[[VAL_2]] {factor = 4 : i64} : !pdl.operation
|
|
# CHECK: }
|
|
@sequence([], FailurePropagationMode.Propagate, [])
|
|
def basic(target: any_op_t()):
|
|
m = structured_match(any_op_t(), target, ops=["arith.addi"])
|
|
loop = get_parent_op(pdl.op_t(), m, op_name="scf.for")
|
|
loop_unroll(loop, 4)
|
|
|
|
|
|
# CHECK-LABEL: TEST: test_apply_patterns
|
|
@construct_and_print_in_module
|
|
def test_apply_patterns():
|
|
# CHECK: transform.sequence failures(propagate) {
|
|
# CHECK: ^{{.*}}(%[[VAL_0:.*]]: !transform.any_op):
|
|
# CHECK: %[[VAL_1:.*]] = transform.structured.match ops{["linalg.matmul"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
|
|
# CHECK: %[[VAL_2:.*]] = get_parent_op %[[VAL_1]] {op_name = "func.func"} : (!transform.any_op) -> !pdl.operation
|
|
# CHECK: apply_patterns to %[[VAL_2]] {
|
|
# CHECK: transform.apply_patterns.canonicalization
|
|
# CHECK: } : !pdl.operation
|
|
# CHECK: %[[VAL_3:.*]] = transform.structured.match ops{["func.func"]} in %[[VAL_0]] : (!transform.any_op) -> !transform.any_op
|
|
# CHECK: apply_cse to %[[VAL_3]] : !transform.any_op
|
|
# CHECK: }
|
|
@sequence([], FailurePropagationMode.Propagate, [])
|
|
def basic(variant_op: any_op_t()):
|
|
matmul = structured_match(any_op_t(), variant_op, ops=["linalg.matmul"])
|
|
top_func = get_parent_op(pdl.op_t(), matmul, op_name="func.func")
|
|
|
|
@apply_patterns(top_func)
|
|
def pats():
|
|
apply_patterns_canonicalization()
|
|
|
|
top_func = structured_match(any_op_t(), variant_op, ops=["func.func"])
|
|
apply_cse(top_func)
|