550 lines
13 KiB
Python
550 lines
13 KiB
Python
|
# Copyright 2016-2017 Tobias Grosser
|
||
|
#
|
||
|
# Use of this software is governed by the MIT license
|
||
|
#
|
||
|
# Written by Tobias Grosser, Weststrasse 47, CH-8003, Zurich
|
||
|
|
||
|
import sys
|
||
|
import isl
|
||
|
|
||
|
# Test that isl objects can be constructed.
|
||
|
#
|
||
|
# This tests:
|
||
|
# - construction from a string
|
||
|
# - construction from an integer
|
||
|
# - static constructor without a parameter
|
||
|
# - conversion construction
|
||
|
# - construction of empty union set
|
||
|
#
|
||
|
# The tests to construct from integers and strings cover functionality that
|
||
|
# is also tested in the parameter type tests, but here the presence of
|
||
|
# multiple overloaded constructors and overload resolution is tested.
|
||
|
#
|
||
|
def test_constructors():
|
||
|
zero1 = isl.val("0")
|
||
|
assert zero1.is_zero()
|
||
|
|
||
|
zero2 = isl.val(0)
|
||
|
assert zero2.is_zero()
|
||
|
|
||
|
zero3 = isl.val.zero()
|
||
|
assert zero3.is_zero()
|
||
|
|
||
|
bs = isl.basic_set("{ [1] }")
|
||
|
result = isl.set("{ [1] }")
|
||
|
s = isl.set(bs)
|
||
|
assert s.is_equal(result)
|
||
|
|
||
|
us = isl.union_set("{ A[1]; B[2, 3] }")
|
||
|
empty = isl.union_set.empty()
|
||
|
assert us.is_equal(us.union(empty))
|
||
|
|
||
|
|
||
|
# Test integer function parameters for a particular integer value.
|
||
|
#
|
||
|
def test_int(i):
|
||
|
val_int = isl.val(i)
|
||
|
val_str = isl.val(str(i))
|
||
|
assert val_int.eq(val_str)
|
||
|
|
||
|
|
||
|
# Test integer function parameters.
|
||
|
#
|
||
|
# Verify that extreme values and zero work.
|
||
|
#
|
||
|
def test_parameters_int():
|
||
|
test_int(sys.maxsize)
|
||
|
test_int(-sys.maxsize - 1)
|
||
|
test_int(0)
|
||
|
|
||
|
|
||
|
# Test isl objects parameters.
|
||
|
#
|
||
|
# Verify that isl objects can be passed as lvalue and rvalue parameters.
|
||
|
# Also verify that isl object parameters are automatically type converted if
|
||
|
# there is an inheritance relation. Finally, test function calls without
|
||
|
# any additional parameters, apart from the isl object on which
|
||
|
# the method is called.
|
||
|
#
|
||
|
def test_parameters_obj():
|
||
|
a = isl.set("{ [0] }")
|
||
|
b = isl.set("{ [1] }")
|
||
|
c = isl.set("{ [2] }")
|
||
|
expected = isl.set("{ [i] : 0 <= i <= 2 }")
|
||
|
|
||
|
tmp = a.union(b)
|
||
|
res_lvalue_param = tmp.union(c)
|
||
|
assert res_lvalue_param.is_equal(expected)
|
||
|
|
||
|
res_rvalue_param = a.union(b).union(c)
|
||
|
assert res_rvalue_param.is_equal(expected)
|
||
|
|
||
|
a2 = isl.basic_set("{ [0] }")
|
||
|
assert a.is_equal(a2)
|
||
|
|
||
|
two = isl.val(2)
|
||
|
half = isl.val("1/2")
|
||
|
res_only_this_param = two.inv()
|
||
|
assert res_only_this_param.eq(half)
|
||
|
|
||
|
|
||
|
# Test different kinds of parameters to be passed to functions.
|
||
|
#
|
||
|
# This includes integer and isl object parameters.
|
||
|
#
|
||
|
def test_parameters():
|
||
|
test_parameters_int()
|
||
|
test_parameters_obj()
|
||
|
|
||
|
|
||
|
# Test that isl objects are returned correctly.
|
||
|
#
|
||
|
# This only tests that after combining two objects, the result is successfully
|
||
|
# returned.
|
||
|
#
|
||
|
def test_return_obj():
|
||
|
one = isl.val("1")
|
||
|
two = isl.val("2")
|
||
|
three = isl.val("3")
|
||
|
|
||
|
res = one.add(two)
|
||
|
|
||
|
assert res.eq(three)
|
||
|
|
||
|
|
||
|
# Test that integer values are returned correctly.
|
||
|
#
|
||
|
def test_return_int():
|
||
|
one = isl.val("1")
|
||
|
neg_one = isl.val("-1")
|
||
|
zero = isl.val("0")
|
||
|
|
||
|
assert one.sgn() > 0
|
||
|
assert neg_one.sgn() < 0
|
||
|
assert zero.sgn() == 0
|
||
|
|
||
|
|
||
|
# Test that isl_bool values are returned correctly.
|
||
|
#
|
||
|
# In particular, check the conversion to bool in case of true and false.
|
||
|
#
|
||
|
def test_return_bool():
|
||
|
empty = isl.set("{ : false }")
|
||
|
univ = isl.set("{ : }")
|
||
|
|
||
|
b_true = empty.is_empty()
|
||
|
b_false = univ.is_empty()
|
||
|
|
||
|
assert b_true
|
||
|
assert not b_false
|
||
|
|
||
|
|
||
|
# Test that strings are returned correctly.
|
||
|
# Do so by calling overloaded isl.ast_build.from_expr methods.
|
||
|
#
|
||
|
def test_return_string():
|
||
|
context = isl.set("[n] -> { : }")
|
||
|
build = isl.ast_build.from_context(context)
|
||
|
pw_aff = isl.pw_aff("[n] -> { [n] }")
|
||
|
set = isl.set("[n] -> { : n >= 0 }")
|
||
|
|
||
|
expr = build.expr_from(pw_aff)
|
||
|
expected_string = "n"
|
||
|
assert expected_string == expr.to_C_str()
|
||
|
|
||
|
expr = build.expr_from(set)
|
||
|
expected_string = "n >= 0"
|
||
|
assert expected_string == expr.to_C_str()
|
||
|
|
||
|
|
||
|
# Test that return values are handled correctly.
|
||
|
#
|
||
|
# Test that isl objects, integers, boolean values, and strings are
|
||
|
# returned correctly.
|
||
|
#
|
||
|
def test_return():
|
||
|
test_return_obj()
|
||
|
test_return_int()
|
||
|
test_return_bool()
|
||
|
test_return_string()
|
||
|
|
||
|
|
||
|
# A class that is used to test isl.id.user.
|
||
|
#
|
||
|
class S:
|
||
|
def __init__(self):
|
||
|
self.value = 42
|
||
|
|
||
|
|
||
|
# Test isl.id.user.
|
||
|
#
|
||
|
# In particular, check that the object attached to an identifier
|
||
|
# can be retrieved again.
|
||
|
#
|
||
|
def test_user():
|
||
|
id = isl.id("test", 5)
|
||
|
id2 = isl.id("test2")
|
||
|
id3 = isl.id("S", S())
|
||
|
assert id.user() == 5, f"unexpected user object {id.user()}"
|
||
|
assert id2.user() is None, f"unexpected user object {id2.user()}"
|
||
|
s = id3.user()
|
||
|
assert isinstance(s, S), f"unexpected user object {s}"
|
||
|
assert s.value == 42, f"unexpected user object {s}"
|
||
|
|
||
|
|
||
|
# Test that foreach functions are modeled correctly.
|
||
|
#
|
||
|
# Verify that closures are correctly called as callback of a 'foreach'
|
||
|
# function and that variables captured by the closure work correctly. Also
|
||
|
# check that the foreach function handles exceptions thrown from
|
||
|
# the closure and that it propagates the exception.
|
||
|
#
|
||
|
def test_foreach():
|
||
|
s = isl.set("{ [0]; [1]; [2] }")
|
||
|
|
||
|
list = []
|
||
|
|
||
|
def add(bs):
|
||
|
list.append(bs)
|
||
|
|
||
|
s.foreach_basic_set(add)
|
||
|
|
||
|
assert len(list) == 3
|
||
|
assert list[0].is_subset(s)
|
||
|
assert list[1].is_subset(s)
|
||
|
assert list[2].is_subset(s)
|
||
|
assert not list[0].is_equal(list[1])
|
||
|
assert not list[0].is_equal(list[2])
|
||
|
assert not list[1].is_equal(list[2])
|
||
|
|
||
|
def fail(bs):
|
||
|
raise Exception("fail")
|
||
|
|
||
|
caught = False
|
||
|
try:
|
||
|
s.foreach_basic_set(fail)
|
||
|
except:
|
||
|
caught = True
|
||
|
assert caught
|
||
|
|
||
|
|
||
|
# Test the functionality of "foreach_scc" functions.
|
||
|
#
|
||
|
# In particular, test it on a list of elements that can be completely sorted
|
||
|
# but where two of the elements ("a" and "b") are incomparable.
|
||
|
#
|
||
|
def test_foreach_scc():
|
||
|
list = isl.id_list(3)
|
||
|
sorted = [isl.id_list(3)]
|
||
|
data = {
|
||
|
"a": isl.map("{ [0] -> [1] }"),
|
||
|
"b": isl.map("{ [1] -> [0] }"),
|
||
|
"c": isl.map("{ [i = 0:1] -> [i] }"),
|
||
|
}
|
||
|
for k, v in data.items():
|
||
|
list = list.add(k)
|
||
|
id = data["a"].space().domain().identity_multi_pw_aff_on_domain()
|
||
|
|
||
|
def follows(a, b):
|
||
|
map = data[b.name()].apply_domain(data[a.name()])
|
||
|
return not map.lex_ge_at(id).is_empty()
|
||
|
|
||
|
def add_single(scc):
|
||
|
assert scc.size() == 1
|
||
|
sorted[0] = sorted[0].concat(scc)
|
||
|
|
||
|
list.foreach_scc(follows, add_single)
|
||
|
assert sorted[0].size() == 3
|
||
|
assert sorted[0].at(0).name() == "b"
|
||
|
assert sorted[0].at(1).name() == "c"
|
||
|
assert sorted[0].at(2).name() == "a"
|
||
|
|
||
|
|
||
|
# Test the functionality of "every" functions.
|
||
|
#
|
||
|
# In particular, test the generic functionality and
|
||
|
# test that exceptions are properly propagated.
|
||
|
#
|
||
|
def test_every():
|
||
|
us = isl.union_set("{ A[i]; B[j] }")
|
||
|
|
||
|
def is_empty(s):
|
||
|
return s.is_empty()
|
||
|
|
||
|
assert not us.every_set(is_empty)
|
||
|
|
||
|
def is_non_empty(s):
|
||
|
return not s.is_empty()
|
||
|
|
||
|
assert us.every_set(is_non_empty)
|
||
|
|
||
|
def in_A(s):
|
||
|
return s.is_subset(isl.set("{ A[x] }"))
|
||
|
|
||
|
assert not us.every_set(in_A)
|
||
|
|
||
|
def not_in_A(s):
|
||
|
return not s.is_subset(isl.set("{ A[x] }"))
|
||
|
|
||
|
assert not us.every_set(not_in_A)
|
||
|
|
||
|
def fail(s):
|
||
|
raise Exception("fail")
|
||
|
|
||
|
caught = False
|
||
|
try:
|
||
|
us.ever_set(fail)
|
||
|
except:
|
||
|
caught = True
|
||
|
assert caught
|
||
|
|
||
|
|
||
|
# Check basic construction of spaces.
|
||
|
#
|
||
|
def test_space():
|
||
|
unit = isl.space.unit()
|
||
|
set_space = unit.add_named_tuple("A", 3)
|
||
|
map_space = set_space.add_named_tuple("B", 2)
|
||
|
|
||
|
set = isl.set.universe(set_space)
|
||
|
map = isl.map.universe(map_space)
|
||
|
assert set.is_equal(isl.set("{ A[*,*,*] }"))
|
||
|
assert map.is_equal(isl.map("{ A[*,*,*] -> B[*,*] }"))
|
||
|
|
||
|
|
||
|
# Construct a simple schedule tree with an outer sequence node and
|
||
|
# a single-dimensional band node in each branch, with one of them
|
||
|
# marked coincident.
|
||
|
#
|
||
|
def construct_schedule_tree():
|
||
|
A = isl.union_set("{ A[i] : 0 <= i < 10 }")
|
||
|
B = isl.union_set("{ B[i] : 0 <= i < 20 }")
|
||
|
|
||
|
node = isl.schedule_node.from_domain(A.union(B))
|
||
|
node = node.child(0)
|
||
|
|
||
|
filters = isl.union_set_list(A).add(B)
|
||
|
node = node.insert_sequence(filters)
|
||
|
|
||
|
f_A = isl.multi_union_pw_aff("[ { A[i] -> [i] } ]")
|
||
|
node = node.child(0)
|
||
|
node = node.child(0)
|
||
|
node = node.insert_partial_schedule(f_A)
|
||
|
node = node.member_set_coincident(0, True)
|
||
|
node = node.ancestor(2)
|
||
|
|
||
|
f_B = isl.multi_union_pw_aff("[ { B[i] -> [i] } ]")
|
||
|
node = node.child(1)
|
||
|
node = node.child(0)
|
||
|
node = node.insert_partial_schedule(f_B)
|
||
|
node = node.ancestor(2)
|
||
|
|
||
|
return node.schedule()
|
||
|
|
||
|
|
||
|
# Test basic schedule tree functionality.
|
||
|
#
|
||
|
# In particular, create a simple schedule tree and
|
||
|
# - check that the root node is a domain node
|
||
|
# - test map_descendant_bottom_up
|
||
|
# - test foreach_descendant_top_down
|
||
|
# - test every_descendant
|
||
|
#
|
||
|
def test_schedule_tree():
|
||
|
schedule = construct_schedule_tree()
|
||
|
root = schedule.root()
|
||
|
|
||
|
assert type(root) == isl.schedule_node_domain
|
||
|
|
||
|
count = [0]
|
||
|
|
||
|
def inc_count(node):
|
||
|
count[0] += 1
|
||
|
return node
|
||
|
|
||
|
root = root.map_descendant_bottom_up(inc_count)
|
||
|
assert count[0] == 8
|
||
|
|
||
|
def fail_map(node):
|
||
|
raise Exception("fail")
|
||
|
return node
|
||
|
|
||
|
caught = False
|
||
|
try:
|
||
|
root.map_descendant_bottom_up(fail_map)
|
||
|
except:
|
||
|
caught = True
|
||
|
assert caught
|
||
|
|
||
|
count = [0]
|
||
|
|
||
|
def inc_count(node):
|
||
|
count[0] += 1
|
||
|
return True
|
||
|
|
||
|
root.foreach_descendant_top_down(inc_count)
|
||
|
assert count[0] == 8
|
||
|
|
||
|
count = [0]
|
||
|
|
||
|
def inc_count(node):
|
||
|
count[0] += 1
|
||
|
return False
|
||
|
|
||
|
root.foreach_descendant_top_down(inc_count)
|
||
|
assert count[0] == 1
|
||
|
|
||
|
def is_not_domain(node):
|
||
|
return type(node) != isl.schedule_node_domain
|
||
|
|
||
|
assert root.child(0).every_descendant(is_not_domain)
|
||
|
assert not root.every_descendant(is_not_domain)
|
||
|
|
||
|
def fail(node):
|
||
|
raise Exception("fail")
|
||
|
|
||
|
caught = False
|
||
|
try:
|
||
|
root.every_descendant(fail)
|
||
|
except:
|
||
|
caught = True
|
||
|
assert caught
|
||
|
|
||
|
domain = root.domain()
|
||
|
filters = [isl.union_set("{}")]
|
||
|
|
||
|
def collect_filters(node):
|
||
|
if type(node) == isl.schedule_node_filter:
|
||
|
filters[0] = filters[0].union(node.filter())
|
||
|
return True
|
||
|
|
||
|
root.every_descendant(collect_filters)
|
||
|
assert domain.is_equal(filters[0])
|
||
|
|
||
|
|
||
|
# Test marking band members for unrolling.
|
||
|
# "schedule" is the schedule created by construct_schedule_tree.
|
||
|
# It schedules two statements, with 10 and 20 instances, respectively.
|
||
|
# Unrolling all band members therefore results in 30 at-domain calls
|
||
|
# by the AST generator.
|
||
|
#
|
||
|
def test_ast_build_unroll(schedule):
|
||
|
root = schedule.root()
|
||
|
|
||
|
def mark_unroll(node):
|
||
|
if type(node) == isl.schedule_node_band:
|
||
|
node = node.member_set_ast_loop_unroll(0)
|
||
|
return node
|
||
|
|
||
|
root = root.map_descendant_bottom_up(mark_unroll)
|
||
|
schedule = root.schedule()
|
||
|
|
||
|
count_ast = [0]
|
||
|
|
||
|
def inc_count_ast(node, build):
|
||
|
count_ast[0] += 1
|
||
|
return node
|
||
|
|
||
|
build = isl.ast_build()
|
||
|
build = build.set_at_each_domain(inc_count_ast)
|
||
|
ast = build.node_from(schedule)
|
||
|
assert count_ast[0] == 30
|
||
|
|
||
|
|
||
|
# Test basic AST generation from a schedule tree.
|
||
|
#
|
||
|
# In particular, create a simple schedule tree and
|
||
|
# - generate an AST from the schedule tree
|
||
|
# - test at_each_domain
|
||
|
# - test unrolling
|
||
|
#
|
||
|
def test_ast_build():
|
||
|
schedule = construct_schedule_tree()
|
||
|
|
||
|
count_ast = [0]
|
||
|
|
||
|
def inc_count_ast(node, build):
|
||
|
count_ast[0] += 1
|
||
|
return node
|
||
|
|
||
|
build = isl.ast_build()
|
||
|
build_copy = build.set_at_each_domain(inc_count_ast)
|
||
|
ast = build.node_from(schedule)
|
||
|
assert count_ast[0] == 0
|
||
|
count_ast[0] = 0
|
||
|
ast = build_copy.node_from(schedule)
|
||
|
assert count_ast[0] == 2
|
||
|
build = build_copy
|
||
|
count_ast[0] = 0
|
||
|
ast = build.node_from(schedule)
|
||
|
assert count_ast[0] == 2
|
||
|
|
||
|
do_fail = True
|
||
|
count_ast_fail = [0]
|
||
|
|
||
|
def fail_inc_count_ast(node, build):
|
||
|
count_ast_fail[0] += 1
|
||
|
if do_fail:
|
||
|
raise Exception("fail")
|
||
|
return node
|
||
|
|
||
|
build = isl.ast_build()
|
||
|
build = build.set_at_each_domain(fail_inc_count_ast)
|
||
|
caught = False
|
||
|
try:
|
||
|
ast = build.node_from(schedule)
|
||
|
except:
|
||
|
caught = True
|
||
|
assert caught
|
||
|
assert count_ast_fail[0] > 0
|
||
|
build_copy = build
|
||
|
build_copy = build_copy.set_at_each_domain(inc_count_ast)
|
||
|
count_ast[0] = 0
|
||
|
ast = build_copy.node_from(schedule)
|
||
|
assert count_ast[0] == 2
|
||
|
count_ast_fail[0] = 0
|
||
|
do_fail = False
|
||
|
ast = build.node_from(schedule)
|
||
|
assert count_ast_fail[0] == 2
|
||
|
|
||
|
test_ast_build_unroll(schedule)
|
||
|
|
||
|
|
||
|
# Test basic AST expression generation from an affine expression.
|
||
|
#
|
||
|
def test_ast_build_expr():
|
||
|
pa = isl.pw_aff("[n] -> { [n + 1] }")
|
||
|
build = isl.ast_build.from_context(pa.domain())
|
||
|
|
||
|
op = build.expr_from(pa)
|
||
|
assert type(op) == isl.ast_expr_op_add
|
||
|
assert op.n_arg() == 2
|
||
|
|
||
|
|
||
|
# Test the isl Python interface
|
||
|
#
|
||
|
# This includes:
|
||
|
# - Object construction
|
||
|
# - Different parameter types
|
||
|
# - Different return types
|
||
|
# - isl.id.user
|
||
|
# - Foreach functions
|
||
|
# - Foreach SCC function
|
||
|
# - Every functions
|
||
|
# - Spaces
|
||
|
# - Schedule trees
|
||
|
# - AST generation
|
||
|
# - AST expression generation
|
||
|
#
|
||
|
test_constructors()
|
||
|
test_parameters()
|
||
|
test_return()
|
||
|
test_user()
|
||
|
test_foreach()
|
||
|
test_foreach_scc()
|
||
|
test_every()
|
||
|
test_space()
|
||
|
test_schedule_tree()
|
||
|
test_ast_build()
|
||
|
test_ast_build_expr()
|