//===- SlowMPInt.cpp - MLIR SlowMPInt Class -------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Analysis/Presburger/SlowMPInt.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/Support/raw_ostream.h" #include #include #include #include using namespace mlir; using namespace presburger; using namespace detail; SlowMPInt::SlowMPInt(int64_t val) : val(64, val, /*isSigned=*/true) {} SlowMPInt::SlowMPInt() : SlowMPInt(0) {} SlowMPInt::SlowMPInt(const llvm::APInt &val) : val(val) {} SlowMPInt &SlowMPInt::operator=(int64_t val) { return *this = SlowMPInt(val); } SlowMPInt::operator int64_t() const { return val.getSExtValue(); } llvm::hash_code detail::hash_value(const SlowMPInt &x) { return hash_value(x.val); } /// --------------------------------------------------------------------------- /// Printing. /// --------------------------------------------------------------------------- void SlowMPInt::print(llvm::raw_ostream &os) const { os << val; } void SlowMPInt::dump() const { print(llvm::errs()); } llvm::raw_ostream &detail::operator<<(llvm::raw_ostream &os, const SlowMPInt &x) { x.print(os); return os; } /// --------------------------------------------------------------------------- /// Convenience operator overloads for int64_t. /// --------------------------------------------------------------------------- SlowMPInt &detail::operator+=(SlowMPInt &a, int64_t b) { return a += SlowMPInt(b); } SlowMPInt &detail::operator-=(SlowMPInt &a, int64_t b) { return a -= SlowMPInt(b); } SlowMPInt &detail::operator*=(SlowMPInt &a, int64_t b) { return a *= SlowMPInt(b); } SlowMPInt &detail::operator/=(SlowMPInt &a, int64_t b) { return a /= SlowMPInt(b); } SlowMPInt &detail::operator%=(SlowMPInt &a, int64_t b) { return a %= SlowMPInt(b); } bool detail::operator==(const SlowMPInt &a, int64_t b) { return a == SlowMPInt(b); } bool detail::operator!=(const SlowMPInt &a, int64_t b) { return a != SlowMPInt(b); } bool detail::operator>(const SlowMPInt &a, int64_t b) { return a > SlowMPInt(b); } bool detail::operator<(const SlowMPInt &a, int64_t b) { return a < SlowMPInt(b); } bool detail::operator<=(const SlowMPInt &a, int64_t b) { return a <= SlowMPInt(b); } bool detail::operator>=(const SlowMPInt &a, int64_t b) { return a >= SlowMPInt(b); } SlowMPInt detail::operator+(const SlowMPInt &a, int64_t b) { return a + SlowMPInt(b); } SlowMPInt detail::operator-(const SlowMPInt &a, int64_t b) { return a - SlowMPInt(b); } SlowMPInt detail::operator*(const SlowMPInt &a, int64_t b) { return a * SlowMPInt(b); } SlowMPInt detail::operator/(const SlowMPInt &a, int64_t b) { return a / SlowMPInt(b); } SlowMPInt detail::operator%(const SlowMPInt &a, int64_t b) { return a % SlowMPInt(b); } bool detail::operator==(int64_t a, const SlowMPInt &b) { return SlowMPInt(a) == b; } bool detail::operator!=(int64_t a, const SlowMPInt &b) { return SlowMPInt(a) != b; } bool detail::operator>(int64_t a, const SlowMPInt &b) { return SlowMPInt(a) > b; } bool detail::operator<(int64_t a, const SlowMPInt &b) { return SlowMPInt(a) < b; } bool detail::operator<=(int64_t a, const SlowMPInt &b) { return SlowMPInt(a) <= b; } bool detail::operator>=(int64_t a, const SlowMPInt &b) { return SlowMPInt(a) >= b; } SlowMPInt detail::operator+(int64_t a, const SlowMPInt &b) { return SlowMPInt(a) + b; } SlowMPInt detail::operator-(int64_t a, const SlowMPInt &b) { return SlowMPInt(a) - b; } SlowMPInt detail::operator*(int64_t a, const SlowMPInt &b) { return SlowMPInt(a) * b; } SlowMPInt detail::operator/(int64_t a, const SlowMPInt &b) { return SlowMPInt(a) / b; } SlowMPInt detail::operator%(int64_t a, const SlowMPInt &b) { return SlowMPInt(a) % b; } static unsigned getMaxWidth(const APInt &a, const APInt &b) { return std::max(a.getBitWidth(), b.getBitWidth()); } /// --------------------------------------------------------------------------- /// Comparison operators. /// --------------------------------------------------------------------------- // TODO: consider instead making APInt::compare available and using that. bool SlowMPInt::operator==(const SlowMPInt &o) const { unsigned width = getMaxWidth(val, o.val); return val.sext(width) == o.val.sext(width); } bool SlowMPInt::operator!=(const SlowMPInt &o) const { unsigned width = getMaxWidth(val, o.val); return val.sext(width) != o.val.sext(width); } bool SlowMPInt::operator>(const SlowMPInt &o) const { unsigned width = getMaxWidth(val, o.val); return val.sext(width).sgt(o.val.sext(width)); } bool SlowMPInt::operator<(const SlowMPInt &o) const { unsigned width = getMaxWidth(val, o.val); return val.sext(width).slt(o.val.sext(width)); } bool SlowMPInt::operator<=(const SlowMPInt &o) const { unsigned width = getMaxWidth(val, o.val); return val.sext(width).sle(o.val.sext(width)); } bool SlowMPInt::operator>=(const SlowMPInt &o) const { unsigned width = getMaxWidth(val, o.val); return val.sext(width).sge(o.val.sext(width)); } /// --------------------------------------------------------------------------- /// Arithmetic operators. /// --------------------------------------------------------------------------- /// Bring a and b to have the same width and then call op(a, b, overflow). /// If the overflow bit becomes set, resize a and b to double the width and /// call op(a, b, overflow), returning its result. The operation with double /// widths should not also overflow. APInt runOpWithExpandOnOverflow( const APInt &a, const APInt &b, llvm::function_ref op) { bool overflow; unsigned width = getMaxWidth(a, b); APInt ret = op(a.sext(width), b.sext(width), overflow); if (!overflow) return ret; width *= 2; ret = op(a.sext(width), b.sext(width), overflow); assert(!overflow && "double width should be sufficient to avoid overflow!"); return ret; } SlowMPInt SlowMPInt::operator+(const SlowMPInt &o) const { return SlowMPInt( runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::sadd_ov))); } SlowMPInt SlowMPInt::operator-(const SlowMPInt &o) const { return SlowMPInt( runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::ssub_ov))); } SlowMPInt SlowMPInt::operator*(const SlowMPInt &o) const { return SlowMPInt( runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::smul_ov))); } SlowMPInt SlowMPInt::operator/(const SlowMPInt &o) const { return SlowMPInt( runOpWithExpandOnOverflow(val, o.val, std::mem_fn(&APInt::sdiv_ov))); } SlowMPInt detail::abs(const SlowMPInt &x) { return x >= 0 ? x : -x; } SlowMPInt detail::ceilDiv(const SlowMPInt &lhs, const SlowMPInt &rhs) { if (rhs == -1) return -lhs; unsigned width = getMaxWidth(lhs.val, rhs.val); return SlowMPInt(llvm::APIntOps::RoundingSDiv( lhs.val.sext(width), rhs.val.sext(width), APInt::Rounding::UP)); } SlowMPInt detail::floorDiv(const SlowMPInt &lhs, const SlowMPInt &rhs) { if (rhs == -1) return -lhs; unsigned width = getMaxWidth(lhs.val, rhs.val); return SlowMPInt(llvm::APIntOps::RoundingSDiv( lhs.val.sext(width), rhs.val.sext(width), APInt::Rounding::DOWN)); } // The RHS is always expected to be positive, and the result /// is always non-negative. SlowMPInt detail::mod(const SlowMPInt &lhs, const SlowMPInt &rhs) { assert(rhs >= 1 && "mod is only supported for positive divisors!"); return lhs % rhs < 0 ? lhs % rhs + rhs : lhs % rhs; } SlowMPInt detail::gcd(const SlowMPInt &a, const SlowMPInt &b) { assert(a >= 0 && b >= 0 && "operands must be non-negative!"); unsigned width = getMaxWidth(a.val, b.val); return SlowMPInt(llvm::APIntOps::GreatestCommonDivisor(a.val.sext(width), b.val.sext(width))); } /// Returns the least common multiple of 'a' and 'b'. SlowMPInt detail::lcm(const SlowMPInt &a, const SlowMPInt &b) { SlowMPInt x = abs(a); SlowMPInt y = abs(b); return (x * y) / gcd(x, y); } /// This operation cannot overflow. SlowMPInt SlowMPInt::operator%(const SlowMPInt &o) const { unsigned width = std::max(val.getBitWidth(), o.val.getBitWidth()); return SlowMPInt(val.sext(width).srem(o.val.sext(width))); } SlowMPInt SlowMPInt::operator-() const { if (val.isMinSignedValue()) { /// Overflow only occurs when the value is the minimum possible value. APInt ret = val.sext(2 * val.getBitWidth()); return SlowMPInt(-ret); } return SlowMPInt(-val); } /// --------------------------------------------------------------------------- /// Assignment operators, preincrement, predecrement. /// --------------------------------------------------------------------------- SlowMPInt &SlowMPInt::operator+=(const SlowMPInt &o) { *this = *this + o; return *this; } SlowMPInt &SlowMPInt::operator-=(const SlowMPInt &o) { *this = *this - o; return *this; } SlowMPInt &SlowMPInt::operator*=(const SlowMPInt &o) { *this = *this * o; return *this; } SlowMPInt &SlowMPInt::operator/=(const SlowMPInt &o) { *this = *this / o; return *this; } SlowMPInt &SlowMPInt::operator%=(const SlowMPInt &o) { *this = *this % o; return *this; } SlowMPInt &SlowMPInt::operator++() { *this += 1; return *this; } SlowMPInt &SlowMPInt::operator--() { *this -= 1; return *this; }