bolt/deps/llvm-18.1.8/mlir/test/python/dialects/arith_dialect.py

95 lines
2.9 KiB
Python
Raw Normal View History

2025-02-14 19:21:04 +01:00
# RUN: %PYTHON %s | FileCheck %s
from functools import partialmethod
from mlir.ir import *
import mlir.dialects.arith as arith
import mlir.dialects.func as func
def run(f):
print("\nTEST:", f.__name__)
f()
# CHECK-LABEL: TEST: testConstantOp
@run
def testConstantOps():
with Context() as ctx, Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
arith.ConstantOp(value=42.42, result=F32Type.get())
# CHECK: %cst = arith.constant 4.242000e+01 : f32
print(module)
# CHECK-LABEL: TEST: testFastMathFlags
@run
def testFastMathFlags():
with Context() as ctx, Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
a = arith.ConstantOp(value=42.42, result=F32Type.get())
r = arith.AddFOp(
a, a, fastmath=arith.FastMathFlags.nnan | arith.FastMathFlags.ninf
)
# CHECK: %0 = arith.addf %cst, %cst fastmath<nnan,ninf> : f32
print(r)
# CHECK-LABEL: TEST: testArithValue
@run
def testArithValue():
def _binary_op(lhs, rhs, op: str) -> "ArithValue":
op = op.capitalize()
if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type):
op += "F"
elif arith._is_integer_like_type(lhs.type) and arith._is_integer_like_type(
lhs.type
):
op += "I"
else:
raise NotImplementedError(f"Unsupported '{op}' operands: {lhs}, {rhs}")
op = getattr(arith, f"{op}Op")
return op(lhs, rhs).result
@register_value_caster(F16Type.static_typeid)
@register_value_caster(F32Type.static_typeid)
@register_value_caster(F64Type.static_typeid)
@register_value_caster(IntegerType.static_typeid)
class ArithValue(Value):
def __init__(self, v):
super().__init__(v)
__add__ = partialmethod(_binary_op, op="add")
__sub__ = partialmethod(_binary_op, op="sub")
__mul__ = partialmethod(_binary_op, op="mul")
def __str__(self):
return super().__str__().replace(Value.__name__, ArithValue.__name__)
with Context() as ctx, Location.unknown():
module = Module.create()
f16_t = F16Type.get()
f32_t = F32Type.get()
f64_t = F64Type.get()
with InsertionPoint(module.body):
a = arith.constant(f16_t, 42.42)
# CHECK: ArithValue(%cst = arith.constant 4.240
print(a)
b = a + a
# CHECK: ArithValue(%0 = arith.addf %cst, %cst : f16)
print(b)
a = arith.constant(f32_t, 42.42)
b = a - a
# CHECK: ArithValue(%1 = arith.subf %cst_0, %cst_0 : f32)
print(b)
a = arith.constant(f64_t, 42.42)
b = a * a
# CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
print(b)