// RUN: mlir-opt --split-input-file --tosa-to-arith="include-apply-rescale=true use-32-bit=true" %s -verify-diagnostics -o -| FileCheck %s // RUN: mlir-opt --split-input-file --tosa-to-arith="include-apply-rescale=false" %s -verify-diagnostics -o -| FileCheck --check-prefix="SCALE" %s // CHECK-LABEL: func @const_test func.func @const_test() -> (tensor) { // CHECK: [[C3:%.+]] = arith.constant dense<3> : tensor %result = "tosa.const"() {value = dense<3> : tensor} : () -> tensor // CHECK: return [[C3]] return %result : tensor } // ----- // CHECK-LABEL: @apply_scale_test_i32 // SCALE: tosa.apply_scale func.func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) { // CHECK-DAG: %[[S32:.+]] = arith.extui %arg2 : i8 to i32 // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32 // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i32 // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : i32 // CHECK-DAG: %[[C30:.+]] = arith.constant 30 : i32 // CHECK-DAG: %[[C32:.+]] = arith.constant 32 : i32 // Compute the high-low values of the matmul in 64-bits. // CHECK-DAG: %[[LOW:.+]], %[[HI:.+]] = arith.mulsi_extended %arg0, %arg1 // Determine whether the high bits need to shift left or right and by how much. // CHECK-DAG: %[[OVER31:.+]] = arith.cmpi sge, %[[S32]], %[[C32]] // CHECK-DAG: %[[OVER32:.+]] = arith.cmpi sgt, %[[S32]], %[[C32]] // CHECK-DAG: %[[HISHLN:.+]] = arith.subi %[[C32]], %[[S32]] // CHECK-DAG: %[[HISHRN:.+]] = arith.subi %[[S32]], %[[C32]] // CHECK-DAG: %[[HISHL:.+]] = arith.select %[[OVER31]], %[[C0]], %[[HISHLN]] // CHECK-DAG: %[[HISHR:.+]] = arith.select %[[OVER31]], %[[HISHRN]], %[[C0]] // Apply double rounding. // CHECK-DAG: %[[CN1:.+]] = arith.constant -1 // CHECK-DAG: %[[POS:.+]] = arith.cmpi sge, %arg0, %[[C0]] // CHECK-DAG: %[[DIR:.+]] = arith.select %[[POS]], %[[C1]], %[[CN1]] // CHECK-DAG: %[[DRND:.+]] = arith.select %[[OVER31]], %[[DIR]], %[[C0]] // CHECK-DAG: %[[DSHFTR:.+]] = arith.shrui %[[LOW]], %[[C30]] // CHECK-DAG: %[[DRNDED:.+]] = arith.addi %[[DSHFTR]], %[[DRND]] // CHECK-DAG: %[[DCARRY:.+]] = arith.shrsi %[[DRNDED]], %[[C2:.+]] // CHECK-DAG: %[[DBIT:.+]] = arith.shli %[[DRND]], %[[C30]] // CHECK-DAG: %[[DLOW:.+]] = arith.addi %[[LOW]], %[[DBIT]] // CHECK-DAG: %[[DHI:.+]] = arith.addi %[[HI]], %[[DCARRY]] // Apply low-bit rounding. // CHECK-DAG: %[[SHFTM1:.+]] = arith.subi %[[S32]], %[[C1]] // CHECK-DAG: %[[LBIT:.+]] = arith.shli %[[C1]], %[[SHFTM1]] // CHECK-DAG: %[[HALF:.+]] = arith.select %[[OVER32]], %[[C0]], %[[LBIT]] // CHECK-DAG: %[[LADD:.+]] = arith.addi %[[DLOW]], %[[HALF]] // CHECK-DAG: %[[LLO:.+]] = arith.cmpi ugt, %[[DLOW]], %[[LADD]] // CHECK-DAG: %[[LCARRY:.+]] = arith.extui %[[LLO]] : i1 to i32 // CHECK-DAG: %[[LRNDED:.+]] = arith.addi %[[DHI]], %[[LCARRY]] // Apply high-bit rounding. // CHECK-DAG: %[[HISHRM1:.+]] = arith.subi %[[HISHR]], %[[C1]] // CHECK-DAG: %[[LHISHFT:.+]] = arith.shli %[[C1]], %[[HISHRM1]] // CHECK-DAG: %[[LHI:.+]] = arith.select %[[OVER32]], %[[LHISHFT]], %[[C0]] // CHECK-DAG: %[[FHI:.+]] = arith.addi %[[LRNDED]], %[[LHI]] // Combine hi-low into the final result. // CHECK-DAG: %[[HIL:.+]] = arith.shli %[[FHI]], %[[HISHL]] // CHECK-DAG: %[[HIALIGN:.+]] = arith.shrsi %[[HIL:.+]], %[[HISHR]] // CHECK-DAG: %[[LOR:.+]] = arith.shrui %[[LADD]], %[[S32]] // CHECK-DAG: %[[LOWALIGN:.+]] = arith.select %[[OVER31]], %[[C0]], %[[LOR]] // CHECK-DAG: %[[RESULT:.+]] = arith.addi %[[LOWALIGN]], %[[HIALIGN]] // CHECK: return %[[RESULT]] %res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i32, i32, i8) -> i32 return %res : i32 } // ----- // CHECK-LABEL: @apply_scale_test_vector // SCALE: tosa.apply_scale func.func @apply_scale_test_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>, %arg2 : vector<4xi8>) -> (vector<4xi32>) { // CHECK-NOT: "tosa.apply_scale" %res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (vector<4xi32>, vector<4xi32>, vector<4xi8>) -> vector<4xi32> return %res : vector<4xi32> } // ----- // CHECK-LABEL: @apply_scale_test_i48 // SCALE: tosa.apply_scale func.func @apply_scale_test_i48(%arg0 : i48, %arg1 : i32, %arg2 : i8) -> (i32) { // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i48 // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i64 // CHECK-DAG: %[[C31:.+]] = arith.constant 31 : i32 // Multiply in 64 bits. // CHECK-DAG: %[[V64:.+]] = arith.extsi %arg0 : i48 to i64 // CHECK-DAG: %[[M64:.+]] = arith.extsi %arg1 : i32 to i64 // CHECK-DAG: %[[MUL:.+]] = arith.muli %[[V64]], %[[M64]] // Round normally. // CHECK-DAG: %[[S32:.+]] = arith.extui %arg2 : i8 to i32 // CHECK-DAG: %[[S64:.+]] = arith.extui %[[S32]] : i32 to i64 // CHECK-DAG: %[[ONEL:.+]] = arith.shli %[[C1]], %[[S64]] : i64 // CHECK-DAG: %[[ONER:.+]] = arith.shrui %[[ONEL]], %[[C1]] // CHECK-DAG: %[[ROUND:.+]] = arith.addi %[[MUL]], %[[ONER]] // Apply double rounding. // CHECK-DAG: %[[DUP:.+]] = arith.constant 1073741824 : i64 // CHECK-DAG: %[[DDOWN:.+]] = arith.constant -1073741824 : i64 // CHECK-DAG: %[[POS:.+]] = arith.cmpi sge, %arg0, %[[C0]] // CHECK-DAG: %[[DBIT:.+]] = arith.select %[[POS]], %[[DUP]], %[[DDOWN]] // CHECK-DAG: %[[DRND:.+]] = arith.addi %[[DBIT]], %[[ROUND]] // CHECK-DAG: %[[USED:.+]] = arith.cmpi sgt, %[[S32]], %[[C31]] : i32 // CHECK-DAG: %[[RES64:.+]] = arith.select %[[USED]], %[[DRND]], %[[ROUND]] : i64 // Shift and truncate final answer. // CHECK-DAG: %[[SHR:.+]] = arith.shrsi %[[RES64]], %[[S64]] // CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i64 to i32 // CHECK: return %[[TRUNC]] %res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i48, i32, i8) -> i32 return %res : i32 } // ----- // CHECK-LABEL: @apply_scale_test_i64 // SCALE: tosa.apply_scale func.func @apply_scale_test_i64(%arg0 : i64, %arg1 : i32, %arg2 : i8) -> (i32) { // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i64 // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i64 // CHECK-DAG: %[[C31:.+]] = arith.constant 31 : i32 // Multiply in 64 bits. // CHECK-DAG: %[[M64:.+]] = arith.extsi %arg1 : i32 to i64 // CHECK-DAG: %[[MUL:.+]] = arith.muli %arg0, %[[M64]] // Round normally. // CHECK-DAG: %[[S32:.+]] = arith.extui %arg2 : i8 to i32 // CHECK-DAG: %[[S64:.+]] = arith.extui %[[S32]] : i32 to i64 // CHECK-DAG: %[[ONEL:.+]] = arith.shli %[[C1]], %[[S64]] : i64 // CHECK-DAG: %[[ONER:.+]] = arith.shrui %[[ONEL]], %[[C1]] // CHECK-DAG: %[[ROUND:.+]] = arith.addi %[[MUL]], %[[ONER]] // Apply double rounding. // CHECK-DAG: %[[DUP:.+]] = arith.constant 1073741824 : i64 // CHECK-DAG: %[[DDOWN:.+]] = arith.constant -1073741824 : i64 // CHECK-DAG: %[[POS:.+]] = arith.cmpi sge, %arg0, %[[C0]] // CHECK-DAG: %[[DBIT:.+]] = arith.select %[[POS]], %[[DUP]], %[[DDOWN]] // CHECK-DAG: %[[DRND:.+]] = arith.addi %[[DBIT]], %[[ROUND]] // CHECK-DAG: %[[USED:.+]] = arith.cmpi sgt, %[[S32]], %[[C31]] : i32 // CHECK-DAG: %[[RES64:.+]] = arith.select %[[USED]], %[[DRND]], %[[ROUND]] : i64 // Shift and truncate final answer. // CHECK-DAG: %[[SHR:.+]] = arith.shrsi %[[RES64]], %[[S64]] // CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i64 to i32 // CHECK: return %[[TRUNC]] %res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i64, i32, i8) -> i32 return %res : i32 }