Improve diagnostics and type checking

This commit is contained in:
Sam Vervaeck 2022-08-21 20:56:58 +02:00
parent cd1e20d460
commit 311f1d228b
6 changed files with 341 additions and 90 deletions

View file

@ -1,6 +1,7 @@
#pragma once
#include "bolt/Diagnostics.hpp"
#include "zen/config.hpp"
#include "bolt/ByteString.hpp"
@ -55,9 +56,10 @@ namespace bolt {
const size_t Id;
std::vector<Type*> Args;
ByteString DisplayName;
inline TCon(const size_t Id, std::vector<Type*> Args ):
Type(TypeKind::Con), Id(Id), Args(Args) {}
inline TCon(const size_t Id, std::vector<Type*> Args, ByteString DisplayName):
Type(TypeKind::Con), Id(Id), Args(Args), DisplayName(DisplayName) {}
};
@ -107,9 +109,20 @@ namespace bolt {
public:
TVSet TVs;
std::vector<Constraint*> Constriants;
std::vector<Constraint*> Constraints;
Type* Type;
inline Forall(class Type* Type):
Type(Type) {}
inline Forall(
TVSet TVs,
std::vector<Constraint*> Constraints,
class Type* Type
): TVs(TVs),
Constraints(Constraints),
Type(Type) {}
};
enum class SchemeKind : unsigned char {
@ -256,12 +269,16 @@ namespace bolt {
class Checker {
DiagnosticEngine& DE;
size_t nextConTypeId = 0;
size_t nextTypeVarId = 0;
Type* inferExpression(Expression* Expression, InferContext& Env);
void infer(Node* node, InferContext& Env);
TCon* createPrimConType();
TVar* createTypeVar();
Type* instantiate(Scheme& S);
@ -272,6 +289,8 @@ namespace bolt {
public:
Checker(DiagnosticEngine& DE);
void check(SourceFile* SF);
};

View file

@ -3,15 +3,38 @@
#include <vector>
#include <stdexcept>
#include <memory>
#include <iostream>
#include "bolt/ByteString.hpp"
#include "bolt/String.hpp"
#include "bolt/CST.hpp"
namespace bolt {
class Type;
enum class DiagnosticKind : unsigned char {
UnexpectedToken,
UnexpectedString,
BindingNotFound,
UnificationError,
};
class Diagnostic : std::runtime_error {
const DiagnosticKind Kind;
protected:
Diagnostic(DiagnosticKind Kind);
public:
Diagnostic();
DiagnosticKind getKind() const noexcept {
return Kind;
}
};
class UnexpectedTokenDiagnostic : public Diagnostic {
@ -21,7 +44,7 @@ namespace bolt {
std::vector<NodeType> Expected;
inline UnexpectedTokenDiagnostic(Token* Actual, std::vector<NodeType> Expected):
Actual(Actual), Expected(Expected) {}
Diagnostic(DiagnosticKind::UnexpectedToken), Actual(Actual), Expected(Expected) {}
};
@ -32,8 +55,58 @@ namespace bolt {
String Actual;
inline UnexpectedStringDiagnostic(TextLoc Location, String Actual):
Location(Location), Actual(Actual) {}
Diagnostic(DiagnosticKind::UnexpectedString), Location(Location), Actual(Actual) {}
};
class BindingNotFoundDiagnostic : public Diagnostic {
public:
ByteString Name;
Node* Initiator;
inline BindingNotFoundDiagnostic(ByteString Name, Node* Initiator):
Diagnostic(DiagnosticKind::BindingNotFound), Name(Name), Initiator(Initiator) {}
};
class UnificationErrorDiagnostic : public Diagnostic {
public:
Type* Left;
Type* Right;
inline UnificationErrorDiagnostic(Type* Left, Type* Right):
Diagnostic(DiagnosticKind::UnificationError), Left(Left), Right(Right) {}
};
class DiagnosticEngine {
protected:
public:
virtual void addDiagnostic(const Diagnostic& Diagnostic) = 0;
template<typename D, typename ...Ts>
void add(Ts&&... Args) {
D Diag { std::forward<Ts>(Args)... };
addDiagnostic(Diag);
}
virtual ~DiagnosticEngine() {}
};
class ConsoleDiagnostics : public DiagnosticEngine {
std::ostream& Out;
public:
void addDiagnostic(const Diagnostic& Diagnostic) override;
ConsoleDiagnostics(std::ostream& Out = std::cerr);
};
}

View file

@ -1,6 +1,7 @@
#ifndef BOLT_TEXT_HPP
#define BOLT_TEXT_HPP
#include "ByteString.hpp"
#include <stddef.h>
#include <string>
@ -13,7 +14,7 @@ namespace bolt {
size_t Line = 1;
size_t Column = 1;
void advance(const std::string& Text) {
inline void advance(const std::string& Text) {
for (auto Chr: Text) {
if (Chr == '\n') {
Line++;
@ -32,6 +33,11 @@ namespace bolt {
TextLoc End;
};
class TextFile {
public:
ByteString getText();
};
}
#endif // of #ifndef BOLT_TEXT_HPP

View file

@ -1,6 +1,7 @@
#include <stack>
#include "bolt/Diagnostics.hpp"
#include "zen/config.hpp"
#include "bolt/CST.hpp"
@ -26,6 +27,10 @@ namespace bolt {
return F.Type;
}
void TypeEnv::add(ByteString Name, Scheme S) {
Mapping.emplace(Name, S);
}
bool Type::hasTypeVar(const TVar* TV) {
switch (Kind) {
case TypeKind::Var:
@ -40,6 +45,18 @@ namespace bolt {
}
return Y->ReturnType->hasTypeVar(TV);
}
case TypeKind::Con:
{
auto Y = static_cast<TCon*>(this);
for (auto Ty: Y->Args) {
if (Ty->hasTypeVar(TV)) {
return true;
}
}
return false;
}
case TypeKind::Any:
return false;
}
}
@ -70,7 +87,7 @@ namespace bolt {
for (auto Arg: Y->Args) {
NewArgs.push_back(Arg->substitute(Sub));
}
return new TCon(Y->Id, Y->Args);
return new TCon(Y->Id, Y->Args, Y->DisplayName);
}
}
}
@ -79,6 +96,9 @@ namespace bolt {
Constraints.push_back(C);
}
Checker::Checker(DiagnosticEngine& DE):
DE(DE) {}
void Checker::infer(Node* X, InferContext& Ctx) {
switch (X->Type) {
@ -143,14 +163,19 @@ namespace bolt {
case NodeType::ConstantExpression:
{
auto Y = static_cast<ConstantExpression*>(X);
Type* Ty = nullptr;
switch (Y->Token->Type) {
case NodeType::IntegerLiteral:
return Ctx.Env.lookupMono("Int");
Ty = Ctx.Env.lookupMono("Int");
break;
case NodeType::StringLiteral:
return Ctx.Env.lookupMono("String");
Ty = Ctx.Env.lookupMono("String");
break;
default:
ZEN_UNREACHABLE
}
ZEN_ASSERT(Ty != nullptr);
return Ty;
}
case NodeType::ReferenceExpression:
@ -158,7 +183,7 @@ namespace bolt {
auto Y = static_cast<ReferenceExpression*>(X);
auto Scm = Ctx.Env.lookup(Y->Name->Text);
if (Scm == nullptr) {
// TODO add diagnostic
DE.add<BindingNotFoundDiagnostic>(Y->Name->Text, Y->Name);
return new TAny();
}
return instantiate(*Scm);
@ -169,7 +194,7 @@ namespace bolt {
auto Y = static_cast<InfixExpression*>(X);
auto Scm = Ctx.Env.lookup(Y->Operator->getText());
if (Scm == nullptr) {
// TODO add diagnostic
DE.add<BindingNotFoundDiagnostic>(Y->Operator->getText(), Y->Operator);
return new TAny();
}
auto OpTy = instantiate(*Scm);
@ -190,6 +215,11 @@ namespace bolt {
void Checker::check(SourceFile *SF) {
TypeEnv Global;
auto StringTy = new TCon(nextConTypeId++, {}, "String");
Global.add("String", Forall(StringTy));
auto IntTy = new TCon(nextConTypeId++, {}, "Int");
Global.add("Int", Forall(IntTy));
Global.add("+", Forall(new TArrow({ IntTy, IntTy }, IntTy)));
ConstraintSet Constraints;
InferContext Toplevel { Constraints, Global };
infer(SF, Toplevel);
@ -199,6 +229,7 @@ namespace bolt {
void Checker::solve(Constraint* Constraint) {
std::stack<class Constraint*> Queue;
Queue.push(Constraint);
TVSub Sub;
while (!Queue.empty()) {
@ -225,8 +256,7 @@ namespace bolt {
{
auto Y = static_cast<CEqual*>(Constraint);
if (!unify(Y->Left, Y->Right, Sub)) {
// TODO diagnostic
fprintf(stderr, "unification error\n");
DE.add<UnificationErrorDiagnostic>(Y->Left, Y->Right);
}
break;
}

View file

@ -1,9 +1,201 @@
#include <sstream>
#include "zen/config.hpp"
#include "bolt/Diagnostics.hpp"
#include "bolt/Checker.hpp"
#define ANSI_RESET "\u001b[0m"
#define ANSI_BOLD "\u001b[1m"
#define ANSI_UNDERLINE "\u001b[4m"
#define ANSI_REVERSED "\u001b[7m"
#define ANSI_FG_BLACK "\u001b[30m"
#define ANSI_FG_RED "\u001b[31m"
#define ANSI_FG_GREEN "\u001b[32m"
#define ANSI_FG_YELLOW "\u001b[33m"
#define ANSI_FG_BLUE "\u001b[34m"
#define ANSI_FG_CYAN "\u001b[35m"
#define ANSI_FG_MAGENTA "\u001b[36m"
#define ANSI_FG_WHITE "\u001b[37m"
#define ANSI_BG_BLACK "\u001b[40m"
#define ANSI_BG_RED "\u001b[41m"
#define ANSI_BG_GREEN "\u001b[42m"
#define ANSI_BG_YELLOW "\u001b[43m"
#define ANSI_BG_BLUE "\u001b[44m"
#define ANSI_BG_CYAN "\u001b[45m"
#define ANSI_BG_MAGENTA "\u001b[46m"
#define ANSI_BG_WHITE "\u001b[47m"
namespace bolt {
Diagnostic::Diagnostic():
std::runtime_error("a compiler error occurred without being caught") {}
Diagnostic::Diagnostic(DiagnosticKind Kind):
std::runtime_error("a compiler error occurred without being caught"), Kind(Kind) {}
static std::string describe(NodeType Type) {
switch (Type) {
case NodeType::Identifier:
return "an identifier";
case NodeType::CustomOperator:
return "an operator";
case NodeType::IntegerLiteral:
return "an integer literal";
case NodeType::EndOfFile:
return "end-of-file";
case NodeType::BlockStart:
return "the start of a new indented block";
case NodeType::BlockEnd:
return "the end of the current indented block";
case NodeType::LineFoldEnd:
return "the end of the current line-fold";
case NodeType::LParen:
return "'('";
case NodeType::RParen:
return "')'";
case NodeType::LBrace:
return "'['";
case NodeType::RBrace:
return "']'";
case NodeType::LBracket:
return "'{'";
case NodeType::RBracket:
return "'}'";
case NodeType::Colon:
return "':'";
case NodeType::Equals:
return "'='";
case NodeType::StringLiteral:
return "a string literal";
case NodeType::Dot:
return "'.'";
case NodeType::PubKeyword:
return "'pub'";
case NodeType::LetKeyword:
return "'let'";
case NodeType::MutKeyword:
return "'mut'";
case NodeType::ReturnKeyword:
return "'return'";
case NodeType::TypeKeyword:
return "'type'";
default:
ZEN_UNREACHABLE
}
}
static std::string describe(const Type* Ty) {
switch (Ty->getKind()) {
case TypeKind::Any:
return "any";
case TypeKind::Var:
return "a" + std::to_string(static_cast<const TVar*>(Ty)->Id);
case TypeKind::Arrow:
{
auto Y = static_cast<const TArrow*>(Ty);
std::ostringstream Out;
Out << "(";
bool First = true;
for (auto PT: Y->ParamTypes) {
if (First) First = false;
else Out << ", ";
Out << describe(PT);
}
Out << ") -> " << describe(Y->ReturnType);
return Out.str();
}
case TypeKind::Con:
{
auto Y = static_cast<const TCon*>(Ty);
std::ostringstream Out;
if (!Y->DisplayName.empty()) {
Out << Y->DisplayName;
} else {
Out << "C" << Y->Id;
}
for (auto Arg: Y->Args) {
Out << " " << describe(Arg);
}
return Out.str();
}
}
}
ConsoleDiagnostics::ConsoleDiagnostics(std::ostream& Out):
Out(Out) {}
void ConsoleDiagnostics::addDiagnostic(const Diagnostic& D) {
switch (D.getKind()) {
case DiagnosticKind::BindingNotFound:
{
auto E = static_cast<const BindingNotFoundDiagnostic&>(D);
Out << ANSI_BOLD ANSI_FG_RED "error: " ANSI_RESET "binding '" << E.Name << "' was not found\n";
//if (E.Initiator != nullptr) {
// writeExcerpt(E.Initiator->getRange());
//}
break;
}
case DiagnosticKind::UnexpectedToken:
{
auto E = static_cast<const UnexpectedTokenDiagnostic&>(D);
Out << "<unknown.bolt>:" << E.Actual->getStartLine() << ":" << E.Actual->getStartColumn() << ": expected ";
switch (E.Expected.size()) {
case 0:
Out << "nothing";
break;
case 1:
Out << describe(E.Expected[0]);
break;
default:
auto Iter = E.Expected.begin();
Out << describe(*Iter++);
NodeType Prev;
while (Iter != E.Expected.end()) {
Out << ", " << describe(Prev);
Prev = *Iter++;
}
Out << " or " << describe(Prev);
break;
}
Out << " but instead got '" << E.Actual->getText() << "'\n";
break;
}
case DiagnosticKind::UnexpectedString:
{
auto E = static_cast<const UnexpectedStringDiagnostic&>(D);
Out << "<unknown.bolt>:" << E.Location.Line << ":" << E.Location.Column << ": unexpected '";
for (auto Chr: E.Actual) {
switch (Chr) {
case '\\':
Out << "\\\\";
break;
case '\'':
Out << "\\'";
break;
default:
Out << Chr;
break;
}
}
break;
}
case DiagnosticKind::UnificationError:
{
auto E = static_cast<const UnificationErrorDiagnostic&>(D);
Out << ANSI_FG_RED << ANSI_BOLD << "error: " << ANSI_RESET << "the types " << ANSI_FG_GREEN << describe(E.Left) << ANSI_RESET
<< " and " << ANSI_FG_GREEN << describe(E.Right) << ANSI_RESET << " failed to match\n";
break;
}
}
}
}

View file

@ -29,58 +29,6 @@ String readFile(std::string Path) {
return Out;
}
std::string describe(NodeType Type) {
switch (Type) {
case NodeType::Identifier:
return "an identifier";
case NodeType::CustomOperator:
return "an operator";
case NodeType::IntegerLiteral:
return "an integer literal";
case NodeType::EndOfFile:
return "end-of-file";
case NodeType::BlockStart:
return "the start of a new indented block";
case NodeType::BlockEnd:
return "the end of the current indented block";
case NodeType::LineFoldEnd:
return "the end of the current line-fold";
case NodeType::LParen:
return "'('";
case NodeType::RParen:
return "')'";
case NodeType::LBrace:
return "'['";
case NodeType::RBrace:
return "']'";
case NodeType::LBracket:
return "'{'";
case NodeType::RBracket:
return "'}'";
case NodeType::Colon:
return "':'";
case NodeType::Equals:
return "'='";
case NodeType::StringLiteral:
return "a string literal";
case NodeType::Dot:
return "'.'";
case NodeType::PubKeyword:
return "'pub'";
case NodeType::LetKeyword:
return "'let'";
case NodeType::MutKeyword:
return "'mut'";
case NodeType::ReturnKeyword:
return "'return'";
case NodeType::TypeKeyword:
return "'type'";
default:
ZEN_UNREACHABLE
}
}
int main(int argc, const char* argv[]) {
if (argc < 2) {
@ -88,6 +36,8 @@ int main(int argc, const char* argv[]) {
return 1;
}
ConsoleDiagnostics DE;
auto Text = readFile(argv[1]);
VectorStream<String> Chars(Text, EOF);
Scanner S(Chars);
@ -99,33 +49,14 @@ int main(int argc, const char* argv[]) {
#ifdef NDEBUG
try {
SF = P.parseSourceFile();
} catch (UnexpectedTokenDiagnostic& E) {
std::cerr << "<unknown.bolt>:" << E.Actual->getStartLine() << ":" << E.Actual->getStartColumn() << ": expected ";
switch (E.Expected.size()) {
case 0:
std::cerr << "nothing";
break;
case 1:
std::cerr << describe(E.Expected[0]);
break;
default:
auto Iter = E.Expected.begin();
std::cerr << describe(*Iter++);
NodeType Prev;
while (Iter != E.Expected.end()) {
std::cerr << ", " << describe(Prev);
Prev = *Iter++;
}
std::cerr << " or " << describe(Prev);
break;
}
std::cerr << " but instead got '" << E.Actual->getText() << "'\n";
} catch (Diagnostic& D) {
DE.addDiagnostic(D);
}
#else
SF = P.parseSourceFile();
#endif
Checker TheChecker;
Checker TheChecker { DE };
TheChecker.check(SF);
return 0;