From 093f307098fa783968556dd73be2904edce0a342 Mon Sep 17 00:00:00 2001 From: Sam Vervaeck Date: Sun, 21 May 2023 20:14:41 +0200 Subject: [PATCH] Fix instance declarations not being correctly typechecked --- include/bolt/CST.hpp | 24 +++++++++++++++++++++--- src/CST.cc | 15 ++++++++++----- src/Checker.cc | 33 ++++++++++++++++++++++++++++----- 3 files changed, 59 insertions(+), 13 deletions(-) diff --git a/include/bolt/CST.hpp b/include/bolt/CST.hpp index 56ef96b7e..2ca134e8f 100644 --- a/include/bolt/CST.hpp +++ b/include/bolt/CST.hpp @@ -168,20 +168,26 @@ namespace bolt { }; + enum class SymbolKind { + Var, + Class, + Type, + }; + class Scope { Node* Source; - std::unordered_map Mapping; + std::unordered_multimap> Mapping; void scan(Node* X); - void addBindings(Pattern* X, Node* ToInsert); + void addBindings(Pattern* P, Node* ToInsert); public: Scope(Node* Source); - Node* lookup(SymbolPath Path); + Node* lookup(SymbolPath Path, SymbolKind Kind = SymbolKind::Var); Scope* getParentScope(); @@ -998,6 +1004,10 @@ namespace bolt { Token* getFirstToken() override; Token* getLastToken() override; + static bool classof(const Node* N) { + return N->getKind() == NodeKind::BindPattern; + } + }; class LiteralPattern : public Pattern { @@ -1012,6 +1022,10 @@ namespace bolt { Token* getFirstToken() override; Token* getLastToken() override; + static bool classof(const Node* N) { + return N->getKind() == NodeKind::LiteralPattern; + } + }; class Expression : public TypedNode { @@ -1410,6 +1424,10 @@ namespace bolt { Token* getFirstToken() override; Token* getLastToken() override; + static bool classof(const Node* N) { + return N->getKind() == NodeKind::InstanceDeclaration; + } + }; class ClassDeclaration : public Node { diff --git a/src/CST.cc b/src/CST.cc index a39c5d3fa..f3128e6c9 100644 --- a/src/CST.cc +++ b/src/CST.cc @@ -28,6 +28,7 @@ namespace bolt { case NodeKind::ClassDeclaration: { auto Decl = static_cast(X); + Mapping.emplace(Decl->Name->getCanonicalText(), std::make_tuple(Decl, SymbolKind::Class)); for (auto Element: Decl->Elements) { scan(Element); } @@ -52,7 +53,7 @@ namespace bolt { case NodeKind::BindPattern: { auto Y = static_cast(X); - Mapping.emplace(Y->Name->Text, ToInsert); + Mapping.emplace(Y->Name->Text, std::make_tuple(ToInsert, SymbolKind::Var)); break; } default: @@ -60,13 +61,13 @@ namespace bolt { } } - Node* Scope::lookup(SymbolPath Path) { + Node* Scope::lookup(SymbolPath Path, SymbolKind Kind) { ZEN_ASSERT(Path.Modules.empty()); auto Curr = this; do { auto Match = Curr->Mapping.find(Path.Name); - if (Match != Curr->Mapping.end()) { - return Match->second; + if (Match != Curr->Mapping.end() && std::get<1>(Match->second) == Kind) { + return std::get<0>(Match->second); } Curr = Curr->getParentScope(); } while (Curr != nullptr); @@ -99,9 +100,13 @@ namespace bolt { } Scope* Node::getScope() { - return this->Parent->getScope(); + return Parent->getScope(); } + /* ClassScope& Node::getClassScope() { */ + /* return Parent->getClassScope(); */ + /* } */ + TextLoc Token::getEndLoc() { auto EndLoc = StartLoc; EndLoc.advance(getText()); diff --git a/src/Checker.cc b/src/Checker.cc index 93458e5e3..9a0b0c95a 100644 --- a/src/Checker.cc +++ b/src/Checker.cc @@ -344,18 +344,27 @@ namespace bolt { case NodeKind::InstanceDeclaration: { auto Decl = static_cast(X); + + // Needed to set the associated Type on the CST node + for (auto TE: Decl->TypeExps) { + inferTypeExpression(TE); + } + auto Match = InstanceMap.find(Decl->Name->getCanonicalText()); if (Match == InstanceMap.end()) { InstanceMap.emplace(Decl->Name->getCanonicalText(), std::vector { Decl }); } else { Match->second.push_back(Decl); } + + // FIXME save Ctx on the node or dont do this at all auto Ctx = createInferContext(); Contexts.push_back(Ctx); for (auto Element: Decl->Elements) { forwardDeclare(Element); } Contexts.pop_back(); + break; } @@ -389,6 +398,25 @@ namespace bolt { } Let->Ty = Ty; + if (llvm::isa(Let->Parent)) { + auto Instance = static_cast(Let->Parent); + auto Class = llvm::cast(Instance->getScope()->lookup({ {}, Instance->Name->getCanonicalText() }, SymbolKind::Class)); + std::vector Params; + for (auto TE: Class->TypeVars) { + auto TV = createTypeVar(); + NewCtx->Env.emplace(TE->Name->getCanonicalText(), new Forall(TV)); + Params.push_back(TV); + } + // FIXME lookup should not go over parent envs + auto Let2 = llvm::cast(Class->getScope()->lookup({ {}, llvm::cast(Let->Pattern)->Name->getCanonicalText() }, SymbolKind::Var)); + if (Let2->TypeAssert) { + addConstraint(new CEqual(Ty, inferTypeExpression(Let2->TypeAssert->TypeExpression), Let)); + } + for (auto [Param, TE]: zen::zip(Params, Instance->TypeExps)) { + addConstraint(new CEqual(Param, TE->getType())); + } + } + if (Let->Body) { switch (Let->Body->getKind()) { case NodeKind::LetExprBody: @@ -447,11 +475,6 @@ namespace bolt { { auto Decl = static_cast(N); - // Needed to set the associated Type on the CST node - for (auto TE: Decl->TypeExps) { - inferTypeExpression(TE); - } - for (auto Element: Decl->Elements) { infer(Element); }