# RUN: %PYTHON %s | FileCheck %s import gc from mlir.ir import * def run(f): print("\nTEST:", f.__name__) f() gc.collect() assert Context._get_live_count() == 0 return f # CHECK-LABEL: TEST: testAffineMapCapsule @run def testAffineMapCapsule(): with Context() as ctx: am1 = AffineMap.get_empty(ctx) # CHECK: mlir.ir.AffineMap._CAPIPtr affine_map_capsule = am1._CAPIPtr print(affine_map_capsule) am2 = AffineMap._CAPICreate(affine_map_capsule) assert am2 == am1 assert am2.context is ctx # CHECK-LABEL: TEST: testAffineMapGet @run def testAffineMapGet(): with Context() as ctx: d0 = AffineDimExpr.get(0) d1 = AffineDimExpr.get(1) c2 = AffineConstantExpr.get(2) # CHECK: (d0, d1)[s0, s1, s2] -> () map0 = AffineMap.get(2, 3, []) print(map0) # CHECK: (d0, d1)[s0, s1, s2] -> (d1, 2) map1 = AffineMap.get(2, 3, [d1, c2]) print(map1) # CHECK: () -> (2) map2 = AffineMap.get(0, 0, [c2]) print(map2) # CHECK: (d0, d1) -> (d0, d1) map3 = AffineMap.get(2, 0, [d0, d1]) print(map3) # CHECK: (d0, d1) -> (d1) map4 = AffineMap.get(2, 0, [d1]) print(map4) # CHECK: (d0, d1, d2) -> (d2, d0, d1) map5 = AffineMap.get_permutation([2, 0, 1]) print(map5) assert map1 == AffineMap.get(2, 3, [d1, c2]) assert AffineMap.get(0, 0, []) == AffineMap.get_empty() assert map2 == AffineMap.get_constant(2) assert map3 == AffineMap.get_identity(2) assert map4 == AffineMap.get_minor_identity(2, 1) try: AffineMap.get(1, 1, [1]) except RuntimeError as e: # CHECK: Invalid expression when attempting to create an AffineMap print(e) try: AffineMap.get(1, 1, [None]) except RuntimeError as e: # CHECK: Invalid expression (None?) when attempting to create an AffineMap print(e) try: AffineMap.get_permutation([1, 0, 1]) except RuntimeError as e: # CHECK: Invalid permutation when attempting to create an AffineMap print(e) try: map3.get_submap([42]) except ValueError as e: # CHECK: result position out of bounds print(e) try: map3.get_minor_submap(42) except ValueError as e: # CHECK: number of results out of bounds print(e) try: map3.get_major_submap(42) except ValueError as e: # CHECK: number of results out of bounds print(e) # CHECK-LABEL: TEST: testAffineMapDerive @run def testAffineMapDerive(): with Context() as ctx: map5 = AffineMap.get_identity(5) # CHECK: (d0, d1, d2, d3, d4) -> (d1, d2, d3) map123 = map5.get_submap([1, 2, 3]) print(map123) # CHECK: (d0, d1, d2, d3, d4) -> (d0, d1) map01 = map5.get_major_submap(2) print(map01) # CHECK: (d0, d1, d2, d3, d4) -> (d3, d4) map34 = map5.get_minor_submap(2) print(map34) # CHECK-LABEL: TEST: testAffineMapProperties @run def testAffineMapProperties(): with Context(): d0 = AffineDimExpr.get(0) d1 = AffineDimExpr.get(1) d2 = AffineDimExpr.get(2) map1 = AffineMap.get(3, 0, [d2, d0]) map2 = AffineMap.get(3, 0, [d2, d0, d1]) map3 = AffineMap.get(3, 1, [d2, d0, d1]) # CHECK: False print(map1.is_permutation) # CHECK: True print(map1.is_projected_permutation) # CHECK: True print(map2.is_permutation) # CHECK: True print(map2.is_projected_permutation) # CHECK: False print(map3.is_permutation) # CHECK: False print(map3.is_projected_permutation) # CHECK-LABEL: TEST: testAffineMapExprs @run def testAffineMapExprs(): with Context(): d0 = AffineDimExpr.get(0) d1 = AffineDimExpr.get(1) d2 = AffineDimExpr.get(2) map3 = AffineMap.get(3, 1, [d2, d0, d1]) # CHECK: 3 print(map3.n_dims) # CHECK: 4 print(map3.n_inputs) # CHECK: 1 print(map3.n_symbols) assert map3.n_inputs == map3.n_dims + map3.n_symbols # CHECK: 3 print(len(map3.results)) for expr in map3.results: # CHECK: d2 # CHECK: d0 # CHECK: d1 print(expr) for expr in map3.results[-1:-4:-1]: # CHECK: d1 # CHECK: d0 # CHECK: d2 print(expr) assert list(map3.results) == [d2, d0, d1] # CHECK-LABEL: TEST: testCompressUnusedSymbols @run def testCompressUnusedSymbols(): with Context() as ctx: d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1), AffineDimExpr.get(2)) s0, s1, s2 = ( AffineSymbolExpr.get(0), AffineSymbolExpr.get(1), AffineSymbolExpr.get(2), ) maps = [ AffineMap.get(3, 3, [d2, d0, d1]), AffineMap.get(3, 3, [d2, d0 + s2, d1]), AffineMap.get(3, 3, [d1, d2, d0]), ] compressed_maps = AffineMap.compress_unused_symbols(maps, ctx) # CHECK: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d2, d0, d1)) # CHECK-SAME: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2, d1)) # CHECK-SAME: AffineMap((d0, d1, d2)[s0, s1, s2] -> (d1, d2, d0)) print(maps) # CHECK: AffineMap((d0, d1, d2)[s0] -> (d2, d0, d1)) # CHECK-SAME: AffineMap((d0, d1, d2)[s0] -> (d2, d0 + s0, d1)) # CHECK-SAME: AffineMap((d0, d1, d2)[s0] -> (d1, d2, d0)) print(compressed_maps) # CHECK-LABEL: TEST: testReplace @run def testReplace(): with Context() as ctx: d0, d1, d2 = (AffineDimExpr.get(0), AffineDimExpr.get(1), AffineDimExpr.get(2)) s0, s1, s2 = ( AffineSymbolExpr.get(0), AffineSymbolExpr.get(1), AffineSymbolExpr.get(2), ) map1 = AffineMap.get(3, 3, [d2, d0 + s1 + s2, d1 + s0]) replace0 = map1.replace(s0, AffineConstantExpr.get(42), 3, 3) replace1 = map1.replace(s1, AffineConstantExpr.get(42), 3, 3) replace3 = map1.replace(s2, AffineConstantExpr.get(42), 3, 2) # CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s1 + s2, d1 + 42) print(replace0) # CHECK: (d0, d1, d2)[s0, s1, s2] -> (d2, d0 + s2 + 42, d1 + s0) print(replace1) # CHECK: (d0, d1, d2)[s0, s1] -> (d2, d0 + s1 + 42, d1 + s0) print(replace3) # CHECK-LABEL: TEST: testHash @run def testHash(): with Context(): d0, d1 = AffineDimExpr.get(0), AffineDimExpr.get(1) m1 = AffineMap.get(2, 0, [d0, d1]) m2 = AffineMap.get(2, 0, [d1, d0]) assert hash(m1) == hash(AffineMap.get(2, 0, [d0, d1])) dictionary = dict() dictionary[m1] = 1 dictionary[m2] = 2 assert m1 in dictionary