bolt/deps/llvm-18.1.8/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
2025-02-14 19:21:04 +01:00

473 lines
18 KiB
C++

//===- PWMAFunction.cpp - MLIR PWMAFunction Class -------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/Presburger/PWMAFunction.h"
#include "mlir/Analysis/Presburger/IntegerRelation.h"
#include "mlir/Analysis/Presburger/MPInt.h"
#include "mlir/Analysis/Presburger/PresburgerRelation.h"
#include "mlir/Analysis/Presburger/PresburgerSpace.h"
#include "mlir/Analysis/Presburger/Utils.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <cassert>
#include <optional>
using namespace mlir;
using namespace presburger;
void MultiAffineFunction::assertIsConsistent() const {
assert(space.getNumVars() - space.getNumRangeVars() + 1 ==
output.getNumColumns() &&
"Inconsistent number of output columns");
assert(space.getNumDomainVars() + space.getNumSymbolVars() ==
divs.getNumNonDivs() &&
"Inconsistent number of non-division variables in divs");
assert(space.getNumRangeVars() == output.getNumRows() &&
"Inconsistent number of output rows");
assert(space.getNumLocalVars() == divs.getNumDivs() &&
"Inconsistent number of divisions.");
assert(divs.hasAllReprs() && "All divisions should have a representation");
}
// Return the result of subtracting the two given vectors pointwise.
// The vectors must be of the same size.
// e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5].
static SmallVector<MPInt, 8> subtractExprs(ArrayRef<MPInt> vecA,
ArrayRef<MPInt> vecB) {
assert(vecA.size() == vecB.size() &&
"Cannot subtract vectors of differing lengths!");
SmallVector<MPInt, 8> result;
result.reserve(vecA.size());
for (unsigned i = 0, e = vecA.size(); i < e; ++i)
result.push_back(vecA[i] - vecB[i]);
return result;
}
PresburgerSet PWMAFunction::getDomain() const {
PresburgerSet domain = PresburgerSet::getEmpty(getDomainSpace());
for (const Piece &piece : pieces)
domain.unionInPlace(piece.domain);
return domain;
}
void MultiAffineFunction::print(raw_ostream &os) const {
space.print(os);
os << "Division Representation:\n";
divs.print(os);
os << "Output:\n";
output.print(os);
}
SmallVector<MPInt, 8>
MultiAffineFunction::valueAt(ArrayRef<MPInt> point) const {
assert(point.size() == getNumDomainVars() + getNumSymbolVars() &&
"Point has incorrect dimensionality!");
SmallVector<MPInt, 8> pointHomogenous{llvm::to_vector(point)};
// Get the division values at this point.
SmallVector<std::optional<MPInt>, 8> divValues = divs.divValuesAt(point);
// The given point didn't include the values of the divs which the output is a
// function of; we have computed one possible set of values and use them here.
pointHomogenous.reserve(pointHomogenous.size() + divValues.size());
for (const std::optional<MPInt> &divVal : divValues)
pointHomogenous.push_back(*divVal);
// The matrix `output` has an affine expression in the ith row, corresponding
// to the expression for the ith value in the output vector. The last column
// of the matrix contains the constant term. Let v be the input point with
// a 1 appended at the end. We can see that output * v gives the desired
// output vector.
pointHomogenous.emplace_back(1);
SmallVector<MPInt, 8> result = output.postMultiplyWithColumn(pointHomogenous);
assert(result.size() == getNumOutputs());
return result;
}
bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const {
assert(space.isCompatible(other.space) &&
"Spaces should be compatible for equality check.");
return getAsRelation().isEqual(other.getAsRelation());
}
bool MultiAffineFunction::isEqual(const MultiAffineFunction &other,
const IntegerPolyhedron &domain) const {
assert(space.isCompatible(other.space) &&
"Spaces should be compatible for equality check.");
IntegerRelation restrictedThis = getAsRelation();
restrictedThis.intersectDomain(domain);
IntegerRelation restrictedOther = other.getAsRelation();
restrictedOther.intersectDomain(domain);
return restrictedThis.isEqual(restrictedOther);
}
bool MultiAffineFunction::isEqual(const MultiAffineFunction &other,
const PresburgerSet &domain) const {
assert(space.isCompatible(other.space) &&
"Spaces should be compatible for equality check.");
return llvm::all_of(domain.getAllDisjuncts(),
[&](const IntegerRelation &disjunct) {
return isEqual(other, IntegerPolyhedron(disjunct));
});
}
void MultiAffineFunction::removeOutputs(unsigned start, unsigned end) {
assert(end <= getNumOutputs() && "Invalid range");
if (start >= end)
return;
space.removeVarRange(VarKind::Range, start, end);
output.removeRows(start, end - start);
}
void MultiAffineFunction::mergeDivs(MultiAffineFunction &other) {
assert(space.isCompatible(other.space) && "Functions should be compatible");
unsigned nDivs = getNumDivs();
unsigned divOffset = divs.getDivOffset();
other.divs.insertDiv(0, nDivs);
SmallVector<MPInt, 8> div(other.divs.getNumVars() + 1);
for (unsigned i = 0; i < nDivs; ++i) {
// Zero fill.
std::fill(div.begin(), div.end(), 0);
// Fill div with dividend from `divs`. Do not fill the constant.
std::copy(divs.getDividend(i).begin(), divs.getDividend(i).end() - 1,
div.begin());
// Fill constant.
div.back() = divs.getDividend(i).back();
other.divs.setDiv(i, div, divs.getDenom(i));
}
other.space.insertVar(VarKind::Local, 0, nDivs);
other.output.insertColumns(divOffset, nDivs);
auto merge = [&](unsigned i, unsigned j) {
// We only merge from local at pos j to local at pos i, where j > i.
if (i >= j)
return false;
// If i < nDivs, we are trying to merge duplicate divs in `this`. Since we
// do not want to merge duplicates in `this`, we ignore this call.
if (j < nDivs)
return false;
// Merge things in space and output.
other.space.removeVarRange(VarKind::Local, j, j + 1);
other.output.addToColumn(divOffset + i, divOffset + j, 1);
other.output.removeColumn(divOffset + j);
return true;
};
other.divs.removeDuplicateDivs(merge);
unsigned newDivs = other.divs.getNumDivs() - nDivs;
space.insertVar(VarKind::Local, nDivs, newDivs);
output.insertColumns(divOffset + nDivs, newDivs);
divs = other.divs;
// Check consistency.
assertIsConsistent();
other.assertIsConsistent();
}
PresburgerSet
MultiAffineFunction::getLexSet(OrderingKind comp,
const MultiAffineFunction &other) const {
assert(getSpace().isCompatible(other.getSpace()) &&
"Output space of funcs should be compatible");
// Create copies of functions and merge their local space.
MultiAffineFunction funcA = *this;
MultiAffineFunction funcB = other;
funcA.mergeDivs(funcB);
// We first create the set `result`, corresponding to the set where output
// of funcA is lexicographically larger/smaller than funcB. This is done by
// creating a PresburgerSet with the following constraints:
//
// (outA[0] > outB[0]) U
// (outA[0] = outB[0], outA[1] > outA[1]) U
// (outA[0] = outB[0], outA[1] = outA[1], outA[2] > outA[2]) U
// ...
// (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] > outB[n-1])
//
// where `n` is the number of outputs.
// If `lexMin` is set, the complement inequality is used:
//
// (outA[0] < outB[0]) U
// (outA[0] = outB[0], outA[1] < outA[1]) U
// (outA[0] = outB[0], outA[1] = outA[1], outA[2] < outA[2]) U
// ...
// (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] < outB[n-1])
PresburgerSpace resultSpace = funcA.getDomainSpace();
PresburgerSet result =
PresburgerSet::getEmpty(resultSpace.getSpaceWithoutLocals());
IntegerPolyhedron levelSet(
/*numReservedInequalities=*/1 + 2 * resultSpace.getNumLocalVars(),
/*numReservedEqualities=*/funcA.getNumOutputs(),
/*numReservedCols=*/resultSpace.getNumVars() + 1, resultSpace);
// Add division inequalities to `levelSet`.
for (unsigned i = 0, e = funcA.getNumDivs(); i < e; ++i) {
levelSet.addInequality(getDivUpperBound(funcA.divs.getDividend(i),
funcA.divs.getDenom(i),
funcA.divs.getDivOffset() + i));
levelSet.addInequality(getDivLowerBound(funcA.divs.getDividend(i),
funcA.divs.getDenom(i),
funcA.divs.getDivOffset() + i));
}
for (unsigned level = 0; level < funcA.getNumOutputs(); ++level) {
// Create the expression `outA - outB` for this level.
SmallVector<MPInt, 8> subExpr =
subtractExprs(funcA.getOutputExpr(level), funcB.getOutputExpr(level));
// TODO: Implement all comparison cases.
switch (comp) {
case OrderingKind::LT:
// For less than, we add an upper bound of -1:
// outA - outB <= -1
// outA <= outB - 1
// outA < outB
levelSet.addBound(BoundType::UB, subExpr, MPInt(-1));
break;
case OrderingKind::GT:
// For greater than, we add a lower bound of 1:
// outA - outB >= 1
// outA > outB + 1
// outA > outB
levelSet.addBound(BoundType::LB, subExpr, MPInt(1));
break;
case OrderingKind::GE:
case OrderingKind::LE:
case OrderingKind::EQ:
case OrderingKind::NE:
assert(false && "Not implemented case");
}
// Union the set with the result.
result.unionInPlace(levelSet);
// The last inequality in `levelSet` is the bound we inserted. We remove
// that for next iteration.
levelSet.removeInequality(levelSet.getNumInequalities() - 1);
// Add equality `outA - outB == 0` for this level for next iteration.
levelSet.addEquality(subExpr);
}
return result;
}
/// Two PWMAFunctions are equal if they have the same dimensionalities,
/// the same domain, and take the same value at every point in the domain.
bool PWMAFunction::isEqual(const PWMAFunction &other) const {
if (!space.isCompatible(other.space))
return false;
if (!this->getDomain().isEqual(other.getDomain()))
return false;
// Check if, whenever the domains of a piece of `this` and a piece of `other`
// overlap, they take the same output value. If `this` and `other` have the
// same domain (checked above), then this check passes iff the two functions
// have the same output at every point in the domain.
return llvm::all_of(this->pieces, [&other](const Piece &pieceA) {
return llvm::all_of(other.pieces, [&pieceA](const Piece &pieceB) {
PresburgerSet commonDomain = pieceA.domain.intersect(pieceB.domain);
return pieceA.output.isEqual(pieceB.output, commonDomain);
});
});
}
void PWMAFunction::addPiece(const Piece &piece) {
assert(piece.isConsistent() && "Piece should be consistent");
assert(piece.domain.intersect(getDomain()).isIntegerEmpty() &&
"Piece should be disjoint from the function");
pieces.push_back(piece);
}
void PWMAFunction::print(raw_ostream &os) const {
space.print(os);
os << getNumPieces() << " pieces:\n";
for (const Piece &piece : pieces) {
os << "Domain of piece:\n";
piece.domain.print(os);
os << "Output of piece\n";
piece.output.print(os);
}
}
void PWMAFunction::dump() const { print(llvm::errs()); }
PWMAFunction PWMAFunction::unionFunction(
const PWMAFunction &func,
llvm::function_ref<PresburgerSet(Piece maf1, Piece maf2)> tiebreak) const {
assert(getNumOutputs() == func.getNumOutputs() &&
"Ranges of functions should be same.");
assert(getSpace().isCompatible(func.getSpace()) &&
"Space is not compatible.");
// The algorithm used here is as follows:
// - Add the output of pieceB for the part of the domain where both pieceA and
// pieceB are defined, and `tiebreak` chooses the output of pieceB.
// - Add the output of pieceA, where pieceB is not defined or `tiebreak`
// chooses
// pieceA over pieceB.
// - Add the output of pieceB, where pieceA is not defined.
// Add parts of the common domain where pieceB's output is used. Also
// add all the parts where pieceA's output is used, both common and
// non-common.
PWMAFunction result(getSpace());
for (const Piece &pieceA : pieces) {
PresburgerSet dom(pieceA.domain);
for (const Piece &pieceB : func.pieces) {
PresburgerSet better = tiebreak(pieceB, pieceA);
// Add the output of pieceB, where it is better than output of pieceA.
// The disjuncts in "better" will be disjoint as tiebreak should gurantee
// that.
result.addPiece({better, pieceB.output});
dom = dom.subtract(better);
}
// Add output of pieceA, where it is better than pieceB, or pieceB is not
// defined.
//
// `dom` here is guranteed to be disjoint from already added pieces
// because the pieces added before are either:
// - Subsets of the domain of other MAFs in `this`, which are guranteed
// to be disjoint from `dom`, or
// - They are one of the pieces added for `pieceB`, and we have been
// subtracting all such pieces from `dom`, so `dom` is disjoint from those
// pieces as well.
result.addPiece({dom, pieceA.output});
}
// Add parts of pieceB which are not shared with pieceA.
PresburgerSet dom = getDomain();
for (const Piece &pieceB : func.pieces)
result.addPiece({pieceB.domain.subtract(dom), pieceB.output});
return result;
}
/// A tiebreak function which breaks ties by comparing the outputs
/// lexicographically based on the given comparison operator.
/// This is templated since it is passed as a lambda.
template <OrderingKind comp>
static PresburgerSet tiebreakLex(const PWMAFunction::Piece &pieceA,
const PWMAFunction::Piece &pieceB) {
PresburgerSet result = pieceA.output.getLexSet(comp, pieceB.output);
result = result.intersect(pieceA.domain).intersect(pieceB.domain);
return result;
}
PWMAFunction PWMAFunction::unionLexMin(const PWMAFunction &func) {
return unionFunction(func, tiebreakLex</*comp=*/OrderingKind::LT>);
}
PWMAFunction PWMAFunction::unionLexMax(const PWMAFunction &func) {
return unionFunction(func, tiebreakLex</*comp=*/OrderingKind::GT>);
}
void MultiAffineFunction::subtract(const MultiAffineFunction &other) {
assert(space.isCompatible(other.space) &&
"Spaces should be compatible for subtraction.");
MultiAffineFunction copyOther = other;
mergeDivs(copyOther);
for (unsigned i = 0, e = getNumOutputs(); i < e; ++i)
output.addToRow(i, copyOther.getOutputExpr(i), MPInt(-1));
// Check consistency.
assertIsConsistent();
}
/// Adds division constraints corresponding to local variables, given a
/// relation and division representations of the local variables in the
/// relation.
static void addDivisionConstraints(IntegerRelation &rel,
const DivisionRepr &divs) {
assert(divs.hasAllReprs() &&
"All divisions in divs should have a representation");
assert(rel.getNumVars() == divs.getNumVars() &&
"Relation and divs should have the same number of vars");
assert(rel.getNumLocalVars() == divs.getNumDivs() &&
"Relation and divs should have the same number of local vars");
for (unsigned i = 0, e = divs.getNumDivs(); i < e; ++i) {
rel.addInequality(getDivUpperBound(divs.getDividend(i), divs.getDenom(i),
divs.getDivOffset() + i));
rel.addInequality(getDivLowerBound(divs.getDividend(i), divs.getDenom(i),
divs.getDivOffset() + i));
}
}
IntegerRelation MultiAffineFunction::getAsRelation() const {
// Create a relation corressponding to the input space plus the divisions
// used in outputs.
IntegerRelation result(PresburgerSpace::getRelationSpace(
space.getNumDomainVars(), 0, space.getNumSymbolVars(),
space.getNumLocalVars()));
// Add division constraints corresponding to divisions used in outputs.
addDivisionConstraints(result, divs);
// The outputs are represented as range variables in the relation. We add
// range variables for the outputs.
result.insertVar(VarKind::Range, 0, getNumOutputs());
// Add equalities such that the i^th range variable is equal to the i^th
// output expression.
SmallVector<MPInt, 8> eq(result.getNumCols());
for (unsigned i = 0, e = getNumOutputs(); i < e; ++i) {
// TODO: Add functions to get VarKind offsets in output in MAF and use them
// here.
// The output expression does not contain range variables, while the
// equality does. So, we need to copy all variables and mark all range
// variables as 0 in the equality.
ArrayRef<MPInt> expr = getOutputExpr(i);
// Copy domain variables in `expr` to domain variables in `eq`.
std::copy(expr.begin(), expr.begin() + getNumDomainVars(), eq.begin());
// Fill the range variables in `eq` as zero.
std::fill(eq.begin() + result.getVarKindOffset(VarKind::Range),
eq.begin() + result.getVarKindEnd(VarKind::Range), 0);
// Copy remaining variables in `expr` to the remaining variables in `eq`.
std::copy(expr.begin() + getNumDomainVars(), expr.end(),
eq.begin() + result.getVarKindEnd(VarKind::Range));
// Set the i^th range var to -1 in `eq` to equate the output expression to
// this range var.
eq[result.getVarKindOffset(VarKind::Range) + i] = -1;
// Add the equality `rangeVar_i = output[i]`.
result.addEquality(eq);
}
return result;
}
void PWMAFunction::removeOutputs(unsigned start, unsigned end) {
space.removeVarRange(VarKind::Range, start, end);
for (Piece &piece : pieces)
piece.output.removeOutputs(start, end);
}
std::optional<SmallVector<MPInt, 8>>
PWMAFunction::valueAt(ArrayRef<MPInt> point) const {
assert(point.size() == getNumDomainVars() + getNumSymbolVars());
for (const Piece &piece : pieces)
if (piece.domain.containsPoint(point))
return piece.output.valueAt(point);
return std::nullopt;
}