//===- Utils.cpp - Utilities to support the Linalg dialect ----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements utilities for the Linalg dialect. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "llvm/ADT/SmallBitVector.h" using namespace mlir; /// Matches a ConstantIndexOp. /// TODO: This should probably just be a general matcher that uses matchConstant /// and checks the operation for an index type. detail::op_matcher mlir::matchConstantIndex() { return detail::op_matcher(); } llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank, ArrayRef shape) { llvm::SmallBitVector dimsToProject(shape.size()); for (unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) { if (shape[pos] == 1) { dimsToProject.set(pos); --rank; } } return dimsToProject; } Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr) { if (auto value = llvm::dyn_cast_if_present(ofr)) return value; auto attr = dyn_cast(llvm::dyn_cast_if_present(ofr)); assert(attr && "expect the op fold result casts to an integer attribute"); return b.create(loc, attr.getValue().getSExtValue()); } Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value) { if (targetType == value.getType()) return value; bool targetIsIndex = targetType.isIndex(); bool valueIsIndex = value.getType().isIndex(); if (targetIsIndex ^ valueIsIndex) return b.create(loc, targetType, value); auto targetIntegerType = dyn_cast(targetType); auto valueIntegerType = dyn_cast(value.getType()); assert(targetIntegerType && valueIntegerType && "unexpected cast between types other than integers and index"); assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness()); if (targetIntegerType.getWidth() > valueIntegerType.getWidth()) return b.create(loc, targetIntegerType, value); return b.create(loc, targetIntegerType, value); } static Value convertScalarToIntDtype(ImplicitLocOpBuilder &b, Value operand, IntegerType toType, bool isUnsigned) { // If operand is floating point, cast directly to the int type. if (isa(operand.getType())) { if (isUnsigned) return b.create(toType, operand); return b.create(toType, operand); } // Cast index operands directly to the int type. if (operand.getType().isIndex()) return b.create(toType, operand); if (auto fromIntType = dyn_cast(operand.getType())) { // Either extend or truncate. if (toType.getWidth() > fromIntType.getWidth()) { if (isUnsigned) return b.create(toType, operand); return b.create(toType, operand); } if (toType.getWidth() < fromIntType.getWidth()) return b.create(toType, operand); return operand; } return {}; } static Value convertScalarToFpDtype(ImplicitLocOpBuilder &b, Value operand, FloatType toType, bool isUnsigned) { // If operand is integer, cast directly to the float type. // Note that it is unclear how to cast from BF16<->FP16. if (isa(operand.getType())) { if (isUnsigned) return b.create(toType, operand); return b.create(toType, operand); } if (auto fromFpTy = dyn_cast(operand.getType())) { if (toType.getWidth() > fromFpTy.getWidth()) return b.create(toType, operand); if (toType.getWidth() < fromFpTy.getWidth()) return b.create(toType, operand); return operand; } return {}; } static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand, ComplexType targetType, bool isUnsigned) { if (auto fromComplexType = dyn_cast(operand.getType())) { if (isa(targetType.getElementType()) && isa(fromComplexType.getElementType())) { Value real = b.create(operand); Value imag = b.create(operand); Type targetETy = targetType.getElementType(); if (targetType.getElementType().getIntOrFloatBitWidth() < fromComplexType.getElementType().getIntOrFloatBitWidth()) { real = b.create(targetETy, real); imag = b.create(targetETy, imag); } else { real = b.create(targetETy, real); imag = b.create(targetETy, imag); } return b.create(targetType, real, imag); } } if (dyn_cast(operand.getType())) { FloatType toFpTy = cast(targetType.getElementType()); auto toBitwidth = toFpTy.getIntOrFloatBitWidth(); Value from = operand; if (from.getType().getIntOrFloatBitWidth() < toBitwidth) { from = b.create(toFpTy, from); } if (from.getType().getIntOrFloatBitWidth() > toBitwidth) { from = b.create(toFpTy, from); } Value zero = b.create( mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy); return b.create(targetType, from, zero); } if (dyn_cast(operand.getType())) { FloatType toFpTy = cast(targetType.getElementType()); Value from = operand; if (isUnsigned) { from = b.create(toFpTy, from); } else { from = b.create(toFpTy, from); } Value zero = b.create( mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy); return b.create(targetType, from, zero); } return {}; } Value mlir::convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast) { if (operand.getType() == toType) return operand; ImplicitLocOpBuilder ib(loc, b); Value result; if (auto intTy = dyn_cast(toType)) { result = convertScalarToIntDtype(ib, operand, intTy, isUnsignedCast); } else if (auto floatTy = dyn_cast(toType)) { result = convertScalarToFpDtype(ib, operand, floatTy, isUnsignedCast); } else if (auto complexTy = dyn_cast(toType)) { result = convertScalarToComplexDtype(ib, operand, complexTy, isUnsignedCast); } if (result) return result; emitWarning(loc) << "could not cast operand of type " << operand.getType() << " to " << toType; return operand; } SmallVector mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, ArrayRef valueOrAttrVec) { return llvm::to_vector<4>( llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { return getValueOrCreateConstantIndexOp(b, loc, value); })); } Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value) { TypedAttr attr; if (isa(type)) { attr = builder.getIntegerAttr(type, value); } else { auto vecTy = cast(type); attr = SplatElementsAttr::get(vecTy, value); } return builder.create(loc, attr); } Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, int64_t value) { unsigned elementBitWidth = 0; if (auto intTy = dyn_cast(type)) elementBitWidth = intTy.getWidth(); else elementBitWidth = cast(type).getElementTypeBitWidth(); return createScalarOrSplatConstant(builder, loc, type, APInt(elementBitWidth, value)); } Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APFloat &value) { if (isa(type)) return builder.createOrFold( loc, type, builder.getFloatAttr(type, value)); TypedAttr splat = SplatElementsAttr::get(cast(type), value); return builder.createOrFold(loc, type, splat); } Value ArithBuilder::_and(Value lhs, Value rhs) { return b.create(loc, lhs, rhs); } Value ArithBuilder::add(Value lhs, Value rhs) { if (isa(lhs.getType())) return b.create(loc, lhs, rhs); return b.create(loc, lhs, rhs); } Value ArithBuilder::sub(Value lhs, Value rhs) { if (isa(lhs.getType())) return b.create(loc, lhs, rhs); return b.create(loc, lhs, rhs); } Value ArithBuilder::mul(Value lhs, Value rhs) { if (isa(lhs.getType())) return b.create(loc, lhs, rhs); return b.create(loc, lhs, rhs); } Value ArithBuilder::sgt(Value lhs, Value rhs) { if (isa(lhs.getType())) return b.create(loc, arith::CmpFPredicate::OGT, lhs, rhs); return b.create(loc, arith::CmpIPredicate::sgt, lhs, rhs); } Value ArithBuilder::slt(Value lhs, Value rhs) { if (isa(lhs.getType())) return b.create(loc, arith::CmpFPredicate::OLT, lhs, rhs); return b.create(loc, arith::CmpIPredicate::slt, lhs, rhs); } Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) { return b.create(loc, cmp, lhs, rhs); }