# RUN: %PYTHON %s | FileCheck %s from functools import partialmethod from mlir.ir import * import mlir.dialects.arith as arith import mlir.dialects.func as func def run(f): print("\nTEST:", f.__name__) f() # CHECK-LABEL: TEST: testConstantOp @run def testConstantOps(): with Context() as ctx, Location.unknown(): module = Module.create() with InsertionPoint(module.body): arith.ConstantOp(value=42.42, result=F32Type.get()) # CHECK: %cst = arith.constant 4.242000e+01 : f32 print(module) # CHECK-LABEL: TEST: testFastMathFlags @run def testFastMathFlags(): with Context() as ctx, Location.unknown(): module = Module.create() with InsertionPoint(module.body): a = arith.ConstantOp(value=42.42, result=F32Type.get()) r = arith.AddFOp( a, a, fastmath=arith.FastMathFlags.nnan | arith.FastMathFlags.ninf ) # CHECK: %0 = arith.addf %cst, %cst fastmath : f32 print(r) # CHECK-LABEL: TEST: testArithValue @run def testArithValue(): def _binary_op(lhs, rhs, op: str) -> "ArithValue": op = op.capitalize() if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type): op += "F" elif arith._is_integer_like_type(lhs.type) and arith._is_integer_like_type( lhs.type ): op += "I" else: raise NotImplementedError(f"Unsupported '{op}' operands: {lhs}, {rhs}") op = getattr(arith, f"{op}Op") return op(lhs, rhs).result @register_value_caster(F16Type.static_typeid) @register_value_caster(F32Type.static_typeid) @register_value_caster(F64Type.static_typeid) @register_value_caster(IntegerType.static_typeid) class ArithValue(Value): def __init__(self, v): super().__init__(v) __add__ = partialmethod(_binary_op, op="add") __sub__ = partialmethod(_binary_op, op="sub") __mul__ = partialmethod(_binary_op, op="mul") def __str__(self): return super().__str__().replace(Value.__name__, ArithValue.__name__) with Context() as ctx, Location.unknown(): module = Module.create() f16_t = F16Type.get() f32_t = F32Type.get() f64_t = F64Type.get() with InsertionPoint(module.body): a = arith.constant(f16_t, 42.42) # CHECK: ArithValue(%cst = arith.constant 4.240 print(a) b = a + a # CHECK: ArithValue(%0 = arith.addf %cst, %cst : f16) print(b) a = arith.constant(f32_t, 42.42) b = a - a # CHECK: ArithValue(%1 = arith.subf %cst_0, %cst_0 : f32) print(b) a = arith.constant(f64_t, 42.42) b = a * a # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64) print(b)