390 lines
13 KiB
Python
390 lines
13 KiB
Python
# RUN: %PYTHON %s | FileCheck %s --enable-var-scope=false
|
|
|
|
import gc
|
|
from mlir.ir import *
|
|
from mlir.dialects import func
|
|
from mlir.dialects._ods_common import SubClassValueT
|
|
|
|
|
|
def run(f):
|
|
print("\nTEST:", f.__name__)
|
|
f()
|
|
gc.collect()
|
|
assert Context._get_live_count() == 0
|
|
return f
|
|
|
|
|
|
# CHECK-LABEL: TEST: testCapsuleConversions
|
|
@run
|
|
def testCapsuleConversions():
|
|
ctx = Context()
|
|
ctx.allow_unregistered_dialects = True
|
|
with Location.unknown(ctx):
|
|
i32 = IntegerType.get_signless(32)
|
|
value = Operation.create("custom.op1", results=[i32]).result
|
|
value_capsule = value._CAPIPtr
|
|
assert '"mlir.ir.Value._CAPIPtr"' in repr(value_capsule)
|
|
value2 = Value._CAPICreate(value_capsule)
|
|
assert value2 == value
|
|
|
|
|
|
# CHECK-LABEL: TEST: testOpResultOwner
|
|
@run
|
|
def testOpResultOwner():
|
|
ctx = Context()
|
|
ctx.allow_unregistered_dialects = True
|
|
with Location.unknown(ctx):
|
|
i32 = IntegerType.get_signless(32)
|
|
op = Operation.create("custom.op1", results=[i32])
|
|
assert op.result.owner == op
|
|
|
|
|
|
# CHECK-LABEL: TEST: testBlockArgOwner
|
|
@run
|
|
def testBlockArgOwner():
|
|
ctx = Context()
|
|
ctx.allow_unregistered_dialects = True
|
|
module = Module.parse(
|
|
r"""
|
|
func.func @foo(%arg0: f32) {
|
|
return
|
|
}""",
|
|
ctx,
|
|
)
|
|
func = module.body.operations[0]
|
|
block = func.regions[0].blocks[0]
|
|
assert block.arguments[0].owner == block
|
|
|
|
|
|
# CHECK-LABEL: TEST: testValueIsInstance
|
|
@run
|
|
def testValueIsInstance():
|
|
ctx = Context()
|
|
ctx.allow_unregistered_dialects = True
|
|
module = Module.parse(
|
|
r"""
|
|
func.func @foo(%arg0: f32) {
|
|
%0 = "some_dialect.some_op"() : () -> f64
|
|
return
|
|
}""",
|
|
ctx,
|
|
)
|
|
func = module.body.operations[0]
|
|
assert BlockArgument.isinstance(func.regions[0].blocks[0].arguments[0])
|
|
assert not OpResult.isinstance(func.regions[0].blocks[0].arguments[0])
|
|
|
|
op = func.regions[0].blocks[0].operations[0]
|
|
assert not BlockArgument.isinstance(op.results[0])
|
|
assert OpResult.isinstance(op.results[0])
|
|
|
|
|
|
# CHECK-LABEL: TEST: testValueHash
|
|
@run
|
|
def testValueHash():
|
|
ctx = Context()
|
|
ctx.allow_unregistered_dialects = True
|
|
module = Module.parse(
|
|
r"""
|
|
func.func @foo(%arg0: f32) -> f32 {
|
|
%0 = "some_dialect.some_op"(%arg0) : (f32) -> f32
|
|
return %0 : f32
|
|
}""",
|
|
ctx,
|
|
)
|
|
|
|
[func] = module.body.operations
|
|
block = func.entry_block
|
|
op, ret = block.operations
|
|
assert hash(block.arguments[0]) == hash(op.operands[0])
|
|
assert hash(op.result) == hash(ret.operands[0])
|
|
|
|
|
|
# CHECK-LABEL: TEST: testValueUses
|
|
@run
|
|
def testValueUses():
|
|
ctx = Context()
|
|
ctx.allow_unregistered_dialects = True
|
|
with Location.unknown(ctx):
|
|
i32 = IntegerType.get_signless(32)
|
|
module = Module.create()
|
|
with InsertionPoint(module.body):
|
|
value = Operation.create("custom.op1", results=[i32]).results[0]
|
|
op1 = Operation.create("custom.op2", operands=[value])
|
|
op2 = Operation.create("custom.op2", operands=[value])
|
|
|
|
# CHECK: Use owner: "custom.op2"
|
|
# CHECK: Use operand_number: 0
|
|
# CHECK: Use owner: "custom.op2"
|
|
# CHECK: Use operand_number: 0
|
|
for use in value.uses:
|
|
assert use.owner in [op1, op2]
|
|
print(f"Use owner: {use.owner}")
|
|
print(f"Use operand_number: {use.operand_number}")
|
|
|
|
|
|
# CHECK-LABEL: TEST: testValueReplaceAllUsesWith
|
|
@run
|
|
def testValueReplaceAllUsesWith():
|
|
ctx = Context()
|
|
ctx.allow_unregistered_dialects = True
|
|
with Location.unknown(ctx):
|
|
i32 = IntegerType.get_signless(32)
|
|
module = Module.create()
|
|
with InsertionPoint(module.body):
|
|
value = Operation.create("custom.op1", results=[i32]).results[0]
|
|
op1 = Operation.create("custom.op2", operands=[value])
|
|
op2 = Operation.create("custom.op2", operands=[value])
|
|
value2 = Operation.create("custom.op3", results=[i32]).results[0]
|
|
value.replace_all_uses_with(value2)
|
|
|
|
assert len(list(value.uses)) == 0
|
|
|
|
# CHECK: Use owner: "custom.op2"
|
|
# CHECK: Use operand_number: 0
|
|
# CHECK: Use owner: "custom.op2"
|
|
# CHECK: Use operand_number: 0
|
|
for use in value2.uses:
|
|
assert use.owner in [op1, op2]
|
|
print(f"Use owner: {use.owner}")
|
|
print(f"Use operand_number: {use.operand_number}")
|
|
|
|
|
|
# CHECK-LABEL: TEST: testValuePrintAsOperand
|
|
@run
|
|
def testValuePrintAsOperand():
|
|
ctx = Context()
|
|
ctx.allow_unregistered_dialects = True
|
|
with Location.unknown(ctx):
|
|
i32 = IntegerType.get_signless(32)
|
|
module = Module.create()
|
|
with InsertionPoint(module.body):
|
|
value = Operation.create("custom.op1", results=[i32]).results[0]
|
|
# CHECK: Value(%[[VAL1:.*]] = "custom.op1"() : () -> i32)
|
|
print(value)
|
|
|
|
value2 = Operation.create("custom.op2", results=[i32]).results[0]
|
|
# CHECK: Value(%[[VAL2:.*]] = "custom.op2"() : () -> i32)
|
|
print(value2)
|
|
|
|
topFn = func.FuncOp("test", ([i32, i32], []))
|
|
entry_block1 = Block.create_at_start(topFn.operation.regions[0], [i32, i32])
|
|
|
|
with InsertionPoint(entry_block1):
|
|
value3 = Operation.create("custom.op3", results=[i32]).results[0]
|
|
# CHECK: Value(%[[VAL3:.*]] = "custom.op3"() : () -> i32)
|
|
print(value3)
|
|
value4 = Operation.create("custom.op4", results=[i32]).results[0]
|
|
# CHECK: Value(%[[VAL4:.*]] = "custom.op4"() : () -> i32)
|
|
print(value4)
|
|
|
|
f = func.FuncOp("test", ([i32, i32], []))
|
|
entry_block2 = Block.create_at_start(f.operation.regions[0], [i32, i32])
|
|
with InsertionPoint(entry_block2):
|
|
value5 = Operation.create("custom.op5", results=[i32]).results[0]
|
|
# CHECK: Value(%[[VAL5:.*]] = "custom.op5"() : () -> i32)
|
|
print(value5)
|
|
value6 = Operation.create("custom.op6", results=[i32]).results[0]
|
|
# CHECK: Value(%[[VAL6:.*]] = "custom.op6"() : () -> i32)
|
|
print(value6)
|
|
|
|
func.ReturnOp([])
|
|
|
|
func.ReturnOp([])
|
|
|
|
# CHECK: %[[VAL1]]
|
|
print(value.get_name())
|
|
# CHECK: %[[VAL2]]
|
|
print(value2.get_name())
|
|
# CHECK: %[[VAL3]]
|
|
print(value3.get_name())
|
|
# CHECK: %[[VAL4]]
|
|
print(value4.get_name())
|
|
|
|
print("With AsmState")
|
|
# CHECK-LABEL: With AsmState
|
|
state = AsmState(topFn.operation, use_local_scope=True)
|
|
# CHECK: %0
|
|
print(value3.get_name(state=state))
|
|
# CHECK: %1
|
|
print(value4.get_name(state=state))
|
|
|
|
print("With use_local_scope")
|
|
# CHECK-LABEL: With use_local_scope
|
|
# CHECK: %0
|
|
print(value3.get_name(use_local_scope=True))
|
|
# CHECK: %1
|
|
print(value4.get_name(use_local_scope=True))
|
|
|
|
# CHECK: %[[VAL5]]
|
|
print(value5.get_name())
|
|
# CHECK: %[[VAL6]]
|
|
print(value6.get_name())
|
|
|
|
# CHECK: %[[ARG0:.*]]
|
|
print(entry_block1.arguments[0].get_name())
|
|
# CHECK: %[[ARG1:.*]]
|
|
print(entry_block1.arguments[1].get_name())
|
|
|
|
# CHECK: %[[ARG2:.*]]
|
|
print(entry_block2.arguments[0].get_name())
|
|
# CHECK: %[[ARG3:.*]]
|
|
print(entry_block2.arguments[1].get_name())
|
|
|
|
# CHECK: module {
|
|
# CHECK: %[[VAL1]] = "custom.op1"() : () -> i32
|
|
# CHECK: %[[VAL2]] = "custom.op2"() : () -> i32
|
|
# CHECK: func.func @test(%[[ARG0]]: i32, %[[ARG1]]: i32) {
|
|
# CHECK: %[[VAL3]] = "custom.op3"() : () -> i32
|
|
# CHECK: %[[VAL4]] = "custom.op4"() : () -> i32
|
|
# CHECK: func @test(%[[ARG2]]: i32, %[[ARG3]]: i32) {
|
|
# CHECK: %[[VAL5]] = "custom.op5"() : () -> i32
|
|
# CHECK: %[[VAL6]] = "custom.op6"() : () -> i32
|
|
# CHECK: return
|
|
# CHECK: }
|
|
# CHECK: return
|
|
# CHECK: }
|
|
# CHECK: }
|
|
print(module)
|
|
|
|
value2.owner.detach_from_parent()
|
|
# CHECK: %0
|
|
print(value2.get_name())
|
|
|
|
|
|
# CHECK-LABEL: TEST: testValueSetType
|
|
@run
|
|
def testValueSetType():
|
|
ctx = Context()
|
|
ctx.allow_unregistered_dialects = True
|
|
with Location.unknown(ctx):
|
|
i32 = IntegerType.get_signless(32)
|
|
i64 = IntegerType.get_signless(64)
|
|
module = Module.create()
|
|
with InsertionPoint(module.body):
|
|
value = Operation.create("custom.op1", results=[i32]).results[0]
|
|
# CHECK: Value(%[[VAL1:.*]] = "custom.op1"() : () -> i32)
|
|
print(value)
|
|
|
|
value.set_type(i64)
|
|
# CHECK: Value(%[[VAL1]] = "custom.op1"() : () -> i64)
|
|
print(value)
|
|
|
|
# CHECK: %[[VAL1]] = "custom.op1"() : () -> i64
|
|
print(value.owner)
|
|
|
|
|
|
# CHECK-LABEL: TEST: testValueCasters
|
|
@run
|
|
def testValueCasters():
|
|
class NOPResult(OpResult):
|
|
def __init__(self, v):
|
|
super().__init__(v)
|
|
|
|
def __str__(self):
|
|
return super().__str__().replace(Value.__name__, NOPResult.__name__)
|
|
|
|
class NOPValue(Value):
|
|
def __init__(self, v):
|
|
super().__init__(v)
|
|
|
|
def __str__(self):
|
|
return super().__str__().replace(Value.__name__, NOPValue.__name__)
|
|
|
|
class NOPBlockArg(BlockArgument):
|
|
def __init__(self, v):
|
|
super().__init__(v)
|
|
|
|
def __str__(self):
|
|
return super().__str__().replace(Value.__name__, NOPBlockArg.__name__)
|
|
|
|
@register_value_caster(IntegerType.static_typeid)
|
|
def cast_int(v) -> SubClassValueT:
|
|
print("in caster", v.__class__.__name__)
|
|
if isinstance(v, OpResult):
|
|
return NOPResult(v)
|
|
if isinstance(v, BlockArgument):
|
|
return NOPBlockArg(v)
|
|
elif isinstance(v, Value):
|
|
return NOPValue(v)
|
|
|
|
ctx = Context()
|
|
ctx.allow_unregistered_dialects = True
|
|
with Location.unknown(ctx):
|
|
i32 = IntegerType.get_signless(32)
|
|
module = Module.create()
|
|
with InsertionPoint(module.body):
|
|
values = Operation.create("custom.op1", results=[i32, i32]).results
|
|
# CHECK: in caster OpResult
|
|
# CHECK: result 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
|
|
print("result", values[0].result_number, values[0])
|
|
# CHECK: in caster OpResult
|
|
# CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
|
|
print("result", values[1].result_number, values[1])
|
|
|
|
# CHECK: results slice 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
|
|
print("results slice", values[:1][0].result_number, values[:1][0])
|
|
|
|
value0, value1 = values
|
|
# CHECK: in caster OpResult
|
|
# CHECK: result 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
|
|
print("result", value0.result_number, values[0])
|
|
# CHECK: in caster OpResult
|
|
# CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32))
|
|
print("result", value1.result_number, values[1])
|
|
|
|
op1 = Operation.create("custom.op2", operands=[value0, value1])
|
|
# CHECK: "custom.op2"(%0#0, %0#1) : (i32, i32) -> ()
|
|
print(op1)
|
|
|
|
# CHECK: in caster Value
|
|
# CHECK: operand 0 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32))
|
|
print("operand 0", op1.operands[0])
|
|
# CHECK: in caster Value
|
|
# CHECK: operand 1 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32))
|
|
print("operand 1", op1.operands[1])
|
|
|
|
# CHECK: in caster BlockArgument
|
|
# CHECK: in caster BlockArgument
|
|
@func.FuncOp.from_py_func(i32, i32)
|
|
def reduction(arg0, arg1):
|
|
# CHECK: as func arg 0 NOPBlockArg
|
|
print("as func arg", arg0.arg_number, arg0.__class__.__name__)
|
|
# CHECK: as func arg 1 NOPBlockArg
|
|
print("as func arg", arg1.arg_number, arg1.__class__.__name__)
|
|
|
|
# CHECK: args slice 0 NOPBlockArg(<block argument> of type 'i32' at index: 0)
|
|
print(
|
|
"args slice",
|
|
reduction.func_op.arguments[:1][0].arg_number,
|
|
reduction.func_op.arguments[:1][0],
|
|
)
|
|
|
|
try:
|
|
|
|
@register_value_caster(IntegerType.static_typeid)
|
|
def dont_cast_int_shouldnt_register(v):
|
|
...
|
|
|
|
except RuntimeError as e:
|
|
# CHECK: Value caster is already registered: {{.*}}cast_int
|
|
print(e)
|
|
|
|
@register_value_caster(IntegerType.static_typeid, replace=True)
|
|
def dont_cast_int(v) -> OpResult:
|
|
assert isinstance(v, OpResult)
|
|
print("don't cast", v.result_number, v)
|
|
return v
|
|
|
|
with Location.unknown(ctx):
|
|
i32 = IntegerType.get_signless(32)
|
|
module = Module.create()
|
|
with InsertionPoint(module.body):
|
|
# CHECK: don't cast 0 Value(%0 = "custom.op1"() : () -> i32)
|
|
new_value = Operation.create("custom.op1", results=[i32]).result
|
|
# CHECK: result 0 Value(%0 = "custom.op1"() : () -> i32)
|
|
print("result", new_value.result_number, new_value)
|
|
|
|
# CHECK: don't cast 0 Value(%1 = "custom.op2"() : () -> i32)
|
|
new_value = Operation.create("custom.op2", results=[i32]).results[0]
|
|
# CHECK: result 0 Value(%1 = "custom.op2"() : () -> i32)
|
|
print("result", new_value.result_number, new_value)
|