//===- TosaToArith.cpp - Lowering Tosa to Arith Dialect -------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // These rewriters lower from the Tosa to the Arith dialect. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/TosaToArith/TosaToArith.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; using namespace tosa; namespace { class ConstOpConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::ConstOp op, PatternRewriter &rewriter) const final { rewriter.replaceOpWithNewOp(op, op.getValue()); return success(); } }; Type matchContainerType(Type element, Type container) { if (auto shapedTy = dyn_cast(container)) return shapedTy.clone(element); return element; } TypedAttr getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) { if (auto shapedTy = dyn_cast(type)) { Type eTy = shapedTy.getElementType(); APInt valueInt(eTy.getIntOrFloatBitWidth(), value); return DenseIntElementsAttr::get(shapedTy, valueInt); } return rewriter.getIntegerAttr(type, value); } Value getConstantValue(Location loc, Type type, int64_t value, PatternRewriter &rewriter) { return rewriter.create( loc, getConstantAttr(type, value, rewriter)); } // This converts the TOSA ApplyScale operator to a set of arithmetic ops, // using 64-bit operations to perform the necessary multiply, bias, and shift. class ApplyScaleGenericOpConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::ApplyScaleOp op, PatternRewriter &rewriter) const final { Location loc = op.getLoc(); Value value = op.getValue(); Value multiplier32 = op.getMultiplier(); Type resultTy = op.getType(); Type valueTy = value.getType(); Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy); Type i64Ty = matchContainerType(rewriter.getI64Type(), resultTy); Value zero = getConstantValue(loc, valueTy, 0, rewriter); Value one64 = getConstantValue(loc, i64Ty, 1, rewriter); Value thirtyOne32 = getConstantValue(loc, i32Ty, 31, rewriter); Value shift32 = rewriter.create(loc, i32Ty, op.getShift()); // Compute the multiplication in 64-bits then select the high / low parts. Value value64 = value; if (getElementTypeOrSelf(valueTy) != rewriter.getI64Type()) value64 = rewriter.create(loc, i64Ty, value); Value multiplier64 = rewriter.create(loc, i64Ty, multiplier32); Value multiply64 = rewriter.create(loc, value64, multiplier64); // Apply normal rounding. Value shift64 = rewriter.create(loc, i64Ty, shift32); Value round = rewriter.create(loc, one64, shift64); round = rewriter.create(loc, round, one64); multiply64 = rewriter.create(loc, multiply64, round); // Apply double rounding if necessary. if (op.getDoubleRound()) { int64_t roundInt = 1 << 30; Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter); Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter); Value positive = rewriter.create( loc, arith::CmpIPredicate::sge, value, zero); Value dir = rewriter.create(loc, positive, roundUp, roundDown); Value val = rewriter.create(loc, dir, multiply64); Value valid = rewriter.create( loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32); multiply64 = rewriter.create(loc, valid, val, multiply64); } Value result64 = rewriter.create(loc, multiply64, shift64); Value result32 = rewriter.create(loc, i32Ty, result64); rewriter.replaceOp(op, result32); return success(); } }; class ApplyScale32BitOpConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::ApplyScaleOp op, PatternRewriter &rewriter) const final { Location loc = op.getLoc(); Type resultTy = op.getType(); Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy); Value value = op.getValue(); if (getElementTypeOrSelf(value.getType()).getIntOrFloatBitWidth() > 32) { return failure(); } Value value32 = op.getValue(); Value multiplier32 = op.getMultiplier(); Value shift32 = rewriter.create(loc, i32Ty, op.getShift()); // Constants used during the scaling operation. Value zero32 = getConstantValue(loc, i32Ty, 0, rewriter); Value one32 = getConstantValue(loc, i32Ty, 1, rewriter); Value two32 = getConstantValue(loc, i32Ty, 2, rewriter); Value thirty32 = getConstantValue(loc, i32Ty, 30, rewriter); Value thirtyTwo32 = getConstantValue(loc, i32Ty, 32, rewriter); // Compute the multiplication in 64-bits then select the high / low parts. // Grab out the high/low of the computation auto value64 = rewriter.create(loc, value32, multiplier32); Value low32 = value64.getLow(); Value high32 = value64.getHigh(); // Determine the direction and amount to shift the high bits. Value shiftOver32 = rewriter.create( loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32); Value roundHighBits = rewriter.create( loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32); Value shiftHighL = rewriter.create(loc, thirtyTwo32, shift32); Value shiftHighR = rewriter.create(loc, shift32, thirtyTwo32); shiftHighL = rewriter.create(loc, shiftOver32, zero32, shiftHighL); shiftHighR = rewriter.create(loc, shiftOver32, shiftHighR, zero32); // Conditionally perform our double round. if (op.getDoubleRound()) { Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter); Value valuePositive = rewriter.create( loc, arith::CmpIPredicate::sge, value32, zero32); Value roundDir = rewriter.create(loc, valuePositive, one32, negOne32); roundDir = rewriter.create(loc, shiftOver32, roundDir, zero32); Value shiftLow = rewriter.create(loc, low32, thirty32); Value rounded = rewriter.create(loc, shiftLow, roundDir); Value carry = rewriter.create(loc, rounded, two32); Value shiftRound = rewriter.create(loc, roundDir, thirty32); low32 = rewriter.create(loc, low32, shiftRound); high32 = rewriter.create(loc, high32, carry); } // Conditionally apply rounding in the low bits. { Value shiftSubOne = rewriter.create(loc, shift32, one32); Value roundBit = rewriter.create(loc, one32, shiftSubOne); roundBit = rewriter.create(loc, roundHighBits, zero32, roundBit); Value newLow32 = rewriter.create(loc, low32, roundBit); Value wasRounded = rewriter.create( loc, arith::CmpIPredicate::ugt, low32, newLow32); low32 = newLow32; Value rounded32 = rewriter.create(loc, i32Ty, wasRounded); high32 = rewriter.create(loc, high32, rounded32); } // Conditionally apply rounding in the high bits. { Value shiftSubOne = rewriter.create(loc, shiftHighR, one32); Value roundBit = rewriter.create(loc, one32, shiftSubOne); roundBit = rewriter.create(loc, roundHighBits, roundBit, zero32); high32 = rewriter.create(loc, high32, roundBit); } // Combine the correct high/low bits into the final rescale result. high32 = rewriter.create(loc, high32, shiftHighL); high32 = rewriter.create(loc, high32, shiftHighR); low32 = rewriter.create(loc, low32, shift32); low32 = rewriter.create(loc, shiftOver32, zero32, low32); // Apply the rounding behavior and shift to the final alignment. Value result = rewriter.create(loc, low32, high32); // Truncate if necessary. if (!getElementTypeOrSelf(resultTy).isInteger(32)) { result = rewriter.create(loc, resultTy, result); } rewriter.replaceOp(op, result); return success(); } }; } // namespace void mlir::tosa::populateTosaToArithConversionPatterns( RewritePatternSet *patterns) { patterns->add(patterns->getContext()); } void mlir::tosa::populateTosaRescaleToArithConversionPatterns( RewritePatternSet *patterns, bool include32Bit) { patterns->add(patterns->getContext(), 100); if (include32Bit) { patterns->add(patterns->getContext(), 200); } }