380 lines
14 KiB
C++
380 lines
14 KiB
C++
//===- DialectSymbolParser.cpp - MLIR Dialect Symbol Parser --------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements the parser for the dialect symbols, such as extended
|
|
// attributes and types.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "AsmParserImpl.h"
|
|
#include "Parser.h"
|
|
#include "mlir/AsmParser/AsmParserState.h"
|
|
#include "mlir/IR/AsmState.h"
|
|
#include "mlir/IR/Attributes.h"
|
|
#include "mlir/IR/BuiltinAttributeInterfaces.h"
|
|
#include "mlir/IR/BuiltinAttributes.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/Dialect.h"
|
|
#include "mlir/IR/DialectImplementation.h"
|
|
#include "mlir/IR/MLIRContext.h"
|
|
#include "mlir/Support/LLVM.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
#include "llvm/Support/MemoryBuffer.h"
|
|
#include "llvm/Support/SourceMgr.h"
|
|
#include <cassert>
|
|
#include <cstddef>
|
|
#include <utility>
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::detail;
|
|
using llvm::MemoryBuffer;
|
|
using llvm::SourceMgr;
|
|
|
|
namespace {
|
|
/// This class provides the main implementation of the DialectAsmParser that
|
|
/// allows for dialects to parse attributes and types. This allows for dialect
|
|
/// hooking into the main MLIR parsing logic.
|
|
class CustomDialectAsmParser : public AsmParserImpl<DialectAsmParser> {
|
|
public:
|
|
CustomDialectAsmParser(StringRef fullSpec, Parser &parser)
|
|
: AsmParserImpl<DialectAsmParser>(parser.getToken().getLoc(), parser),
|
|
fullSpec(fullSpec) {}
|
|
~CustomDialectAsmParser() override = default;
|
|
|
|
/// Returns the full specification of the symbol being parsed. This allows
|
|
/// for using a separate parser if necessary.
|
|
StringRef getFullSymbolSpec() const override { return fullSpec; }
|
|
|
|
private:
|
|
/// The full symbol specification.
|
|
StringRef fullSpec;
|
|
};
|
|
} // namespace
|
|
|
|
///
|
|
/// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>'
|
|
/// pretty-dialect-sym-contents ::= pretty-dialect-sym-body
|
|
/// | '(' pretty-dialect-sym-contents+ ')'
|
|
/// | '[' pretty-dialect-sym-contents+ ']'
|
|
/// | '{' pretty-dialect-sym-contents+ '}'
|
|
/// | '[^[<({>\])}\0]+'
|
|
///
|
|
ParseResult Parser::parseDialectSymbolBody(StringRef &body,
|
|
bool &isCodeCompletion) {
|
|
// Symbol bodies are a relatively unstructured format that contains a series
|
|
// of properly nested punctuation, with anything else in the middle. Scan
|
|
// ahead to find it and consume it if successful, otherwise emit an error.
|
|
const char *curPtr = getTokenSpelling().data();
|
|
|
|
// Scan over the nested punctuation, bailing out on error and consuming until
|
|
// we find the end. We know that we're currently looking at the '<', so we can
|
|
// go until we find the matching '>' character.
|
|
assert(*curPtr == '<');
|
|
SmallVector<char, 8> nestedPunctuation;
|
|
const char *codeCompleteLoc = state.lex.getCodeCompleteLoc();
|
|
|
|
// Functor used to emit an unbalanced punctuation error.
|
|
auto emitPunctError = [&] {
|
|
return emitError() << "unbalanced '" << nestedPunctuation.back()
|
|
<< "' character in pretty dialect name";
|
|
};
|
|
// Functor used to check for unbalanced punctuation.
|
|
auto checkNestedPunctuation = [&](char expectedToken) -> ParseResult {
|
|
if (nestedPunctuation.back() != expectedToken)
|
|
return emitPunctError();
|
|
nestedPunctuation.pop_back();
|
|
return success();
|
|
};
|
|
do {
|
|
// Handle code completions, which may appear in the middle of the symbol
|
|
// body.
|
|
if (curPtr == codeCompleteLoc) {
|
|
isCodeCompletion = true;
|
|
nestedPunctuation.clear();
|
|
break;
|
|
}
|
|
|
|
char c = *curPtr++;
|
|
switch (c) {
|
|
case '\0':
|
|
// This also handles the EOF case.
|
|
if (!nestedPunctuation.empty())
|
|
return emitPunctError();
|
|
return emitError("unexpected nul or EOF in pretty dialect name");
|
|
case '<':
|
|
case '[':
|
|
case '(':
|
|
case '{':
|
|
nestedPunctuation.push_back(c);
|
|
continue;
|
|
|
|
case '-':
|
|
// The sequence `->` is treated as special token.
|
|
if (*curPtr == '>')
|
|
++curPtr;
|
|
continue;
|
|
|
|
case '>':
|
|
if (failed(checkNestedPunctuation('<')))
|
|
return failure();
|
|
break;
|
|
case ']':
|
|
if (failed(checkNestedPunctuation('[')))
|
|
return failure();
|
|
break;
|
|
case ')':
|
|
if (failed(checkNestedPunctuation('(')))
|
|
return failure();
|
|
break;
|
|
case '}':
|
|
if (failed(checkNestedPunctuation('{')))
|
|
return failure();
|
|
break;
|
|
case '"': {
|
|
// Dispatch to the lexer to lex past strings.
|
|
resetToken(curPtr - 1);
|
|
curPtr = state.curToken.getEndLoc().getPointer();
|
|
|
|
// Handle code completions, which may appear in the middle of the symbol
|
|
// body.
|
|
if (state.curToken.isCodeCompletion()) {
|
|
isCodeCompletion = true;
|
|
nestedPunctuation.clear();
|
|
break;
|
|
}
|
|
|
|
// Otherwise, ensure this token was actually a string.
|
|
if (state.curToken.isNot(Token::string))
|
|
return failure();
|
|
break;
|
|
}
|
|
|
|
default:
|
|
continue;
|
|
}
|
|
} while (!nestedPunctuation.empty());
|
|
|
|
// Ok, we succeeded, remember where we stopped, reset the lexer to know it is
|
|
// consuming all this stuff, and return.
|
|
resetToken(curPtr);
|
|
|
|
unsigned length = curPtr - body.begin();
|
|
body = StringRef(body.data(), length);
|
|
return success();
|
|
}
|
|
|
|
/// Parse an extended dialect symbol.
|
|
template <typename Symbol, typename SymbolAliasMap, typename CreateFn>
|
|
static Symbol parseExtendedSymbol(Parser &p, AsmParserState *asmState,
|
|
SymbolAliasMap &aliases,
|
|
CreateFn &&createSymbol) {
|
|
Token tok = p.getToken();
|
|
|
|
// Handle code completion of the extended symbol.
|
|
StringRef identifier = tok.getSpelling().drop_front();
|
|
if (tok.isCodeCompletion() && identifier.empty())
|
|
return p.codeCompleteDialectSymbol(aliases);
|
|
|
|
// Parse the dialect namespace.
|
|
SMRange range = p.getToken().getLocRange();
|
|
SMLoc loc = p.getToken().getLoc();
|
|
p.consumeToken();
|
|
|
|
// Check to see if this is a pretty name.
|
|
auto [dialectName, symbolData] = identifier.split('.');
|
|
bool isPrettyName = !symbolData.empty() || identifier.back() == '.';
|
|
|
|
// Check to see if the symbol has trailing data, i.e. has an immediately
|
|
// following '<'.
|
|
bool hasTrailingData =
|
|
p.getToken().is(Token::less) &&
|
|
identifier.bytes_end() == p.getTokenSpelling().bytes_begin();
|
|
|
|
// If there is no '<' token following this, and if the typename contains no
|
|
// dot, then we are parsing a symbol alias.
|
|
if (!hasTrailingData && !isPrettyName) {
|
|
// Check for an alias for this type.
|
|
auto aliasIt = aliases.find(identifier);
|
|
if (aliasIt == aliases.end())
|
|
return (p.emitWrongTokenError("undefined symbol alias id '" + identifier +
|
|
"'"),
|
|
nullptr);
|
|
if (asmState) {
|
|
if constexpr (std::is_same_v<Symbol, Type>)
|
|
asmState->addTypeAliasUses(identifier, range);
|
|
else
|
|
asmState->addAttrAliasUses(identifier, range);
|
|
}
|
|
return aliasIt->second;
|
|
}
|
|
|
|
// If this isn't an alias, we are parsing a dialect-specific symbol. If the
|
|
// name contains a dot, then this is the "pretty" form. If not, it is the
|
|
// verbose form that looks like <...>.
|
|
if (!isPrettyName) {
|
|
// Point the symbol data to the end of the dialect name to start.
|
|
symbolData = StringRef(dialectName.end(), 0);
|
|
|
|
// Parse the body of the symbol.
|
|
bool isCodeCompletion = false;
|
|
if (p.parseDialectSymbolBody(symbolData, isCodeCompletion))
|
|
return nullptr;
|
|
symbolData = symbolData.drop_front();
|
|
|
|
// If the body contained a code completion it won't have the trailing `>`
|
|
// token, so don't drop it.
|
|
if (!isCodeCompletion)
|
|
symbolData = symbolData.drop_back();
|
|
} else {
|
|
loc = SMLoc::getFromPointer(symbolData.data());
|
|
|
|
// If the dialect's symbol is followed immediately by a <, then lex the body
|
|
// of it into prettyName.
|
|
if (hasTrailingData && p.parseDialectSymbolBody(symbolData))
|
|
return nullptr;
|
|
}
|
|
|
|
return createSymbol(dialectName, symbolData, loc);
|
|
}
|
|
|
|
/// Parse an extended attribute.
|
|
///
|
|
/// extended-attribute ::= (dialect-attribute | attribute-alias)
|
|
/// dialect-attribute ::= `#` dialect-namespace `<` attr-data `>`
|
|
/// (`:` type)?
|
|
/// | `#` alias-name pretty-dialect-sym-body? (`:` type)?
|
|
/// attribute-alias ::= `#` alias-name
|
|
///
|
|
Attribute Parser::parseExtendedAttr(Type type) {
|
|
MLIRContext *ctx = getContext();
|
|
Attribute attr = parseExtendedSymbol<Attribute>(
|
|
*this, state.asmState, state.symbols.attributeAliasDefinitions,
|
|
[&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Attribute {
|
|
// Parse an optional trailing colon type.
|
|
Type attrType = type;
|
|
if (consumeIf(Token::colon) && !(attrType = parseType()))
|
|
return Attribute();
|
|
|
|
// If we found a registered dialect, then ask it to parse the attribute.
|
|
if (Dialect *dialect =
|
|
builder.getContext()->getOrLoadDialect(dialectName)) {
|
|
// Temporarily reset the lexer to let the dialect parse the attribute.
|
|
const char *curLexerPos = getToken().getLoc().getPointer();
|
|
resetToken(symbolData.data());
|
|
|
|
// Parse the attribute.
|
|
CustomDialectAsmParser customParser(symbolData, *this);
|
|
Attribute attr = dialect->parseAttribute(customParser, attrType);
|
|
resetToken(curLexerPos);
|
|
return attr;
|
|
}
|
|
|
|
// Otherwise, form a new opaque attribute.
|
|
return OpaqueAttr::getChecked(
|
|
[&] { return emitError(loc); }, StringAttr::get(ctx, dialectName),
|
|
symbolData, attrType ? attrType : NoneType::get(ctx));
|
|
});
|
|
|
|
// Ensure that the attribute has the same type as requested.
|
|
auto typedAttr = dyn_cast_or_null<TypedAttr>(attr);
|
|
if (type && typedAttr && typedAttr.getType() != type) {
|
|
emitError("attribute type different than expected: expected ")
|
|
<< type << ", but got " << typedAttr.getType();
|
|
return nullptr;
|
|
}
|
|
return attr;
|
|
}
|
|
|
|
/// Parse an extended type.
|
|
///
|
|
/// extended-type ::= (dialect-type | type-alias)
|
|
/// dialect-type ::= `!` dialect-namespace `<` `"` type-data `"` `>`
|
|
/// dialect-type ::= `!` alias-name pretty-dialect-attribute-body?
|
|
/// type-alias ::= `!` alias-name
|
|
///
|
|
Type Parser::parseExtendedType() {
|
|
MLIRContext *ctx = getContext();
|
|
return parseExtendedSymbol<Type>(
|
|
*this, state.asmState, state.symbols.typeAliasDefinitions,
|
|
[&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Type {
|
|
// If we found a registered dialect, then ask it to parse the type.
|
|
if (auto *dialect = ctx->getOrLoadDialect(dialectName)) {
|
|
// Temporarily reset the lexer to let the dialect parse the type.
|
|
const char *curLexerPos = getToken().getLoc().getPointer();
|
|
resetToken(symbolData.data());
|
|
|
|
// Parse the type.
|
|
CustomDialectAsmParser customParser(symbolData, *this);
|
|
Type type = dialect->parseType(customParser);
|
|
resetToken(curLexerPos);
|
|
return type;
|
|
}
|
|
|
|
// Otherwise, form a new opaque type.
|
|
return OpaqueType::getChecked([&] { return emitError(loc); },
|
|
StringAttr::get(ctx, dialectName),
|
|
symbolData);
|
|
});
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// mlir::parseAttribute/parseType
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Parses a symbol, of type 'T', and returns it if parsing was successful. If
|
|
/// parsing failed, nullptr is returned.
|
|
template <typename T, typename ParserFn>
|
|
static T parseSymbol(StringRef inputStr, MLIRContext *context,
|
|
size_t *numReadOut, bool isKnownNullTerminated,
|
|
ParserFn &&parserFn) {
|
|
// Set the buffer name to the string being parsed, so that it appears in error
|
|
// diagnostics.
|
|
auto memBuffer =
|
|
isKnownNullTerminated
|
|
? MemoryBuffer::getMemBuffer(inputStr,
|
|
/*BufferName=*/inputStr)
|
|
: MemoryBuffer::getMemBufferCopy(inputStr, /*BufferName=*/inputStr);
|
|
SourceMgr sourceMgr;
|
|
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
|
|
SymbolState aliasState;
|
|
ParserConfig config(context);
|
|
ParserState state(sourceMgr, config, aliasState, /*asmState=*/nullptr,
|
|
/*codeCompleteContext=*/nullptr);
|
|
Parser parser(state);
|
|
|
|
Token startTok = parser.getToken();
|
|
T symbol = parserFn(parser);
|
|
if (!symbol)
|
|
return T();
|
|
|
|
// Provide the number of bytes that were read.
|
|
Token endTok = parser.getToken();
|
|
size_t numRead =
|
|
endTok.getLoc().getPointer() - startTok.getLoc().getPointer();
|
|
if (numReadOut) {
|
|
*numReadOut = numRead;
|
|
} else if (numRead != inputStr.size()) {
|
|
parser.emitError(endTok.getLoc()) << "found trailing characters: '"
|
|
<< inputStr.drop_front(numRead) << "'";
|
|
return T();
|
|
}
|
|
return symbol;
|
|
}
|
|
|
|
Attribute mlir::parseAttribute(StringRef attrStr, MLIRContext *context,
|
|
Type type, size_t *numRead,
|
|
bool isKnownNullTerminated) {
|
|
return parseSymbol<Attribute>(
|
|
attrStr, context, numRead, isKnownNullTerminated,
|
|
[type](Parser &parser) { return parser.parseAttribute(type); });
|
|
}
|
|
Type mlir::parseType(StringRef typeStr, MLIRContext *context, size_t *numRead,
|
|
bool isKnownNullTerminated) {
|
|
return parseSymbol<Type>(typeStr, context, numRead, isKnownNullTerminated,
|
|
[](Parser &parser) { return parser.parseType(); });
|
|
}
|