From 311f1d228b1694bb33ac420ad937bfab3d9af94c Mon Sep 17 00:00:00 2001 From: Sam Vervaeck Date: Sun, 21 Aug 2022 20:56:58 +0200 Subject: [PATCH] Improve diagnostics and type checking --- include/bolt/Checker.hpp | 25 ++++- include/bolt/Diagnostics.hpp | 79 +++++++++++++- include/bolt/Text.hpp | 8 +- src/Checker.cc | 44 ++++++-- src/Diagnostics.cc | 196 ++++++++++++++++++++++++++++++++++- src/main.cc | 79 +------------- 6 files changed, 341 insertions(+), 90 deletions(-) diff --git a/include/bolt/Checker.hpp b/include/bolt/Checker.hpp index db1161092..7505566a8 100644 --- a/include/bolt/Checker.hpp +++ b/include/bolt/Checker.hpp @@ -1,6 +1,7 @@ #pragma once +#include "bolt/Diagnostics.hpp" #include "zen/config.hpp" #include "bolt/ByteString.hpp" @@ -55,9 +56,10 @@ namespace bolt { const size_t Id; std::vector Args; + ByteString DisplayName; - inline TCon(const size_t Id, std::vector Args ): - Type(TypeKind::Con), Id(Id), Args(Args) {} + inline TCon(const size_t Id, std::vector Args, ByteString DisplayName): + Type(TypeKind::Con), Id(Id), Args(Args), DisplayName(DisplayName) {} }; @@ -107,9 +109,20 @@ namespace bolt { public: TVSet TVs; - std::vector Constriants; + std::vector Constraints; Type* Type; + inline Forall(class Type* Type): + Type(Type) {} + + inline Forall( + TVSet TVs, + std::vector Constraints, + class Type* Type + ): TVs(TVs), + Constraints(Constraints), + Type(Type) {} + }; enum class SchemeKind : unsigned char { @@ -256,12 +269,16 @@ namespace bolt { class Checker { + DiagnosticEngine& DE; + + size_t nextConTypeId = 0; size_t nextTypeVarId = 0; Type* inferExpression(Expression* Expression, InferContext& Env); void infer(Node* node, InferContext& Env); + TCon* createPrimConType(); TVar* createTypeVar(); Type* instantiate(Scheme& S); @@ -272,6 +289,8 @@ namespace bolt { public: + Checker(DiagnosticEngine& DE); + void check(SourceFile* SF); }; diff --git a/include/bolt/Diagnostics.hpp b/include/bolt/Diagnostics.hpp index 447403ca0..4c50f5c76 100644 --- a/include/bolt/Diagnostics.hpp +++ b/include/bolt/Diagnostics.hpp @@ -3,15 +3,38 @@ #include #include +#include +#include +#include "bolt/ByteString.hpp" #include "bolt/String.hpp" #include "bolt/CST.hpp" namespace bolt { + class Type; + + enum class DiagnosticKind : unsigned char { + UnexpectedToken, + UnexpectedString, + BindingNotFound, + UnificationError, + }; + class Diagnostic : std::runtime_error { + + const DiagnosticKind Kind; + + protected: + + Diagnostic(DiagnosticKind Kind); + public: - Diagnostic(); + + DiagnosticKind getKind() const noexcept { + return Kind; + } + }; class UnexpectedTokenDiagnostic : public Diagnostic { @@ -21,7 +44,7 @@ namespace bolt { std::vector Expected; inline UnexpectedTokenDiagnostic(Token* Actual, std::vector Expected): - Actual(Actual), Expected(Expected) {} + Diagnostic(DiagnosticKind::UnexpectedToken), Actual(Actual), Expected(Expected) {} }; @@ -32,8 +55,58 @@ namespace bolt { String Actual; inline UnexpectedStringDiagnostic(TextLoc Location, String Actual): - Location(Location), Actual(Actual) {} + Diagnostic(DiagnosticKind::UnexpectedString), Location(Location), Actual(Actual) {} }; + class BindingNotFoundDiagnostic : public Diagnostic { + public: + + ByteString Name; + Node* Initiator; + + inline BindingNotFoundDiagnostic(ByteString Name, Node* Initiator): + Diagnostic(DiagnosticKind::BindingNotFound), Name(Name), Initiator(Initiator) {} + + }; + + class UnificationErrorDiagnostic : public Diagnostic { + public: + + Type* Left; + Type* Right; + + inline UnificationErrorDiagnostic(Type* Left, Type* Right): + Diagnostic(DiagnosticKind::UnificationError), Left(Left), Right(Right) {} + + }; + + class DiagnosticEngine { + protected: + + public: + + virtual void addDiagnostic(const Diagnostic& Diagnostic) = 0; + + template + void add(Ts&&... Args) { + D Diag { std::forward(Args)... }; + addDiagnostic(Diag); + } + + virtual ~DiagnosticEngine() {} + + }; + + class ConsoleDiagnostics : public DiagnosticEngine { + + std::ostream& Out; + + public: + + void addDiagnostic(const Diagnostic& Diagnostic) override; + + ConsoleDiagnostics(std::ostream& Out = std::cerr); + + }; } diff --git a/include/bolt/Text.hpp b/include/bolt/Text.hpp index d5707495d..fb9fa61d2 100644 --- a/include/bolt/Text.hpp +++ b/include/bolt/Text.hpp @@ -1,6 +1,7 @@ #ifndef BOLT_TEXT_HPP #define BOLT_TEXT_HPP +#include "ByteString.hpp" #include #include @@ -13,7 +14,7 @@ namespace bolt { size_t Line = 1; size_t Column = 1; - void advance(const std::string& Text) { + inline void advance(const std::string& Text) { for (auto Chr: Text) { if (Chr == '\n') { Line++; @@ -32,6 +33,11 @@ namespace bolt { TextLoc End; }; + class TextFile { + public: + ByteString getText(); + }; + } #endif // of #ifndef BOLT_TEXT_HPP diff --git a/src/Checker.cc b/src/Checker.cc index 8a254f750..32b58fc70 100644 --- a/src/Checker.cc +++ b/src/Checker.cc @@ -1,6 +1,7 @@ #include +#include "bolt/Diagnostics.hpp" #include "zen/config.hpp" #include "bolt/CST.hpp" @@ -26,6 +27,10 @@ namespace bolt { return F.Type; } + void TypeEnv::add(ByteString Name, Scheme S) { + Mapping.emplace(Name, S); + } + bool Type::hasTypeVar(const TVar* TV) { switch (Kind) { case TypeKind::Var: @@ -40,6 +45,18 @@ namespace bolt { } return Y->ReturnType->hasTypeVar(TV); } + case TypeKind::Con: + { + auto Y = static_cast(this); + for (auto Ty: Y->Args) { + if (Ty->hasTypeVar(TV)) { + return true; + } + } + return false; + } + case TypeKind::Any: + return false; } } @@ -70,7 +87,7 @@ namespace bolt { for (auto Arg: Y->Args) { NewArgs.push_back(Arg->substitute(Sub)); } - return new TCon(Y->Id, Y->Args); + return new TCon(Y->Id, Y->Args, Y->DisplayName); } } } @@ -79,6 +96,9 @@ namespace bolt { Constraints.push_back(C); } + Checker::Checker(DiagnosticEngine& DE): + DE(DE) {} + void Checker::infer(Node* X, InferContext& Ctx) { switch (X->Type) { @@ -143,14 +163,19 @@ namespace bolt { case NodeType::ConstantExpression: { auto Y = static_cast(X); + Type* Ty = nullptr; switch (Y->Token->Type) { case NodeType::IntegerLiteral: - return Ctx.Env.lookupMono("Int"); + Ty = Ctx.Env.lookupMono("Int"); + break; case NodeType::StringLiteral: - return Ctx.Env.lookupMono("String"); + Ty = Ctx.Env.lookupMono("String"); + break; default: ZEN_UNREACHABLE } + ZEN_ASSERT(Ty != nullptr); + return Ty; } case NodeType::ReferenceExpression: @@ -158,7 +183,7 @@ namespace bolt { auto Y = static_cast(X); auto Scm = Ctx.Env.lookup(Y->Name->Text); if (Scm == nullptr) { - // TODO add diagnostic + DE.add(Y->Name->Text, Y->Name); return new TAny(); } return instantiate(*Scm); @@ -169,7 +194,7 @@ namespace bolt { auto Y = static_cast(X); auto Scm = Ctx.Env.lookup(Y->Operator->getText()); if (Scm == nullptr) { - // TODO add diagnostic + DE.add(Y->Operator->getText(), Y->Operator); return new TAny(); } auto OpTy = instantiate(*Scm); @@ -190,6 +215,11 @@ namespace bolt { void Checker::check(SourceFile *SF) { TypeEnv Global; + auto StringTy = new TCon(nextConTypeId++, {}, "String"); + Global.add("String", Forall(StringTy)); + auto IntTy = new TCon(nextConTypeId++, {}, "Int"); + Global.add("Int", Forall(IntTy)); + Global.add("+", Forall(new TArrow({ IntTy, IntTy }, IntTy))); ConstraintSet Constraints; InferContext Toplevel { Constraints, Global }; infer(SF, Toplevel); @@ -199,6 +229,7 @@ namespace bolt { void Checker::solve(Constraint* Constraint) { std::stack Queue; + Queue.push(Constraint); TVSub Sub; while (!Queue.empty()) { @@ -225,8 +256,7 @@ namespace bolt { { auto Y = static_cast(Constraint); if (!unify(Y->Left, Y->Right, Sub)) { - // TODO diagnostic - fprintf(stderr, "unification error\n"); + DE.add(Y->Left, Y->Right); } break; } diff --git a/src/Diagnostics.cc b/src/Diagnostics.cc index 6acfda46f..c4fe3c6da 100644 --- a/src/Diagnostics.cc +++ b/src/Diagnostics.cc @@ -1,9 +1,201 @@ +#include + +#include "zen/config.hpp" + #include "bolt/Diagnostics.hpp" +#include "bolt/Checker.hpp" + +#define ANSI_RESET "\u001b[0m" +#define ANSI_BOLD "\u001b[1m" +#define ANSI_UNDERLINE "\u001b[4m" +#define ANSI_REVERSED "\u001b[7m" + +#define ANSI_FG_BLACK "\u001b[30m" +#define ANSI_FG_RED "\u001b[31m" +#define ANSI_FG_GREEN "\u001b[32m" +#define ANSI_FG_YELLOW "\u001b[33m" +#define ANSI_FG_BLUE "\u001b[34m" +#define ANSI_FG_CYAN "\u001b[35m" +#define ANSI_FG_MAGENTA "\u001b[36m" +#define ANSI_FG_WHITE "\u001b[37m" + +#define ANSI_BG_BLACK "\u001b[40m" +#define ANSI_BG_RED "\u001b[41m" +#define ANSI_BG_GREEN "\u001b[42m" +#define ANSI_BG_YELLOW "\u001b[43m" +#define ANSI_BG_BLUE "\u001b[44m" +#define ANSI_BG_CYAN "\u001b[45m" +#define ANSI_BG_MAGENTA "\u001b[46m" +#define ANSI_BG_WHITE "\u001b[47m" namespace bolt { - Diagnostic::Diagnostic(): - std::runtime_error("a compiler error occurred without being caught") {} + Diagnostic::Diagnostic(DiagnosticKind Kind): + std::runtime_error("a compiler error occurred without being caught"), Kind(Kind) {} + + static std::string describe(NodeType Type) { + switch (Type) { + case NodeType::Identifier: + return "an identifier"; + case NodeType::CustomOperator: + return "an operator"; + case NodeType::IntegerLiteral: + return "an integer literal"; + case NodeType::EndOfFile: + return "end-of-file"; + case NodeType::BlockStart: + return "the start of a new indented block"; + case NodeType::BlockEnd: + return "the end of the current indented block"; + case NodeType::LineFoldEnd: + return "the end of the current line-fold"; + case NodeType::LParen: + return "'('"; + case NodeType::RParen: + return "')'"; + case NodeType::LBrace: + return "'['"; + case NodeType::RBrace: + return "']'"; + case NodeType::LBracket: + return "'{'"; + case NodeType::RBracket: + return "'}'"; + case NodeType::Colon: + return "':'"; + case NodeType::Equals: + return "'='"; + case NodeType::StringLiteral: + return "a string literal"; + case NodeType::Dot: + return "'.'"; + case NodeType::PubKeyword: + return "'pub'"; + case NodeType::LetKeyword: + return "'let'"; + case NodeType::MutKeyword: + return "'mut'"; + case NodeType::ReturnKeyword: + return "'return'"; + case NodeType::TypeKeyword: + return "'type'"; + default: + ZEN_UNREACHABLE + } + } + + static std::string describe(const Type* Ty) { + switch (Ty->getKind()) { + case TypeKind::Any: + return "any"; + case TypeKind::Var: + return "a" + std::to_string(static_cast(Ty)->Id); + case TypeKind::Arrow: + { + auto Y = static_cast(Ty); + std::ostringstream Out; + Out << "("; + bool First = true; + for (auto PT: Y->ParamTypes) { + if (First) First = false; + else Out << ", "; + Out << describe(PT); + } + Out << ") -> " << describe(Y->ReturnType); + return Out.str(); + } + case TypeKind::Con: + { + auto Y = static_cast(Ty); + std::ostringstream Out; + if (!Y->DisplayName.empty()) { + Out << Y->DisplayName; + } else { + Out << "C" << Y->Id; + } + for (auto Arg: Y->Args) { + Out << " " << describe(Arg); + } + return Out.str(); + } + } + } + + + ConsoleDiagnostics::ConsoleDiagnostics(std::ostream& Out): + Out(Out) {} + + void ConsoleDiagnostics::addDiagnostic(const Diagnostic& D) { + + switch (D.getKind()) { + + case DiagnosticKind::BindingNotFound: + { + auto E = static_cast(D); + Out << ANSI_BOLD ANSI_FG_RED "error: " ANSI_RESET "binding '" << E.Name << "' was not found\n"; + //if (E.Initiator != nullptr) { + // writeExcerpt(E.Initiator->getRange()); + //} + break; + } + + case DiagnosticKind::UnexpectedToken: + { + auto E = static_cast(D); + Out << ":" << E.Actual->getStartLine() << ":" << E.Actual->getStartColumn() << ": expected "; + switch (E.Expected.size()) { + case 0: + Out << "nothing"; + break; + case 1: + Out << describe(E.Expected[0]); + break; + default: + auto Iter = E.Expected.begin(); + Out << describe(*Iter++); + NodeType Prev; + while (Iter != E.Expected.end()) { + Out << ", " << describe(Prev); + Prev = *Iter++; + } + Out << " or " << describe(Prev); + break; + } + Out << " but instead got '" << E.Actual->getText() << "'\n"; + break; + } + + case DiagnosticKind::UnexpectedString: + { + auto E = static_cast(D); + Out << ":" << E.Location.Line << ":" << E.Location.Column << ": unexpected '"; + for (auto Chr: E.Actual) { + switch (Chr) { + case '\\': + Out << "\\\\"; + break; + case '\'': + Out << "\\'"; + break; + default: + Out << Chr; + break; + } + } + break; + } + + case DiagnosticKind::UnificationError: + { + auto E = static_cast(D); + Out << ANSI_FG_RED << ANSI_BOLD << "error: " << ANSI_RESET << "the types " << ANSI_FG_GREEN << describe(E.Left) << ANSI_RESET + << " and " << ANSI_FG_GREEN << describe(E.Right) << ANSI_RESET << " failed to match\n"; + break; + } + + } + + } } diff --git a/src/main.cc b/src/main.cc index 725488223..652eae794 100644 --- a/src/main.cc +++ b/src/main.cc @@ -29,58 +29,6 @@ String readFile(std::string Path) { return Out; } -std::string describe(NodeType Type) { - switch (Type) { - case NodeType::Identifier: - return "an identifier"; - case NodeType::CustomOperator: - return "an operator"; - case NodeType::IntegerLiteral: - return "an integer literal"; - case NodeType::EndOfFile: - return "end-of-file"; - case NodeType::BlockStart: - return "the start of a new indented block"; - case NodeType::BlockEnd: - return "the end of the current indented block"; - case NodeType::LineFoldEnd: - return "the end of the current line-fold"; - case NodeType::LParen: - return "'('"; - case NodeType::RParen: - return "')'"; - case NodeType::LBrace: - return "'['"; - case NodeType::RBrace: - return "']'"; - case NodeType::LBracket: - return "'{'"; - case NodeType::RBracket: - return "'}'"; - case NodeType::Colon: - return "':'"; - case NodeType::Equals: - return "'='"; - case NodeType::StringLiteral: - return "a string literal"; - case NodeType::Dot: - return "'.'"; - case NodeType::PubKeyword: - return "'pub'"; - case NodeType::LetKeyword: - return "'let'"; - case NodeType::MutKeyword: - return "'mut'"; - case NodeType::ReturnKeyword: - return "'return'"; - case NodeType::TypeKeyword: - return "'type'"; - default: - ZEN_UNREACHABLE - } -} - - int main(int argc, const char* argv[]) { if (argc < 2) { @@ -88,6 +36,8 @@ int main(int argc, const char* argv[]) { return 1; } + ConsoleDiagnostics DE; + auto Text = readFile(argv[1]); VectorStream Chars(Text, EOF); Scanner S(Chars); @@ -99,33 +49,14 @@ int main(int argc, const char* argv[]) { #ifdef NDEBUG try { SF = P.parseSourceFile(); - } catch (UnexpectedTokenDiagnostic& E) { - std::cerr << ":" << E.Actual->getStartLine() << ":" << E.Actual->getStartColumn() << ": expected "; - switch (E.Expected.size()) { - case 0: - std::cerr << "nothing"; - break; - case 1: - std::cerr << describe(E.Expected[0]); - break; - default: - auto Iter = E.Expected.begin(); - std::cerr << describe(*Iter++); - NodeType Prev; - while (Iter != E.Expected.end()) { - std::cerr << ", " << describe(Prev); - Prev = *Iter++; - } - std::cerr << " or " << describe(Prev); - break; - } - std::cerr << " but instead got '" << E.Actual->getText() << "'\n"; + } catch (Diagnostic& D) { + DE.addDiagnostic(D); } #else SF = P.parseSourceFile(); #endif - Checker TheChecker; + Checker TheChecker { DE }; TheChecker.check(SF); return 0;