//===-- lib/Evaluate/fold-real.cpp ----------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "fold-implementation.h" #include "fold-matmul.h" #include "fold-reduction.h" namespace Fortran::evaluate { template static Expr FoldTransformationalBessel( FunctionRef &&funcRef, FoldingContext &context) { CHECK(funcRef.arguments().size() == 3); /// Bessel runtime functions use `int` integer arguments. Convert integer /// arguments to Int4, any overflow error will be reported during the /// conversion folding. using Int4 = Type; if (auto args{ GetConstantArguments(context, funcRef.arguments())}) { const std::string &name{std::get(funcRef.proc().u).name}; if (auto elementalBessel{GetHostRuntimeWrapper(name)}) { std::vector> results; int n1{static_cast( std::get<0>(*args)->GetScalarValue().value().ToInt64())}; int n2{static_cast( std::get<1>(*args)->GetScalarValue().value().ToInt64())}; Scalar x{std::get<2>(*args)->GetScalarValue().value()}; for (int i{n1}; i <= n2; ++i) { results.emplace_back((*elementalBessel)(context, Scalar{i}, x)); } return Expr{Constant{ std::move(results), ConstantSubscripts{std::max(n2 - n1 + 1, 0)}}}; } else { context.messages().Say( "%s(integer(kind=4), real(kind=%d)) cannot be folded on host"_warn_en_US, name, T::kind); } } return Expr{std::move(funcRef)}; } // NORM2 template class Norm2Accumulator { using T = Type; public: Norm2Accumulator( const Constant &array, const Constant &maxAbs, Rounding rounding) : array_{array}, maxAbs_{maxAbs}, rounding_{rounding} {}; void operator()( Scalar &element, const ConstantSubscripts &at, bool /*first*/) { // Kahan summation of scaled elements: // Naively, // NORM2(A(:)) = SQRT(SUM(A(:)**2)) // For any T > 0, we have mathematically // SQRT(SUM(A(:)**2)) // = SQRT(T**2 * (SUM(A(:)**2) / T**2)) // = SQRT(T**2 * SUM(A(:)**2 / T**2)) // = SQRT(T**2 * SUM((A(:)/T)**2)) // = SQRT(T**2) * SQRT(SUM((A(:)/T)**2)) // = T * SQRT(SUM((A(:)/T)**2)) // By letting T = MAXVAL(ABS(A)), we ensure that // ALL(ABS(A(:)/T) <= 1), so ALL((A(:)/T)**2 <= 1), and the SUM will // not overflow unless absolutely necessary. auto scale{maxAbs_.At(maxAbsAt_)}; if (scale.IsZero()) { // Maximum value is zero, and so will the result be. // Avoid division by zero below. element = scale; } else { auto item{array_.At(at)}; auto scaled{item.Divide(scale).value}; auto square{scaled.Multiply(scaled).value}; auto next{square.Add(correction_, rounding_)}; overflow_ |= next.flags.test(RealFlag::Overflow); auto sum{element.Add(next.value, rounding_)}; overflow_ |= sum.flags.test(RealFlag::Overflow); correction_ = sum.value.Subtract(element, rounding_) .value.Subtract(next.value, rounding_) .value; element = sum.value; } } bool overflow() const { return overflow_; } void Done(Scalar &result) { // result+correction == SUM((data(:)/maxAbs)**2) // result = maxAbs * SQRT(result+correction) auto corrected{result.Add(correction_, rounding_)}; overflow_ |= corrected.flags.test(RealFlag::Overflow); correction_ = Scalar{}; auto root{corrected.value.SQRT().value}; auto product{root.Multiply(maxAbs_.At(maxAbsAt_))}; maxAbs_.IncrementSubscripts(maxAbsAt_); overflow_ |= product.flags.test(RealFlag::Overflow); result = product.value; } private: const Constant &array_; const Constant &maxAbs_; const Rounding rounding_; bool overflow_{false}; Scalar correction_{}; ConstantSubscripts maxAbsAt_{maxAbs_.lbounds()}; }; template static Expr> FoldNorm2(FoldingContext &context, FunctionRef> &&funcRef) { using T = Type; using Element = typename Constant::Element; std::optional dim; if (std::optional> arrayAndMask{ ProcessReductionArgs(context, funcRef.arguments(), dim, /*X=*/0, /*DIM=*/1)}) { MaxvalMinvalAccumulator maxAbsAccumulator{ RelationalOperator::GT, context, arrayAndMask->array}; const Element identity{}; Constant maxAbs{DoReduction(arrayAndMask->array, arrayAndMask->mask, dim, identity, maxAbsAccumulator)}; Norm2Accumulator norm2Accumulator{arrayAndMask->array, maxAbs, context.targetCharacteristics().roundingMode()}; Constant result{DoReduction(arrayAndMask->array, arrayAndMask->mask, dim, identity, norm2Accumulator)}; if (norm2Accumulator.overflow()) { context.messages().Say( "NORM2() of REAL(%d) data overflowed"_warn_en_US, KIND); } return Expr{std::move(result)}; } return Expr{std::move(funcRef)}; } template Expr> FoldIntrinsicFunction( FoldingContext &context, FunctionRef> &&funcRef) { using T = Type; using ComplexT = Type; using Int4 = Type; ActualArguments &args{funcRef.arguments()}; auto *intrinsic{std::get_if(&funcRef.proc().u)}; CHECK(intrinsic); std::string name{intrinsic->name}; if (name == "acos" || name == "acosh" || name == "asin" || name == "asinh" || (name == "atan" && args.size() == 1) || name == "atanh" || name == "bessel_j0" || name == "bessel_j1" || name == "bessel_y0" || name == "bessel_y1" || name == "cos" || name == "cosh" || name == "erf" || name == "erfc" || name == "erfc_scaled" || name == "exp" || name == "gamma" || name == "log" || name == "log10" || name == "log_gamma" || name == "sin" || name == "sinh" || name == "tan" || name == "tanh") { CHECK(args.size() == 1); if (auto callable{GetHostRuntimeWrapper(name)}) { return FoldElementalIntrinsic( context, std::move(funcRef), *callable); } else { context.messages().Say( "%s(real(kind=%d)) cannot be folded on host"_warn_en_US, name, KIND); } } else if (name == "amax0" || name == "amin0" || name == "amin1" || name == "amax1" || name == "dmin1" || name == "dmax1") { return RewriteSpecificMINorMAX(context, std::move(funcRef)); } else if (name == "atan" || name == "atan2") { std::string localName{name == "atan" ? "atan2" : name}; CHECK(args.size() == 2); if (auto callable{GetHostRuntimeWrapper(localName)}) { return FoldElementalIntrinsic( context, std::move(funcRef), *callable); } else { context.messages().Say( "%s(real(kind=%d), real(kind%d)) cannot be folded on host"_warn_en_US, name, KIND, KIND); } } else if (name == "bessel_jn" || name == "bessel_yn") { if (args.size() == 2) { // elemental // runtime functions use int arg if (auto callable{GetHostRuntimeWrapper(name)}) { return FoldElementalIntrinsic( context, std::move(funcRef), *callable); } else { context.messages().Say( "%s(integer(kind=4), real(kind=%d)) cannot be folded on host"_warn_en_US, name, KIND); } } else { return FoldTransformationalBessel(std::move(funcRef), context); } } else if (name == "abs") { // incl. zabs & cdabs // Argument can be complex or real if (auto *x{UnwrapExpr>(args[0])}) { return FoldElementalIntrinsic( context, std::move(funcRef), &Scalar::ABS); } else if (auto *z{UnwrapExpr>(args[0])}) { return FoldElementalIntrinsic(context, std::move(funcRef), ScalarFunc([&name, &context]( const Scalar &z) -> Scalar { ValueWithRealFlags> y{z.ABS()}; if (y.flags.test(RealFlag::Overflow)) { context.messages().Say( "complex ABS intrinsic folding overflow"_warn_en_US, name); } return y.value; })); } else { common::die(" unexpected argument type inside abs"); } } else if (name == "aimag") { if (auto *zExpr{UnwrapExpr>(args[0])}) { return Fold(context, Expr{ComplexComponent{true, std::move(*zExpr)}}); } } else if (name == "aint" || name == "anint") { // ANINT rounds ties away from zero, not to even common::RoundingMode mode{name == "aint" ? common::RoundingMode::ToZero : common::RoundingMode::TiesAwayFromZero}; return FoldElementalIntrinsic(context, std::move(funcRef), ScalarFunc( [&name, &context, mode](const Scalar &x) -> Scalar { ValueWithRealFlags> y{x.ToWholeNumber(mode)}; if (y.flags.test(RealFlag::Overflow)) { context.messages().Say( "%s intrinsic folding overflow"_warn_en_US, name); } return y.value; })); } else if (name == "dim") { return FoldElementalIntrinsic(context, std::move(funcRef), ScalarFunc([&context](const Scalar &x, const Scalar &y) -> Scalar { ValueWithRealFlags> result{x.DIM(y)}; if (result.flags.test(RealFlag::Overflow)) { context.messages().Say("DIM intrinsic folding overflow"_warn_en_US); } return result.value; })); } else if (name == "dot_product") { return FoldDotProduct(context, std::move(funcRef)); } else if (name == "dprod") { // Rewrite DPROD(x,y) -> DBLE(x)*DBLE(y) if (args.at(0) && args.at(1)) { const auto *xExpr{args[0]->UnwrapExpr()}; const auto *yExpr{args[1]->UnwrapExpr()}; if (xExpr && yExpr) { return Fold(context, ToReal(context, common::Clone(*xExpr)) * ToReal(context, common::Clone(*yExpr))); } } } else if (name == "epsilon") { return Expr{Scalar::EPSILON()}; } else if (name == "fraction") { return FoldElementalIntrinsic(context, std::move(funcRef), ScalarFunc( [](const Scalar &x) -> Scalar { return x.FRACTION(); })); } else if (name == "huge") { return Expr{Scalar::HUGE()}; } else if (name == "hypot") { CHECK(args.size() == 2); return FoldElementalIntrinsic(context, std::move(funcRef), ScalarFunc( [&](const Scalar &x, const Scalar &y) -> Scalar { ValueWithRealFlags> result{x.HYPOT(y)}; if (result.flags.test(RealFlag::Overflow)) { context.messages().Say( "HYPOT intrinsic folding overflow"_warn_en_US); } return result.value; })); } else if (name == "matmul") { return FoldMatmul(context, std::move(funcRef)); } else if (name == "max") { return FoldMINorMAX(context, std::move(funcRef), Ordering::Greater); } else if (name == "maxval") { return FoldMaxvalMinval(context, std::move(funcRef), RelationalOperator::GT, T::Scalar::HUGE().Negate()); } else if (name == "min") { return FoldMINorMAX(context, std::move(funcRef), Ordering::Less); } else if (name == "minval") { return FoldMaxvalMinval( context, std::move(funcRef), RelationalOperator::LT, T::Scalar::HUGE()); } else if (name == "mod") { CHECK(args.size() == 2); return FoldElementalIntrinsic(context, std::move(funcRef), ScalarFunc( [&context](const Scalar &x, const Scalar &y) -> Scalar { auto result{x.MOD(y)}; if (result.flags.test(RealFlag::DivideByZero)) { context.messages().Say( "second argument to MOD must not be zero"_warn_en_US); } return result.value; })); } else if (name == "modulo") { CHECK(args.size() == 2); return FoldElementalIntrinsic(context, std::move(funcRef), ScalarFunc( [&context](const Scalar &x, const Scalar &y) -> Scalar { auto result{x.MODULO(y)}; if (result.flags.test(RealFlag::DivideByZero)) { context.messages().Say( "second argument to MODULO must not be zero"_warn_en_US); } return result.value; })); } else if (name == "nearest") { if (const auto *sExpr{UnwrapExpr>(args[1])}) { return common::visit( [&](const auto &sVal) { using TS = ResultType; return FoldElementalIntrinsic(context, std::move(funcRef), ScalarFunc([&](const Scalar &x, const Scalar &s) -> Scalar { if (s.IsZero()) { context.messages().Say( "NEAREST: S argument is zero"_warn_en_US); } auto result{x.NEAREST(!s.IsNegative())}; if (result.flags.test(RealFlag::Overflow)) { context.messages().Say( "NEAREST intrinsic folding overflow"_warn_en_US); } else if (result.flags.test(RealFlag::InvalidArgument)) { context.messages().Say( "NEAREST intrinsic folding: bad argument"_warn_en_US); } return result.value; })); }, sExpr->u); } } else if (name == "norm2") { return FoldNorm2(context, std::move(funcRef)); } else if (name == "product") { auto one{Scalar::FromInteger(value::Integer<8>{1}).value}; return FoldProduct(context, std::move(funcRef), one); } else if (name == "real" || name == "dble") { if (auto *expr{args[0].value().UnwrapExpr()}) { return ToReal(context, std::move(*expr)); } } else if (name == "rrspacing") { return FoldElementalIntrinsic(context, std::move(funcRef), ScalarFunc( [](const Scalar &x) -> Scalar { return x.RRSPACING(); })); } else if (name == "scale") { if (const auto *byExpr{UnwrapExpr>(args[1])}) { return common::visit( [&](const auto &byVal) { using TBY = ResultType; return FoldElementalIntrinsic(context, std::move(funcRef), ScalarFunc( [&](const Scalar &x, const Scalar &y) -> Scalar { ValueWithRealFlags> result{x. // MSVC chokes on the keyword "template" here in a call to a // member function template. #ifndef _MSC_VER template #endif SCALE(y)}; if (result.flags.test(RealFlag::Overflow)) { context.messages().Say( "SCALE intrinsic folding overflow"_warn_en_US); } return result.value; })); }, byExpr->u); } } else if (name == "set_exponent") { if (const auto *iExpr{UnwrapExpr>(args[1])}) { return common::visit( [&](const auto &iVal) { using TY = ResultType; return FoldElementalIntrinsic(context, std::move(funcRef), ScalarFunc( [&](const Scalar &x, const Scalar &i) -> Scalar { return x.SET_EXPONENT(i.ToInt64()); })); }, iExpr->u); } } else if (name == "sign") { return FoldElementalIntrinsic( context, std::move(funcRef), &Scalar::SIGN); } else if (name == "spacing") { return FoldElementalIntrinsic(context, std::move(funcRef), ScalarFunc( [](const Scalar &x) -> Scalar { return x.SPACING(); })); } else if (name == "sqrt") { return FoldElementalIntrinsic(context, std::move(funcRef), ScalarFunc( [](const Scalar &x) -> Scalar { return x.SQRT().value; })); } else if (name == "sum") { return FoldSum(context, std::move(funcRef)); } else if (name == "tiny") { return Expr{Scalar::TINY()}; } else if (name == "__builtin_fma") { CHECK(args.size() == 3); } else if (name == "__builtin_ieee_next_after") { if (const auto *yExpr{UnwrapExpr>(args[1])}) { return common::visit( [&](const auto &yVal) { using TY = ResultType; return FoldElementalIntrinsic(context, std::move(funcRef), ScalarFunc([&](const Scalar &x, const Scalar &y) -> Scalar { bool upward{true}; switch (x.Compare(Scalar::Convert(y).value)) { case Relation::Unordered: context.messages().Say( "IEEE_NEXT_AFTER intrinsic folding: bad argument"_warn_en_US); return x; case Relation::Equal: return x; case Relation::Less: upward = true; break; case Relation::Greater: upward = false; break; } auto result{x.NEAREST(upward)}; if (result.flags.test(RealFlag::Overflow)) { context.messages().Say( "IEEE_NEXT_AFTER intrinsic folding overflow"_warn_en_US); } return result.value; })); }, yExpr->u); } } else if (name == "__builtin_ieee_next_up" || name == "__builtin_ieee_next_down") { bool upward{name == "__builtin_ieee_next_up"}; const char *iName{upward ? "IEEE_NEXT_UP" : "IEEE_NEXT_DOWN"}; return FoldElementalIntrinsic(context, std::move(funcRef), ScalarFunc([&](const Scalar &x) -> Scalar { auto result{x.NEAREST(upward)}; if (result.flags.test(RealFlag::Overflow)) { context.messages().Say( "%s intrinsic folding overflow"_warn_en_US, iName); } else if (result.flags.test(RealFlag::InvalidArgument)) { context.messages().Say( "%s intrinsic folding: bad argument"_warn_en_US, iName); } return result.value; })); } return Expr{std::move(funcRef)}; } #ifdef _MSC_VER // disable bogus warning about missing definitions #pragma warning(disable : 4661) #endif FOR_EACH_REAL_KIND(template class ExpressionBase, ) template class ExpressionBase; } // namespace Fortran::evaluate