1261 lines
41 KiB
C++
1261 lines
41 KiB
C++
//===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
|
#include "mlir/Dialect/EmitC/IR/EmitC.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/Dialect.h"
|
|
#include "mlir/IR/Operation.h"
|
|
#include "mlir/Support/IndentedOstream.h"
|
|
#include "mlir/Support/LLVM.h"
|
|
#include "mlir/Target/Cpp/CppEmitter.h"
|
|
#include "llvm/ADT/DenseMap.h"
|
|
#include "llvm/ADT/StringExtras.h"
|
|
#include "llvm/ADT/StringMap.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
#include <utility>
|
|
|
|
#define DEBUG_TYPE "translate-to-cpp"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::emitc;
|
|
using llvm::formatv;
|
|
|
|
/// Convenience functions to produce interleaved output with functions returning
|
|
/// a LogicalResult. This is different than those in STLExtras as functions used
|
|
/// on each element doesn't return a string.
|
|
template <typename ForwardIterator, typename UnaryFunctor,
|
|
typename NullaryFunctor>
|
|
inline LogicalResult
|
|
interleaveWithError(ForwardIterator begin, ForwardIterator end,
|
|
UnaryFunctor eachFn, NullaryFunctor betweenFn) {
|
|
if (begin == end)
|
|
return success();
|
|
if (failed(eachFn(*begin)))
|
|
return failure();
|
|
++begin;
|
|
for (; begin != end; ++begin) {
|
|
betweenFn();
|
|
if (failed(eachFn(*begin)))
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
template <typename Container, typename UnaryFunctor, typename NullaryFunctor>
|
|
inline LogicalResult interleaveWithError(const Container &c,
|
|
UnaryFunctor eachFn,
|
|
NullaryFunctor betweenFn) {
|
|
return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn);
|
|
}
|
|
|
|
template <typename Container, typename UnaryFunctor>
|
|
inline LogicalResult interleaveCommaWithError(const Container &c,
|
|
raw_ostream &os,
|
|
UnaryFunctor eachFn) {
|
|
return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; });
|
|
}
|
|
|
|
/// Return the precedence of a operator as an integer, higher values
|
|
/// imply higher precedence.
|
|
static int getOperatorPrecedence(Operation *operation) {
|
|
return llvm::TypeSwitch<Operation *, int>(operation)
|
|
.Case<emitc::AddOp>([&](auto op) { return 11; })
|
|
.Case<emitc::ApplyOp>([&](auto op) { return 13; })
|
|
.Case<emitc::CastOp>([&](auto op) { return 13; })
|
|
.Case<emitc::CmpOp>([&](auto op) {
|
|
switch (op.getPredicate()) {
|
|
case emitc::CmpPredicate::eq:
|
|
case emitc::CmpPredicate::ne:
|
|
return 8;
|
|
case emitc::CmpPredicate::lt:
|
|
case emitc::CmpPredicate::le:
|
|
case emitc::CmpPredicate::gt:
|
|
case emitc::CmpPredicate::ge:
|
|
return 9;
|
|
case emitc::CmpPredicate::three_way:
|
|
return 10;
|
|
}
|
|
})
|
|
.Case<emitc::DivOp>([&](auto op) { return 12; })
|
|
.Case<emitc::MulOp>([&](auto op) { return 12; })
|
|
.Case<emitc::RemOp>([&](auto op) { return 12; })
|
|
.Case<emitc::SubOp>([&](auto op) { return 11; })
|
|
.Case<emitc::CallOpaqueOp>([&](auto op) { return 14; });
|
|
llvm_unreachable("Unsupported operator");
|
|
}
|
|
|
|
namespace {
|
|
/// Emitter that uses dialect specific emitters to emit C++ code.
|
|
struct CppEmitter {
|
|
explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop);
|
|
|
|
/// Emits attribute or returns failure.
|
|
LogicalResult emitAttribute(Location loc, Attribute attr);
|
|
|
|
/// Emits operation 'op' with/without training semicolon or returns failure.
|
|
LogicalResult emitOperation(Operation &op, bool trailingSemicolon);
|
|
|
|
/// Emits type 'type' or returns failure.
|
|
LogicalResult emitType(Location loc, Type type);
|
|
|
|
/// Emits array of types as a std::tuple of the emitted types.
|
|
/// - emits void for an empty array;
|
|
/// - emits the type of the only element for arrays of size one;
|
|
/// - emits a std::tuple otherwise;
|
|
LogicalResult emitTypes(Location loc, ArrayRef<Type> types);
|
|
|
|
/// Emits array of types as a std::tuple of the emitted types independently of
|
|
/// the array size.
|
|
LogicalResult emitTupleType(Location loc, ArrayRef<Type> types);
|
|
|
|
/// Emits an assignment for a variable which has been declared previously.
|
|
LogicalResult emitVariableAssignment(OpResult result);
|
|
|
|
/// Emits a variable declaration for a result of an operation.
|
|
LogicalResult emitVariableDeclaration(OpResult result,
|
|
bool trailingSemicolon);
|
|
|
|
/// Emits the variable declaration and assignment prefix for 'op'.
|
|
/// - emits separate variable followed by std::tie for multi-valued operation;
|
|
/// - emits single type followed by variable for single result;
|
|
/// - emits nothing if no value produced by op;
|
|
/// Emits final '=' operator where a type is produced. Returns failure if
|
|
/// any result type could not be converted.
|
|
LogicalResult emitAssignPrefix(Operation &op);
|
|
|
|
/// Emits a label for the block.
|
|
LogicalResult emitLabel(Block &block);
|
|
|
|
/// Emits the operands and atttributes of the operation. All operands are
|
|
/// emitted first and then all attributes in alphabetical order.
|
|
LogicalResult emitOperandsAndAttributes(Operation &op,
|
|
ArrayRef<StringRef> exclude = {});
|
|
|
|
/// Emits the operands of the operation. All operands are emitted in order.
|
|
LogicalResult emitOperands(Operation &op);
|
|
|
|
/// Emits value as an operands of an operation
|
|
LogicalResult emitOperand(Value value);
|
|
|
|
/// Emit an expression as a C expression.
|
|
LogicalResult emitExpression(ExpressionOp expressionOp);
|
|
|
|
/// Return the existing or a new name for a Value.
|
|
StringRef getOrCreateName(Value val);
|
|
|
|
/// Return the existing or a new label of a Block.
|
|
StringRef getOrCreateName(Block &block);
|
|
|
|
/// Whether to map an mlir integer to a unsigned integer in C++.
|
|
bool shouldMapToUnsigned(IntegerType::SignednessSemantics val);
|
|
|
|
/// RAII helper function to manage entering/exiting C++ scopes.
|
|
struct Scope {
|
|
Scope(CppEmitter &emitter)
|
|
: valueMapperScope(emitter.valueMapper),
|
|
blockMapperScope(emitter.blockMapper), emitter(emitter) {
|
|
emitter.valueInScopeCount.push(emitter.valueInScopeCount.top());
|
|
emitter.labelInScopeCount.push(emitter.labelInScopeCount.top());
|
|
}
|
|
~Scope() {
|
|
emitter.valueInScopeCount.pop();
|
|
emitter.labelInScopeCount.pop();
|
|
}
|
|
|
|
private:
|
|
llvm::ScopedHashTableScope<Value, std::string> valueMapperScope;
|
|
llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope;
|
|
CppEmitter &emitter;
|
|
};
|
|
|
|
/// Returns wether the Value is assigned to a C++ variable in the scope.
|
|
bool hasValueInScope(Value val);
|
|
|
|
// Returns whether a label is assigned to the block.
|
|
bool hasBlockLabel(Block &block);
|
|
|
|
/// Returns the output stream.
|
|
raw_indented_ostream &ostream() { return os; };
|
|
|
|
/// Returns if all variables for op results and basic block arguments need to
|
|
/// be declared at the beginning of a function.
|
|
bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; };
|
|
|
|
/// Get expression currently being emitted.
|
|
ExpressionOp getEmittedExpression() { return emittedExpression; }
|
|
|
|
/// Determine whether given value is part of the expression potentially being
|
|
/// emitted.
|
|
bool isPartOfCurrentExpression(Value value) {
|
|
if (!emittedExpression)
|
|
return false;
|
|
Operation *def = value.getDefiningOp();
|
|
if (!def)
|
|
return false;
|
|
auto operandExpression = dyn_cast<ExpressionOp>(def->getParentOp());
|
|
return operandExpression == emittedExpression;
|
|
};
|
|
|
|
private:
|
|
using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
|
|
using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
|
|
|
|
/// Output stream to emit to.
|
|
raw_indented_ostream os;
|
|
|
|
/// Boolean to enforce that all variables for op results and block
|
|
/// arguments are declared at the beginning of the function. This also
|
|
/// includes results from ops located in nested regions.
|
|
bool declareVariablesAtTop;
|
|
|
|
/// Map from value to name of C++ variable that contain the name.
|
|
ValueMapper valueMapper;
|
|
|
|
/// Map from block to name of C++ label.
|
|
BlockMapper blockMapper;
|
|
|
|
/// The number of values in the current scope. This is used to declare the
|
|
/// names of values in a scope.
|
|
std::stack<int64_t> valueInScopeCount;
|
|
std::stack<int64_t> labelInScopeCount;
|
|
|
|
/// State of the current expression being emitted.
|
|
ExpressionOp emittedExpression;
|
|
SmallVector<int> emittedExpressionPrecedence;
|
|
|
|
void pushExpressionPrecedence(int precedence) {
|
|
emittedExpressionPrecedence.push_back(precedence);
|
|
}
|
|
void popExpressionPrecedence() { emittedExpressionPrecedence.pop_back(); }
|
|
static int lowestPrecedence() { return 0; }
|
|
int getExpressionPrecedence() {
|
|
if (emittedExpressionPrecedence.empty())
|
|
return lowestPrecedence();
|
|
return emittedExpressionPrecedence.back();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
/// Determine whether expression \p expressionOp should be emitted inline, i.e.
|
|
/// as part of its user. This function recommends inlining of any expressions
|
|
/// that can be inlined unless it is used by another expression, under the
|
|
/// assumption that any expression fusion/re-materialization was taken care of
|
|
/// by transformations run by the backend.
|
|
static bool shouldBeInlined(ExpressionOp expressionOp) {
|
|
// Do not inline if expression is marked as such.
|
|
if (expressionOp.getDoNotInline())
|
|
return false;
|
|
|
|
// Do not inline expressions with side effects to prevent side-effect
|
|
// reordering.
|
|
if (expressionOp.hasSideEffects())
|
|
return false;
|
|
|
|
// Do not inline expressions with multiple uses.
|
|
Value result = expressionOp.getResult();
|
|
if (!result.hasOneUse())
|
|
return false;
|
|
|
|
// Do not inline expressions used by other expressions, as any desired
|
|
// expression folding was taken care of by transformations.
|
|
Operation *user = *result.getUsers().begin();
|
|
return !user->getParentOfType<ExpressionOp>();
|
|
}
|
|
|
|
static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation,
|
|
Attribute value) {
|
|
OpResult result = operation->getResult(0);
|
|
|
|
// Only emit an assignment as the variable was already declared when printing
|
|
// the FuncOp.
|
|
if (emitter.shouldDeclareVariablesAtTop()) {
|
|
// Skip the assignment if the emitc.constant has no value.
|
|
if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(value)) {
|
|
if (oAttr.getValue().empty())
|
|
return success();
|
|
}
|
|
|
|
if (failed(emitter.emitVariableAssignment(result)))
|
|
return failure();
|
|
return emitter.emitAttribute(operation->getLoc(), value);
|
|
}
|
|
|
|
// Emit a variable declaration for an emitc.constant op without value.
|
|
if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(value)) {
|
|
if (oAttr.getValue().empty())
|
|
// The semicolon gets printed by the emitOperation function.
|
|
return emitter.emitVariableDeclaration(result,
|
|
/*trailingSemicolon=*/false);
|
|
}
|
|
|
|
// Emit a variable declaration.
|
|
if (failed(emitter.emitAssignPrefix(*operation)))
|
|
return failure();
|
|
return emitter.emitAttribute(operation->getLoc(), value);
|
|
}
|
|
|
|
static LogicalResult printOperation(CppEmitter &emitter,
|
|
emitc::ConstantOp constantOp) {
|
|
Operation *operation = constantOp.getOperation();
|
|
Attribute value = constantOp.getValue();
|
|
|
|
return printConstantOp(emitter, operation, value);
|
|
}
|
|
|
|
static LogicalResult printOperation(CppEmitter &emitter,
|
|
emitc::VariableOp variableOp) {
|
|
Operation *operation = variableOp.getOperation();
|
|
Attribute value = variableOp.getValue();
|
|
|
|
return printConstantOp(emitter, operation, value);
|
|
}
|
|
|
|
static LogicalResult printOperation(CppEmitter &emitter,
|
|
arith::ConstantOp constantOp) {
|
|
Operation *operation = constantOp.getOperation();
|
|
Attribute value = constantOp.getValue();
|
|
|
|
return printConstantOp(emitter, operation, value);
|
|
}
|
|
|
|
static LogicalResult printOperation(CppEmitter &emitter,
|
|
func::ConstantOp constantOp) {
|
|
Operation *operation = constantOp.getOperation();
|
|
Attribute value = constantOp.getValueAttr();
|
|
|
|
return printConstantOp(emitter, operation, value);
|
|
}
|
|
|
|
static LogicalResult printOperation(CppEmitter &emitter,
|
|
emitc::AssignOp assignOp) {
|
|
auto variableOp = cast<emitc::VariableOp>(assignOp.getVar().getDefiningOp());
|
|
OpResult result = variableOp->getResult(0);
|
|
|
|
if (failed(emitter.emitVariableAssignment(result)))
|
|
return failure();
|
|
|
|
return emitter.emitOperand(assignOp.getValue());
|
|
}
|
|
|
|
static LogicalResult printBinaryOperation(CppEmitter &emitter,
|
|
Operation *operation,
|
|
StringRef binaryOperator) {
|
|
raw_ostream &os = emitter.ostream();
|
|
|
|
if (failed(emitter.emitAssignPrefix(*operation)))
|
|
return failure();
|
|
|
|
if (failed(emitter.emitOperand(operation->getOperand(0))))
|
|
return failure();
|
|
|
|
os << " " << binaryOperator << " ";
|
|
|
|
if (failed(emitter.emitOperand(operation->getOperand(1))))
|
|
return failure();
|
|
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult printOperation(CppEmitter &emitter, emitc::AddOp addOp) {
|
|
Operation *operation = addOp.getOperation();
|
|
|
|
return printBinaryOperation(emitter, operation, "+");
|
|
}
|
|
|
|
static LogicalResult printOperation(CppEmitter &emitter, emitc::DivOp divOp) {
|
|
Operation *operation = divOp.getOperation();
|
|
|
|
return printBinaryOperation(emitter, operation, "/");
|
|
}
|
|
|
|
static LogicalResult printOperation(CppEmitter &emitter, emitc::MulOp mulOp) {
|
|
Operation *operation = mulOp.getOperation();
|
|
|
|
return printBinaryOperation(emitter, operation, "*");
|
|
}
|
|
|
|
static LogicalResult printOperation(CppEmitter &emitter, emitc::RemOp remOp) {
|
|
Operation *operation = remOp.getOperation();
|
|
|
|
return printBinaryOperation(emitter, operation, "%");
|
|
}
|
|
|
|
static LogicalResult printOperation(CppEmitter &emitter, emitc::SubOp subOp) {
|
|
Operation *operation = subOp.getOperation();
|
|
|
|
return printBinaryOperation(emitter, operation, "-");
|
|
}
|
|
|
|
static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) {
|
|
Operation *operation = cmpOp.getOperation();
|
|
|
|
StringRef binaryOperator;
|
|
|
|
switch (cmpOp.getPredicate()) {
|
|
case emitc::CmpPredicate::eq:
|
|
binaryOperator = "==";
|
|
break;
|
|
case emitc::CmpPredicate::ne:
|
|
binaryOperator = "!=";
|
|
break;
|
|
case emitc::CmpPredicate::lt:
|
|
binaryOperator = "<";
|
|
break;
|
|
case emitc::CmpPredicate::le:
|
|
binaryOperator = "<=";
|
|
break;
|
|
case emitc::CmpPredicate::gt:
|
|
binaryOperator = ">";
|
|
break;
|
|
case emitc::CmpPredicate::ge:
|
|
binaryOperator = ">=";
|
|
break;
|
|
case emitc::CmpPredicate::three_way:
|
|
binaryOperator = "<=>";
|
|
break;
|
|
}
|
|
|
|
return printBinaryOperation(emitter, operation, binaryOperator);
|
|
}
|
|
|
|
static LogicalResult printOperation(CppEmitter &emitter,
|
|
cf::BranchOp branchOp) {
|
|
raw_ostream &os = emitter.ostream();
|
|
Block &successor = *branchOp.getSuccessor();
|
|
|
|
for (auto pair :
|
|
llvm::zip(branchOp.getOperands(), successor.getArguments())) {
|
|
Value &operand = std::get<0>(pair);
|
|
BlockArgument &argument = std::get<1>(pair);
|
|
os << emitter.getOrCreateName(argument) << " = "
|
|
<< emitter.getOrCreateName(operand) << ";\n";
|
|
}
|
|
|
|
os << "goto ";
|
|
if (!(emitter.hasBlockLabel(successor)))
|
|
return branchOp.emitOpError("unable to find label for successor block");
|
|
os << emitter.getOrCreateName(successor);
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult printOperation(CppEmitter &emitter,
|
|
cf::CondBranchOp condBranchOp) {
|
|
raw_indented_ostream &os = emitter.ostream();
|
|
Block &trueSuccessor = *condBranchOp.getTrueDest();
|
|
Block &falseSuccessor = *condBranchOp.getFalseDest();
|
|
|
|
os << "if (" << emitter.getOrCreateName(condBranchOp.getCondition())
|
|
<< ") {\n";
|
|
|
|
os.indent();
|
|
|
|
// If condition is true.
|
|
for (auto pair : llvm::zip(condBranchOp.getTrueOperands(),
|
|
trueSuccessor.getArguments())) {
|
|
Value &operand = std::get<0>(pair);
|
|
BlockArgument &argument = std::get<1>(pair);
|
|
os << emitter.getOrCreateName(argument) << " = "
|
|
<< emitter.getOrCreateName(operand) << ";\n";
|
|
}
|
|
|
|
os << "goto ";
|
|
if (!(emitter.hasBlockLabel(trueSuccessor))) {
|
|
return condBranchOp.emitOpError("unable to find label for successor block");
|
|
}
|
|
os << emitter.getOrCreateName(trueSuccessor) << ";\n";
|
|
os.unindent() << "} else {\n";
|
|
os.indent();
|
|
// If condition is false.
|
|
for (auto pair : llvm::zip(condBranchOp.getFalseOperands(),
|
|
falseSuccessor.getArguments())) {
|
|
Value &operand = std::get<0>(pair);
|
|
BlockArgument &argument = std::get<1>(pair);
|
|
os << emitter.getOrCreateName(argument) << " = "
|
|
<< emitter.getOrCreateName(operand) << ";\n";
|
|
}
|
|
|
|
os << "goto ";
|
|
if (!(emitter.hasBlockLabel(falseSuccessor))) {
|
|
return condBranchOp.emitOpError()
|
|
<< "unable to find label for successor block";
|
|
}
|
|
os << emitter.getOrCreateName(falseSuccessor) << ";\n";
|
|
os.unindent() << "}";
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult printOperation(CppEmitter &emitter, func::CallOp callOp) {
|
|
if (failed(emitter.emitAssignPrefix(*callOp.getOperation())))
|
|
return failure();
|
|
|
|
raw_ostream &os = emitter.ostream();
|
|
os << callOp.getCallee() << "(";
|
|
if (failed(emitter.emitOperands(*callOp.getOperation())))
|
|
return failure();
|
|
os << ")";
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult printOperation(CppEmitter &emitter,
|
|
emitc::CallOpaqueOp callOpaqueOp) {
|
|
raw_ostream &os = emitter.ostream();
|
|
Operation &op = *callOpaqueOp.getOperation();
|
|
|
|
if (failed(emitter.emitAssignPrefix(op)))
|
|
return failure();
|
|
os << callOpaqueOp.getCallee();
|
|
|
|
auto emitArgs = [&](Attribute attr) -> LogicalResult {
|
|
if (auto t = dyn_cast<IntegerAttr>(attr)) {
|
|
// Index attributes are treated specially as operand index.
|
|
if (t.getType().isIndex()) {
|
|
int64_t idx = t.getInt();
|
|
Value operand = op.getOperand(idx);
|
|
auto literalDef =
|
|
dyn_cast_if_present<LiteralOp>(operand.getDefiningOp());
|
|
if (!literalDef && !emitter.hasValueInScope(operand))
|
|
return op.emitOpError("operand ")
|
|
<< idx << "'s value not defined in scope";
|
|
os << emitter.getOrCreateName(operand);
|
|
return success();
|
|
}
|
|
}
|
|
if (failed(emitter.emitAttribute(op.getLoc(), attr)))
|
|
return failure();
|
|
|
|
return success();
|
|
};
|
|
|
|
if (callOpaqueOp.getTemplateArgs()) {
|
|
os << "<";
|
|
if (failed(interleaveCommaWithError(*callOpaqueOp.getTemplateArgs(), os,
|
|
emitArgs)))
|
|
return failure();
|
|
os << ">";
|
|
}
|
|
|
|
os << "(";
|
|
|
|
LogicalResult emittedArgs =
|
|
callOpaqueOp.getArgs()
|
|
? interleaveCommaWithError(*callOpaqueOp.getArgs(), os, emitArgs)
|
|
: emitter.emitOperands(op);
|
|
if (failed(emittedArgs))
|
|
return failure();
|
|
os << ")";
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult printOperation(CppEmitter &emitter,
|
|
emitc::ApplyOp applyOp) {
|
|
raw_ostream &os = emitter.ostream();
|
|
Operation &op = *applyOp.getOperation();
|
|
|
|
if (failed(emitter.emitAssignPrefix(op)))
|
|
return failure();
|
|
os << applyOp.getApplicableOperator();
|
|
os << emitter.getOrCreateName(applyOp.getOperand());
|
|
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) {
|
|
raw_ostream &os = emitter.ostream();
|
|
Operation &op = *castOp.getOperation();
|
|
|
|
if (failed(emitter.emitAssignPrefix(op)))
|
|
return failure();
|
|
os << "(";
|
|
if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType())))
|
|
return failure();
|
|
os << ") ";
|
|
return emitter.emitOperand(castOp.getOperand());
|
|
}
|
|
|
|
static LogicalResult printOperation(CppEmitter &emitter,
|
|
emitc::ExpressionOp expressionOp) {
|
|
if (shouldBeInlined(expressionOp))
|
|
return success();
|
|
|
|
Operation &op = *expressionOp.getOperation();
|
|
|
|
if (failed(emitter.emitAssignPrefix(op)))
|
|
return failure();
|
|
|
|
return emitter.emitExpression(expressionOp);
|
|
}
|
|
|
|
static LogicalResult printOperation(CppEmitter &emitter,
|
|
emitc::IncludeOp includeOp) {
|
|
raw_ostream &os = emitter.ostream();
|
|
|
|
os << "#include ";
|
|
if (includeOp.getIsStandardInclude())
|
|
os << "<" << includeOp.getInclude() << ">";
|
|
else
|
|
os << "\"" << includeOp.getInclude() << "\"";
|
|
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
|
|
|
|
raw_indented_ostream &os = emitter.ostream();
|
|
|
|
// Utility function to determine whether a value is an expression that will be
|
|
// inlined, and as such should be wrapped in parentheses in order to guarantee
|
|
// its precedence and associativity.
|
|
auto requiresParentheses = [&](Value value) {
|
|
auto expressionOp =
|
|
dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
|
|
if (!expressionOp)
|
|
return false;
|
|
return shouldBeInlined(expressionOp);
|
|
};
|
|
|
|
os << "for (";
|
|
if (failed(
|
|
emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType())))
|
|
return failure();
|
|
os << " ";
|
|
os << emitter.getOrCreateName(forOp.getInductionVar());
|
|
os << " = ";
|
|
if (failed(emitter.emitOperand(forOp.getLowerBound())))
|
|
return failure();
|
|
os << "; ";
|
|
os << emitter.getOrCreateName(forOp.getInductionVar());
|
|
os << " < ";
|
|
Value upperBound = forOp.getUpperBound();
|
|
bool upperBoundRequiresParentheses = requiresParentheses(upperBound);
|
|
if (upperBoundRequiresParentheses)
|
|
os << "(";
|
|
if (failed(emitter.emitOperand(upperBound)))
|
|
return failure();
|
|
if (upperBoundRequiresParentheses)
|
|
os << ")";
|
|
os << "; ";
|
|
os << emitter.getOrCreateName(forOp.getInductionVar());
|
|
os << " += ";
|
|
if (failed(emitter.emitOperand(forOp.getStep())))
|
|
return failure();
|
|
os << ") {\n";
|
|
os.indent();
|
|
|
|
Region &forRegion = forOp.getRegion();
|
|
auto regionOps = forRegion.getOps();
|
|
|
|
// We skip the trailing yield op.
|
|
for (auto it = regionOps.begin(); std::next(it) != regionOps.end(); ++it) {
|
|
if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
|
|
return failure();
|
|
}
|
|
|
|
os.unindent() << "}";
|
|
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult printOperation(CppEmitter &emitter, emitc::IfOp ifOp) {
|
|
raw_indented_ostream &os = emitter.ostream();
|
|
|
|
// Helper function to emit all ops except the last one, expected to be
|
|
// emitc::yield.
|
|
auto emitAllExceptLast = [&emitter](Region ®ion) {
|
|
Region::OpIterator it = region.op_begin(), end = region.op_end();
|
|
for (; std::next(it) != end; ++it) {
|
|
if (failed(emitter.emitOperation(*it, /*trailingSemicolon=*/true)))
|
|
return failure();
|
|
}
|
|
assert(isa<emitc::YieldOp>(*it) &&
|
|
"Expected last operation in the region to be emitc::yield");
|
|
return success();
|
|
};
|
|
|
|
os << "if (";
|
|
if (failed(emitter.emitOperand(ifOp.getCondition())))
|
|
return failure();
|
|
os << ") {\n";
|
|
os.indent();
|
|
if (failed(emitAllExceptLast(ifOp.getThenRegion())))
|
|
return failure();
|
|
os.unindent() << "}";
|
|
|
|
Region &elseRegion = ifOp.getElseRegion();
|
|
if (!elseRegion.empty()) {
|
|
os << " else {\n";
|
|
os.indent();
|
|
if (failed(emitAllExceptLast(elseRegion)))
|
|
return failure();
|
|
os.unindent() << "}";
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult printOperation(CppEmitter &emitter,
|
|
func::ReturnOp returnOp) {
|
|
raw_ostream &os = emitter.ostream();
|
|
os << "return";
|
|
switch (returnOp.getNumOperands()) {
|
|
case 0:
|
|
return success();
|
|
case 1:
|
|
os << " ";
|
|
if (failed(emitter.emitOperand(returnOp.getOperand(0))))
|
|
return failure();
|
|
return success();
|
|
default:
|
|
os << " std::make_tuple(";
|
|
if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation())))
|
|
return failure();
|
|
os << ")";
|
|
return success();
|
|
}
|
|
}
|
|
|
|
static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) {
|
|
CppEmitter::Scope scope(emitter);
|
|
|
|
for (Operation &op : moduleOp) {
|
|
if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
static LogicalResult printOperation(CppEmitter &emitter,
|
|
func::FuncOp functionOp) {
|
|
// We need to declare variables at top if the function has multiple blocks.
|
|
if (!emitter.shouldDeclareVariablesAtTop() &&
|
|
functionOp.getBlocks().size() > 1) {
|
|
return functionOp.emitOpError(
|
|
"with multiple blocks needs variables declared at top");
|
|
}
|
|
|
|
CppEmitter::Scope scope(emitter);
|
|
raw_indented_ostream &os = emitter.ostream();
|
|
if (failed(emitter.emitTypes(functionOp.getLoc(),
|
|
functionOp.getFunctionType().getResults())))
|
|
return failure();
|
|
os << " " << functionOp.getName();
|
|
|
|
os << "(";
|
|
if (failed(interleaveCommaWithError(
|
|
functionOp.getArguments(), os,
|
|
[&](BlockArgument arg) -> LogicalResult {
|
|
if (failed(emitter.emitType(functionOp.getLoc(), arg.getType())))
|
|
return failure();
|
|
os << " " << emitter.getOrCreateName(arg);
|
|
return success();
|
|
})))
|
|
return failure();
|
|
os << ") {\n";
|
|
os.indent();
|
|
if (emitter.shouldDeclareVariablesAtTop()) {
|
|
// Declare all variables that hold op results including those from nested
|
|
// regions.
|
|
WalkResult result =
|
|
functionOp.walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
|
|
if (isa<emitc::LiteralOp>(op) ||
|
|
isa<emitc::ExpressionOp>(op->getParentOp()) ||
|
|
(isa<emitc::ExpressionOp>(op) &&
|
|
shouldBeInlined(cast<emitc::ExpressionOp>(op))))
|
|
return WalkResult::skip();
|
|
for (OpResult result : op->getResults()) {
|
|
if (failed(emitter.emitVariableDeclaration(
|
|
result, /*trailingSemicolon=*/true))) {
|
|
return WalkResult(
|
|
op->emitError("unable to declare result variable for op"));
|
|
}
|
|
}
|
|
return WalkResult::advance();
|
|
});
|
|
if (result.wasInterrupted())
|
|
return failure();
|
|
}
|
|
|
|
Region::BlockListType &blocks = functionOp.getBlocks();
|
|
// Create label names for basic blocks.
|
|
for (Block &block : blocks) {
|
|
emitter.getOrCreateName(block);
|
|
}
|
|
|
|
// Declare variables for basic block arguments.
|
|
for (Block &block : llvm::drop_begin(blocks)) {
|
|
for (BlockArgument &arg : block.getArguments()) {
|
|
if (emitter.hasValueInScope(arg))
|
|
return functionOp.emitOpError(" block argument #")
|
|
<< arg.getArgNumber() << " is out of scope";
|
|
if (failed(
|
|
emitter.emitType(block.getParentOp()->getLoc(), arg.getType()))) {
|
|
return failure();
|
|
}
|
|
os << " " << emitter.getOrCreateName(arg) << ";\n";
|
|
}
|
|
}
|
|
|
|
for (Block &block : blocks) {
|
|
// Only print a label if the block has predecessors.
|
|
if (!block.hasNoPredecessors()) {
|
|
if (failed(emitter.emitLabel(block)))
|
|
return failure();
|
|
}
|
|
for (Operation &op : block.getOperations()) {
|
|
// When generating code for an emitc.if or cf.cond_br op no semicolon
|
|
// needs to be printed after the closing brace.
|
|
// When generating code for an emitc.for op, printing a trailing semicolon
|
|
// is handled within the printOperation function.
|
|
bool trailingSemicolon =
|
|
!isa<cf::CondBranchOp, emitc::ForOp, emitc::IfOp, emitc::LiteralOp>(
|
|
op);
|
|
|
|
if (failed(emitter.emitOperation(
|
|
op, /*trailingSemicolon=*/trailingSemicolon)))
|
|
return failure();
|
|
}
|
|
}
|
|
os.unindent() << "}\n";
|
|
return success();
|
|
}
|
|
|
|
CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
|
|
: os(os), declareVariablesAtTop(declareVariablesAtTop) {
|
|
valueInScopeCount.push(0);
|
|
labelInScopeCount.push(0);
|
|
}
|
|
|
|
/// Return the existing or a new name for a Value.
|
|
StringRef CppEmitter::getOrCreateName(Value val) {
|
|
if (auto literal = dyn_cast_if_present<emitc::LiteralOp>(val.getDefiningOp()))
|
|
return literal.getValue();
|
|
if (!valueMapper.count(val))
|
|
valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
|
|
return *valueMapper.begin(val);
|
|
}
|
|
|
|
/// Return the existing or a new label for a Block.
|
|
StringRef CppEmitter::getOrCreateName(Block &block) {
|
|
if (!blockMapper.count(&block))
|
|
blockMapper.insert(&block, formatv("label{0}", ++labelInScopeCount.top()));
|
|
return *blockMapper.begin(&block);
|
|
}
|
|
|
|
bool CppEmitter::shouldMapToUnsigned(IntegerType::SignednessSemantics val) {
|
|
switch (val) {
|
|
case IntegerType::Signless:
|
|
return false;
|
|
case IntegerType::Signed:
|
|
return false;
|
|
case IntegerType::Unsigned:
|
|
return true;
|
|
}
|
|
llvm_unreachable("Unexpected IntegerType::SignednessSemantics");
|
|
}
|
|
|
|
bool CppEmitter::hasValueInScope(Value val) { return valueMapper.count(val); }
|
|
|
|
bool CppEmitter::hasBlockLabel(Block &block) {
|
|
return blockMapper.count(&block);
|
|
}
|
|
|
|
LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
|
|
auto printInt = [&](const APInt &val, bool isUnsigned) {
|
|
if (val.getBitWidth() == 1) {
|
|
if (val.getBoolValue())
|
|
os << "true";
|
|
else
|
|
os << "false";
|
|
} else {
|
|
SmallString<128> strValue;
|
|
val.toString(strValue, 10, !isUnsigned, false);
|
|
os << strValue;
|
|
}
|
|
};
|
|
|
|
auto printFloat = [&](const APFloat &val) {
|
|
if (val.isFinite()) {
|
|
SmallString<128> strValue;
|
|
// Use default values of toString except don't truncate zeros.
|
|
val.toString(strValue, 0, 0, false);
|
|
switch (llvm::APFloatBase::SemanticsToEnum(val.getSemantics())) {
|
|
case llvm::APFloatBase::S_IEEEsingle:
|
|
os << "(float)";
|
|
break;
|
|
case llvm::APFloatBase::S_IEEEdouble:
|
|
os << "(double)";
|
|
break;
|
|
default:
|
|
break;
|
|
};
|
|
os << strValue;
|
|
} else if (val.isNaN()) {
|
|
os << "NAN";
|
|
} else if (val.isInfinity()) {
|
|
if (val.isNegative())
|
|
os << "-";
|
|
os << "INFINITY";
|
|
}
|
|
};
|
|
|
|
// Print floating point attributes.
|
|
if (auto fAttr = dyn_cast<FloatAttr>(attr)) {
|
|
printFloat(fAttr.getValue());
|
|
return success();
|
|
}
|
|
if (auto dense = dyn_cast<DenseFPElementsAttr>(attr)) {
|
|
os << '{';
|
|
interleaveComma(dense, os, [&](const APFloat &val) { printFloat(val); });
|
|
os << '}';
|
|
return success();
|
|
}
|
|
|
|
// Print integer attributes.
|
|
if (auto iAttr = dyn_cast<IntegerAttr>(attr)) {
|
|
if (auto iType = dyn_cast<IntegerType>(iAttr.getType())) {
|
|
printInt(iAttr.getValue(), shouldMapToUnsigned(iType.getSignedness()));
|
|
return success();
|
|
}
|
|
if (auto iType = dyn_cast<IndexType>(iAttr.getType())) {
|
|
printInt(iAttr.getValue(), false);
|
|
return success();
|
|
}
|
|
}
|
|
if (auto dense = dyn_cast<DenseIntElementsAttr>(attr)) {
|
|
if (auto iType = dyn_cast<IntegerType>(
|
|
cast<TensorType>(dense.getType()).getElementType())) {
|
|
os << '{';
|
|
interleaveComma(dense, os, [&](const APInt &val) {
|
|
printInt(val, shouldMapToUnsigned(iType.getSignedness()));
|
|
});
|
|
os << '}';
|
|
return success();
|
|
}
|
|
if (auto iType = dyn_cast<IndexType>(
|
|
cast<TensorType>(dense.getType()).getElementType())) {
|
|
os << '{';
|
|
interleaveComma(dense, os,
|
|
[&](const APInt &val) { printInt(val, false); });
|
|
os << '}';
|
|
return success();
|
|
}
|
|
}
|
|
|
|
// Print opaque attributes.
|
|
if (auto oAttr = dyn_cast<emitc::OpaqueAttr>(attr)) {
|
|
os << oAttr.getValue();
|
|
return success();
|
|
}
|
|
|
|
// Print symbolic reference attributes.
|
|
if (auto sAttr = dyn_cast<SymbolRefAttr>(attr)) {
|
|
if (sAttr.getNestedReferences().size() > 1)
|
|
return emitError(loc, "attribute has more than 1 nested reference");
|
|
os << sAttr.getRootReference().getValue();
|
|
return success();
|
|
}
|
|
|
|
// Print type attributes.
|
|
if (auto type = dyn_cast<TypeAttr>(attr))
|
|
return emitType(loc, type.getValue());
|
|
|
|
return emitError(loc, "cannot emit attribute: ") << attr;
|
|
}
|
|
|
|
LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) {
|
|
assert(emittedExpressionPrecedence.empty() &&
|
|
"Expected precedence stack to be empty");
|
|
Operation *rootOp = expressionOp.getRootOp();
|
|
|
|
emittedExpression = expressionOp;
|
|
pushExpressionPrecedence(getOperatorPrecedence(rootOp));
|
|
|
|
if (failed(emitOperation(*rootOp, /*trailingSemicolon=*/false)))
|
|
return failure();
|
|
|
|
popExpressionPrecedence();
|
|
assert(emittedExpressionPrecedence.empty() &&
|
|
"Expected precedence stack to be empty");
|
|
emittedExpression = nullptr;
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult CppEmitter::emitOperand(Value value) {
|
|
if (isPartOfCurrentExpression(value)) {
|
|
Operation *def = value.getDefiningOp();
|
|
assert(def && "Expected operand to be defined by an operation");
|
|
int precedence = getOperatorPrecedence(def);
|
|
bool encloseInParenthesis = precedence < getExpressionPrecedence();
|
|
if (encloseInParenthesis) {
|
|
os << "(";
|
|
pushExpressionPrecedence(lowestPrecedence());
|
|
} else
|
|
pushExpressionPrecedence(precedence);
|
|
|
|
if (failed(emitOperation(*def, /*trailingSemicolon=*/false)))
|
|
return failure();
|
|
|
|
if (encloseInParenthesis)
|
|
os << ")";
|
|
|
|
popExpressionPrecedence();
|
|
return success();
|
|
}
|
|
|
|
auto expressionOp = dyn_cast_if_present<ExpressionOp>(value.getDefiningOp());
|
|
if (expressionOp && shouldBeInlined(expressionOp))
|
|
return emitExpression(expressionOp);
|
|
|
|
auto literalOp = dyn_cast_if_present<LiteralOp>(value.getDefiningOp());
|
|
if (!literalOp && !hasValueInScope(value))
|
|
return failure();
|
|
os << getOrCreateName(value);
|
|
return success();
|
|
}
|
|
|
|
LogicalResult CppEmitter::emitOperands(Operation &op) {
|
|
return interleaveCommaWithError(op.getOperands(), os, [&](Value operand) {
|
|
// If an expression is being emitted, push lowest precedence as these
|
|
// operands are either wrapped by parenthesis.
|
|
if (getEmittedExpression())
|
|
pushExpressionPrecedence(lowestPrecedence());
|
|
if (failed(emitOperand(operand)))
|
|
return failure();
|
|
if (getEmittedExpression())
|
|
popExpressionPrecedence();
|
|
return success();
|
|
});
|
|
}
|
|
|
|
LogicalResult
|
|
CppEmitter::emitOperandsAndAttributes(Operation &op,
|
|
ArrayRef<StringRef> exclude) {
|
|
if (failed(emitOperands(op)))
|
|
return failure();
|
|
// Insert comma in between operands and non-filtered attributes if needed.
|
|
if (op.getNumOperands() > 0) {
|
|
for (NamedAttribute attr : op.getAttrs()) {
|
|
if (!llvm::is_contained(exclude, attr.getName().strref())) {
|
|
os << ", ";
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
// Emit attributes.
|
|
auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult {
|
|
if (llvm::is_contained(exclude, attr.getName().strref()))
|
|
return success();
|
|
os << "/* " << attr.getName().getValue() << " */";
|
|
if (failed(emitAttribute(op.getLoc(), attr.getValue())))
|
|
return failure();
|
|
return success();
|
|
};
|
|
return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute);
|
|
}
|
|
|
|
LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
|
|
if (!hasValueInScope(result)) {
|
|
return result.getDefiningOp()->emitOpError(
|
|
"result variable for the operation has not been declared");
|
|
}
|
|
os << getOrCreateName(result) << " = ";
|
|
return success();
|
|
}
|
|
|
|
LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
|
|
bool trailingSemicolon) {
|
|
if (hasValueInScope(result)) {
|
|
return result.getDefiningOp()->emitError(
|
|
"result variable for the operation already declared");
|
|
}
|
|
if (failed(emitType(result.getOwner()->getLoc(), result.getType())))
|
|
return failure();
|
|
os << " " << getOrCreateName(result);
|
|
if (trailingSemicolon)
|
|
os << ";\n";
|
|
return success();
|
|
}
|
|
|
|
LogicalResult CppEmitter::emitAssignPrefix(Operation &op) {
|
|
// If op is being emitted as part of an expression, bail out.
|
|
if (getEmittedExpression())
|
|
return success();
|
|
|
|
switch (op.getNumResults()) {
|
|
case 0:
|
|
break;
|
|
case 1: {
|
|
OpResult result = op.getResult(0);
|
|
if (shouldDeclareVariablesAtTop()) {
|
|
if (failed(emitVariableAssignment(result)))
|
|
return failure();
|
|
} else {
|
|
if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/false)))
|
|
return failure();
|
|
os << " = ";
|
|
}
|
|
break;
|
|
}
|
|
default:
|
|
if (!shouldDeclareVariablesAtTop()) {
|
|
for (OpResult result : op.getResults()) {
|
|
if (failed(emitVariableDeclaration(result, /*trailingSemicolon=*/true)))
|
|
return failure();
|
|
}
|
|
}
|
|
os << "std::tie(";
|
|
interleaveComma(op.getResults(), os,
|
|
[&](Value result) { os << getOrCreateName(result); });
|
|
os << ") = ";
|
|
}
|
|
return success();
|
|
}
|
|
|
|
LogicalResult CppEmitter::emitLabel(Block &block) {
|
|
if (!hasBlockLabel(block))
|
|
return block.getParentOp()->emitError("label for block not found");
|
|
// FIXME: Add feature in `raw_indented_ostream` to ignore indent for block
|
|
// label instead of using `getOStream`.
|
|
os.getOStream() << getOrCreateName(block) << ":\n";
|
|
return success();
|
|
}
|
|
|
|
LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
|
|
LogicalResult status =
|
|
llvm::TypeSwitch<Operation *, LogicalResult>(&op)
|
|
// Builtin ops.
|
|
.Case<ModuleOp>([&](auto op) { return printOperation(*this, op); })
|
|
// CF ops.
|
|
.Case<cf::BranchOp, cf::CondBranchOp>(
|
|
[&](auto op) { return printOperation(*this, op); })
|
|
// EmitC ops.
|
|
.Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp,
|
|
emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
|
|
emitc::ConstantOp, emitc::DivOp, emitc::ExpressionOp,
|
|
emitc::ForOp, emitc::IfOp, emitc::IncludeOp, emitc::MulOp,
|
|
emitc::RemOp, emitc::SubOp, emitc::VariableOp>(
|
|
[&](auto op) { return printOperation(*this, op); })
|
|
// Func ops.
|
|
.Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>(
|
|
[&](auto op) { return printOperation(*this, op); })
|
|
// Arithmetic ops.
|
|
.Case<arith::ConstantOp>(
|
|
[&](auto op) { return printOperation(*this, op); })
|
|
.Case<emitc::LiteralOp>([&](auto op) { return success(); })
|
|
.Default([&](Operation *) {
|
|
return op.emitOpError("unable to find printer for op");
|
|
});
|
|
|
|
if (failed(status))
|
|
return failure();
|
|
|
|
if (isa<emitc::LiteralOp>(op))
|
|
return success();
|
|
|
|
if (getEmittedExpression() ||
|
|
(isa<emitc::ExpressionOp>(op) &&
|
|
shouldBeInlined(cast<emitc::ExpressionOp>(op))))
|
|
return success();
|
|
|
|
os << (trailingSemicolon ? ";\n" : "\n");
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult CppEmitter::emitType(Location loc, Type type) {
|
|
if (auto iType = dyn_cast<IntegerType>(type)) {
|
|
switch (iType.getWidth()) {
|
|
case 1:
|
|
return (os << "bool"), success();
|
|
case 8:
|
|
case 16:
|
|
case 32:
|
|
case 64:
|
|
if (shouldMapToUnsigned(iType.getSignedness()))
|
|
return (os << "uint" << iType.getWidth() << "_t"), success();
|
|
else
|
|
return (os << "int" << iType.getWidth() << "_t"), success();
|
|
default:
|
|
return emitError(loc, "cannot emit integer type ") << type;
|
|
}
|
|
}
|
|
if (auto fType = dyn_cast<FloatType>(type)) {
|
|
switch (fType.getWidth()) {
|
|
case 32:
|
|
return (os << "float"), success();
|
|
case 64:
|
|
return (os << "double"), success();
|
|
default:
|
|
return emitError(loc, "cannot emit float type ") << type;
|
|
}
|
|
}
|
|
if (auto iType = dyn_cast<IndexType>(type))
|
|
return (os << "size_t"), success();
|
|
if (auto tType = dyn_cast<TensorType>(type)) {
|
|
if (!tType.hasRank())
|
|
return emitError(loc, "cannot emit unranked tensor type");
|
|
if (!tType.hasStaticShape())
|
|
return emitError(loc, "cannot emit tensor type with non static shape");
|
|
os << "Tensor<";
|
|
if (failed(emitType(loc, tType.getElementType())))
|
|
return failure();
|
|
auto shape = tType.getShape();
|
|
for (auto dimSize : shape) {
|
|
os << ", ";
|
|
os << dimSize;
|
|
}
|
|
os << ">";
|
|
return success();
|
|
}
|
|
if (auto tType = dyn_cast<TupleType>(type))
|
|
return emitTupleType(loc, tType.getTypes());
|
|
if (auto oType = dyn_cast<emitc::OpaqueType>(type)) {
|
|
os << oType.getValue();
|
|
return success();
|
|
}
|
|
if (auto pType = dyn_cast<emitc::PointerType>(type)) {
|
|
if (failed(emitType(loc, pType.getPointee())))
|
|
return failure();
|
|
os << "*";
|
|
return success();
|
|
}
|
|
return emitError(loc, "cannot emit type ") << type;
|
|
}
|
|
|
|
LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef<Type> types) {
|
|
switch (types.size()) {
|
|
case 0:
|
|
os << "void";
|
|
return success();
|
|
case 1:
|
|
return emitType(loc, types.front());
|
|
default:
|
|
return emitTupleType(loc, types);
|
|
}
|
|
}
|
|
|
|
LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
|
|
os << "std::tuple<";
|
|
if (failed(interleaveCommaWithError(
|
|
types, os, [&](Type type) { return emitType(loc, type); })))
|
|
return failure();
|
|
os << ">";
|
|
return success();
|
|
}
|
|
|
|
LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os,
|
|
bool declareVariablesAtTop) {
|
|
CppEmitter emitter(os, declareVariablesAtTop);
|
|
return emitter.emitOperation(*op, /*trailingSemicolon=*/false);
|
|
}
|