300 lines
10 KiB
MLIR
300 lines
10 KiB
MLIR
// Check that the wide integer emulation produces the same result as wide
|
|
// calculations. Emulate i16 ops with i8 ops.
|
|
|
|
// RUN: mlir-opt %s --test-arith-emulate-wide-int="widest-int-supported=8" \
|
|
// RUN: --convert-vector-to-scf --convert-scf-to-cf --convert-cf-to-llvm \
|
|
// RUN: --convert-vector-to-llvm --convert-func-to-llvm --convert-arith-to-llvm \
|
|
// RUN: --reconcile-unrealized-casts | \
|
|
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
|
|
// RUN: --shared-libs="%mlir_c_runner_utils,%mlir_runner_utils" | \
|
|
// RUN: FileCheck %s
|
|
|
|
// CHECK-NOT: Mismatch
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Common Utility Functions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Prints both binary op operands and the first result. If the second result
|
|
// does not match, prints the second result and a 'Mismatch' message.
|
|
func.func @check_results(%lhs : i16, %rhs : i16, %res0 : i16, %res1 : i16) -> () {
|
|
%vec_zero = arith.constant dense<0> : vector<2xi16>
|
|
%ins0 = vector.insert %lhs, %vec_zero[0] : i16 into vector<2xi16>
|
|
%operands = vector.insert %rhs, %ins0[1] : i16 into vector<2xi16>
|
|
vector.print %operands : vector<2xi16>
|
|
vector.print %res0 : i16
|
|
%mismatch = arith.cmpi ne, %res0, %res1 : i16
|
|
scf.if %mismatch -> () {
|
|
vector.print %res1 : i16
|
|
vector.print str "Mismatch"
|
|
}
|
|
return
|
|
}
|
|
|
|
func.func @xorshift(%i : i16) -> (i16) {
|
|
%cst8 = arith.constant 8 : i16
|
|
%shifted = arith.shrui %i, %cst8 : i16
|
|
%res = arith.xori %i, %shifted : i16
|
|
return %res : i16
|
|
}
|
|
|
|
// Returns a hash of the input number. This is used we want to sample a bunch
|
|
// of i16 inputs with close to uniform distribution but without fixed offsets
|
|
// between each sample.
|
|
func.func @xhash(%i : i16) -> (i16) {
|
|
%pattern = arith.constant 21845 : i16 // Alternating ones and zeros.
|
|
%prime = arith.constant 25867 : i16 // Large i16 prime.
|
|
%xi = func.call @xorshift(%i) : (i16) -> (i16)
|
|
%inner = arith.muli %xi, %pattern : i16
|
|
%xinner = func.call @xorshift(%inner) : (i16) -> (i16)
|
|
%res = arith.muli %xinner, %prime : i16
|
|
return %res : i16
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Test arith.addi
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Ops in this function will be emulated using i8 ops.
|
|
func.func @emulate_addi(%lhs : i16, %rhs : i16) -> (i16) {
|
|
%res = arith.addi %lhs, %rhs : i16
|
|
return %res : i16
|
|
}
|
|
|
|
// Performs both wide and emulated `arith.muli`, and checks that the results
|
|
// match.
|
|
func.func @check_addi(%lhs : i16, %rhs : i16) -> () {
|
|
%wide = arith.addi %lhs, %rhs : i16
|
|
%emulated = func.call @emulate_addi(%lhs, %rhs) : (i16, i16) -> (i16)
|
|
func.call @check_results(%lhs, %rhs, %wide, %emulated) : (i16, i16, i16, i16) -> ()
|
|
return
|
|
}
|
|
|
|
// Checks that `arith.addi` is emulated properly by sampling the input space.
|
|
// In total, this test function checks 500 * 500 = 250k input pairs.
|
|
func.func @test_addi() -> () {
|
|
%idx0 = arith.constant 0 : index
|
|
%idx1 = arith.constant 1 : index
|
|
%idx500 = arith.constant 500 : index
|
|
|
|
%cst0 = arith.constant 0 : i16
|
|
%cst1 = arith.constant 1 : i16
|
|
|
|
scf.for %lhs_idx = %idx0 to %idx500 step %idx1 iter_args(%lhs = %cst0) -> (i16) {
|
|
%arg_lhs = func.call @xhash(%lhs) : (i16) -> (i16)
|
|
|
|
scf.for %rhs_idx = %idx0 to %idx500 step %idx1 iter_args(%rhs = %cst0) -> (i16) {
|
|
%arg_rhs = func.call @xhash(%rhs) : (i16) -> (i16)
|
|
func.call @check_addi(%arg_lhs, %arg_rhs) : (i16, i16) -> ()
|
|
|
|
%rhs_next = arith.addi %rhs, %cst1 : i16
|
|
scf.yield %rhs_next : i16
|
|
}
|
|
|
|
%lhs_next = arith.addi %lhs, %cst1 : i16
|
|
scf.yield %lhs_next : i16
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Test arith.muli
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Ops in this function will be emulated using i8 ops.
|
|
func.func @emulate_muli(%lhs : i16, %rhs : i16) -> (i16) {
|
|
%res = arith.muli %lhs, %rhs : i16
|
|
return %res : i16
|
|
}
|
|
|
|
// Performs both wide and emulated `arith.muli`, and checks that the results
|
|
// match.
|
|
func.func @check_muli(%lhs : i16, %rhs : i16) -> () {
|
|
%wide = arith.muli %lhs, %rhs : i16
|
|
%emulated = func.call @emulate_muli(%lhs, %rhs) : (i16, i16) -> (i16)
|
|
func.call @check_results(%lhs, %rhs, %wide, %emulated) : (i16, i16, i16, i16) -> ()
|
|
return
|
|
}
|
|
|
|
// Checks that `arith.muli` is emulated properly by sampling the input space.
|
|
// In total, this test function checks 500 * 500 = 250k input pairs.
|
|
func.func @test_muli() -> () {
|
|
%idx0 = arith.constant 0 : index
|
|
%idx1 = arith.constant 1 : index
|
|
%idx500 = arith.constant 500 : index
|
|
|
|
%cst0 = arith.constant 0 : i16
|
|
%cst1 = arith.constant 1 : i16
|
|
|
|
scf.for %lhs_idx = %idx0 to %idx500 step %idx1 iter_args(%lhs = %cst0) -> (i16) {
|
|
%arg_lhs = func.call @xhash(%lhs) : (i16) -> (i16)
|
|
|
|
scf.for %rhs_idx = %idx0 to %idx500 step %idx1 iter_args(%rhs = %cst0) -> (i16) {
|
|
%arg_rhs = func.call @xhash(%rhs) : (i16) -> (i16)
|
|
func.call @check_muli(%arg_lhs, %arg_rhs) : (i16, i16) -> ()
|
|
|
|
%rhs_next = arith.addi %rhs, %cst1 : i16
|
|
scf.yield %rhs_next : i16
|
|
}
|
|
|
|
%lhs_next = arith.addi %lhs, %cst1 : i16
|
|
scf.yield %lhs_next : i16
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Test arith.shli
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Ops in this function will be emulated using i8 ops.
|
|
func.func @emulate_shli(%lhs : i16, %rhs : i16) -> (i16) {
|
|
%res = arith.shli %lhs, %rhs : i16
|
|
return %res : i16
|
|
}
|
|
|
|
// Performs both wide and emulated `arith.shli`, and checks that the results
|
|
// match.
|
|
func.func @check_shli(%lhs : i16, %rhs : i16) -> () {
|
|
%wide = arith.shli %lhs, %rhs : i16
|
|
%emulated = func.call @emulate_shli(%lhs, %rhs) : (i16, i16) -> (i16)
|
|
func.call @check_results(%lhs, %rhs, %wide, %emulated) : (i16, i16, i16, i16) -> ()
|
|
return
|
|
}
|
|
|
|
// Checks that `arith.shli` is emulated properly by sampling the input space.
|
|
// Checks all valid shift amounts for i16: 0 to 15.
|
|
// In total, this test function checks 100 * 16 = 1.6k input pairs.
|
|
func.func @test_shli() -> () {
|
|
%idx0 = arith.constant 0 : index
|
|
%idx1 = arith.constant 1 : index
|
|
%idx16 = arith.constant 16 : index
|
|
%idx100 = arith.constant 100 : index
|
|
|
|
%cst0 = arith.constant 0 : i16
|
|
%cst1 = arith.constant 1 : i16
|
|
|
|
scf.for %lhs_idx = %idx0 to %idx100 step %idx1 iter_args(%lhs = %cst0) -> (i16) {
|
|
%arg_lhs = func.call @xhash(%lhs) : (i16) -> (i16)
|
|
|
|
scf.for %rhs_idx = %idx0 to %idx16 step %idx1 iter_args(%rhs = %cst0) -> (i16) {
|
|
func.call @check_shli(%arg_lhs, %rhs) : (i16, i16) -> ()
|
|
%rhs_next = arith.addi %rhs, %cst1 : i16
|
|
scf.yield %rhs_next : i16
|
|
}
|
|
|
|
%lhs_next = arith.addi %lhs, %cst1 : i16
|
|
scf.yield %lhs_next : i16
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Test arith.shrsi
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Ops in this function will be emulated using i8 ops.
|
|
func.func @emulate_shrsi(%lhs : i16, %rhs : i16) -> (i16) {
|
|
%res = arith.shrsi %lhs, %rhs : i16
|
|
return %res : i16
|
|
}
|
|
|
|
// Performs both wide and emulated `arith.shrsi`, and checks that the results
|
|
// match.
|
|
func.func @check_shrsi(%lhs : i16, %rhs : i16) -> () {
|
|
%wide = arith.shrsi %lhs, %rhs : i16
|
|
%emulated = func.call @emulate_shrsi(%lhs, %rhs) : (i16, i16) -> (i16)
|
|
func.call @check_results(%lhs, %rhs, %wide, %emulated) : (i16, i16, i16, i16) -> ()
|
|
return
|
|
}
|
|
|
|
// Checks that `arith.shrus` is emulated properly by sampling the input space.
|
|
// Checks all valid shift amounts for i16: 0 to 15.
|
|
// In total, this test function checks 100 * 16 = 1.6k input pairs.
|
|
func.func @test_shrsi() -> () {
|
|
%idx0 = arith.constant 0 : index
|
|
%idx1 = arith.constant 1 : index
|
|
%idx16 = arith.constant 16 : index
|
|
%idx100 = arith.constant 100 : index
|
|
|
|
%cst0 = arith.constant 0 : i16
|
|
%cst1 = arith.constant 1 : i16
|
|
|
|
scf.for %lhs_idx = %idx0 to %idx100 step %idx1 iter_args(%lhs = %cst0) -> (i16) {
|
|
%arg_lhs = func.call @xhash(%lhs) : (i16) -> (i16)
|
|
|
|
scf.for %rhs_idx = %idx0 to %idx16 step %idx1 iter_args(%rhs = %cst0) -> (i16) {
|
|
func.call @check_shrsi(%arg_lhs, %rhs) : (i16, i16) -> ()
|
|
%rhs_next = arith.addi %rhs, %cst1 : i16
|
|
scf.yield %rhs_next : i16
|
|
}
|
|
|
|
%lhs_next = arith.addi %lhs, %cst1 : i16
|
|
scf.yield %lhs_next : i16
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Test arith.shrui
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Ops in this function will be emulated using i8 ops.
|
|
func.func @emulate_shrui(%lhs : i16, %rhs : i16) -> (i16) {
|
|
%res = arith.shrui %lhs, %rhs : i16
|
|
return %res : i16
|
|
}
|
|
|
|
// Performs both wide and emulated `arith.shrui`, and checks that the results
|
|
// match.
|
|
func.func @check_shrui(%lhs : i16, %rhs : i16) -> () {
|
|
%wide = arith.shrui %lhs, %rhs : i16
|
|
%emulated = func.call @emulate_shrui(%lhs, %rhs) : (i16, i16) -> (i16)
|
|
func.call @check_results(%lhs, %rhs, %wide, %emulated) : (i16, i16, i16, i16) -> ()
|
|
return
|
|
}
|
|
|
|
// Checks that `arith.shrui` is emulated properly by sampling the input space.
|
|
// Checks all valid shift amounts for i16: 0 to 15.
|
|
// In total, this test function checks 100 * 16 = 1.6k input pairs.
|
|
func.func @test_shrui() -> () {
|
|
%idx0 = arith.constant 0 : index
|
|
%idx1 = arith.constant 1 : index
|
|
%idx16 = arith.constant 16 : index
|
|
%idx100 = arith.constant 100 : index
|
|
|
|
%cst0 = arith.constant 0 : i16
|
|
%cst1 = arith.constant 1 : i16
|
|
|
|
scf.for %lhs_idx = %idx0 to %idx100 step %idx1 iter_args(%lhs = %cst0) -> (i16) {
|
|
%arg_lhs = func.call @xhash(%lhs) : (i16) -> (i16)
|
|
|
|
scf.for %rhs_idx = %idx0 to %idx16 step %idx1 iter_args(%rhs = %cst0) -> (i16) {
|
|
func.call @check_shrui(%arg_lhs, %rhs) : (i16, i16) -> ()
|
|
%rhs_next = arith.addi %rhs, %cst1 : i16
|
|
scf.yield %rhs_next : i16
|
|
}
|
|
|
|
%lhs_next = arith.addi %lhs, %cst1 : i16
|
|
scf.yield %lhs_next : i16
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Entry Point
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
func.func @entry() {
|
|
func.call @test_addi() : () -> ()
|
|
func.call @test_muli() : () -> ()
|
|
func.call @test_shli() : () -> ()
|
|
func.call @test_shrsi() : () -> ()
|
|
func.call @test_shrui() : () -> ()
|
|
return
|
|
}
|