Fix instance declarations not being correctly typechecked

This commit is contained in:
Sam Vervaeck 2023-05-21 20:14:41 +02:00
parent 3d19ce988c
commit 093f307098
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
3 changed files with 59 additions and 13 deletions

View file

@ -168,20 +168,26 @@ namespace bolt {
}; };
enum class SymbolKind {
Var,
Class,
Type,
};
class Scope { class Scope {
Node* Source; Node* Source;
std::unordered_map<ByteString, Node*> Mapping; std::unordered_multimap<ByteString, std::tuple<Node*, SymbolKind>> Mapping;
void scan(Node* X); void scan(Node* X);
void addBindings(Pattern* X, Node* ToInsert); void addBindings(Pattern* P, Node* ToInsert);
public: public:
Scope(Node* Source); Scope(Node* Source);
Node* lookup(SymbolPath Path); Node* lookup(SymbolPath Path, SymbolKind Kind = SymbolKind::Var);
Scope* getParentScope(); Scope* getParentScope();
@ -998,6 +1004,10 @@ namespace bolt {
Token* getFirstToken() override; Token* getFirstToken() override;
Token* getLastToken() override; Token* getLastToken() override;
static bool classof(const Node* N) {
return N->getKind() == NodeKind::BindPattern;
}
}; };
class LiteralPattern : public Pattern { class LiteralPattern : public Pattern {
@ -1012,6 +1022,10 @@ namespace bolt {
Token* getFirstToken() override; Token* getFirstToken() override;
Token* getLastToken() override; Token* getLastToken() override;
static bool classof(const Node* N) {
return N->getKind() == NodeKind::LiteralPattern;
}
}; };
class Expression : public TypedNode { class Expression : public TypedNode {
@ -1410,6 +1424,10 @@ namespace bolt {
Token* getFirstToken() override; Token* getFirstToken() override;
Token* getLastToken() override; Token* getLastToken() override;
static bool classof(const Node* N) {
return N->getKind() == NodeKind::InstanceDeclaration;
}
}; };
class ClassDeclaration : public Node { class ClassDeclaration : public Node {

View file

@ -28,6 +28,7 @@ namespace bolt {
case NodeKind::ClassDeclaration: case NodeKind::ClassDeclaration:
{ {
auto Decl = static_cast<ClassDeclaration*>(X); auto Decl = static_cast<ClassDeclaration*>(X);
Mapping.emplace(Decl->Name->getCanonicalText(), std::make_tuple(Decl, SymbolKind::Class));
for (auto Element: Decl->Elements) { for (auto Element: Decl->Elements) {
scan(Element); scan(Element);
} }
@ -52,7 +53,7 @@ namespace bolt {
case NodeKind::BindPattern: case NodeKind::BindPattern:
{ {
auto Y = static_cast<BindPattern*>(X); auto Y = static_cast<BindPattern*>(X);
Mapping.emplace(Y->Name->Text, ToInsert); Mapping.emplace(Y->Name->Text, std::make_tuple(ToInsert, SymbolKind::Var));
break; break;
} }
default: default:
@ -60,13 +61,13 @@ namespace bolt {
} }
} }
Node* Scope::lookup(SymbolPath Path) { Node* Scope::lookup(SymbolPath Path, SymbolKind Kind) {
ZEN_ASSERT(Path.Modules.empty()); ZEN_ASSERT(Path.Modules.empty());
auto Curr = this; auto Curr = this;
do { do {
auto Match = Curr->Mapping.find(Path.Name); auto Match = Curr->Mapping.find(Path.Name);
if (Match != Curr->Mapping.end()) { if (Match != Curr->Mapping.end() && std::get<1>(Match->second) == Kind) {
return Match->second; return std::get<0>(Match->second);
} }
Curr = Curr->getParentScope(); Curr = Curr->getParentScope();
} while (Curr != nullptr); } while (Curr != nullptr);
@ -99,9 +100,13 @@ namespace bolt {
} }
Scope* Node::getScope() { Scope* Node::getScope() {
return this->Parent->getScope(); return Parent->getScope();
} }
/* ClassScope& Node::getClassScope() { */
/* return Parent->getClassScope(); */
/* } */
TextLoc Token::getEndLoc() { TextLoc Token::getEndLoc() {
auto EndLoc = StartLoc; auto EndLoc = StartLoc;
EndLoc.advance(getText()); EndLoc.advance(getText());

View file

@ -344,18 +344,27 @@ namespace bolt {
case NodeKind::InstanceDeclaration: case NodeKind::InstanceDeclaration:
{ {
auto Decl = static_cast<InstanceDeclaration*>(X); auto Decl = static_cast<InstanceDeclaration*>(X);
// Needed to set the associated Type on the CST node
for (auto TE: Decl->TypeExps) {
inferTypeExpression(TE);
}
auto Match = InstanceMap.find(Decl->Name->getCanonicalText()); auto Match = InstanceMap.find(Decl->Name->getCanonicalText());
if (Match == InstanceMap.end()) { if (Match == InstanceMap.end()) {
InstanceMap.emplace(Decl->Name->getCanonicalText(), std::vector { Decl }); InstanceMap.emplace(Decl->Name->getCanonicalText(), std::vector { Decl });
} else { } else {
Match->second.push_back(Decl); Match->second.push_back(Decl);
} }
// FIXME save Ctx on the node or dont do this at all
auto Ctx = createInferContext(); auto Ctx = createInferContext();
Contexts.push_back(Ctx); Contexts.push_back(Ctx);
for (auto Element: Decl->Elements) { for (auto Element: Decl->Elements) {
forwardDeclare(Element); forwardDeclare(Element);
} }
Contexts.pop_back(); Contexts.pop_back();
break; break;
} }
@ -389,6 +398,25 @@ namespace bolt {
} }
Let->Ty = Ty; Let->Ty = Ty;
if (llvm::isa<InstanceDeclaration>(Let->Parent)) {
auto Instance = static_cast<InstanceDeclaration*>(Let->Parent);
auto Class = llvm::cast<ClassDeclaration>(Instance->getScope()->lookup({ {}, Instance->Name->getCanonicalText() }, SymbolKind::Class));
std::vector<TVar*> Params;
for (auto TE: Class->TypeVars) {
auto TV = createTypeVar();
NewCtx->Env.emplace(TE->Name->getCanonicalText(), new Forall(TV));
Params.push_back(TV);
}
// FIXME lookup should not go over parent envs
auto Let2 = llvm::cast<LetDeclaration>(Class->getScope()->lookup({ {}, llvm::cast<BindPattern>(Let->Pattern)->Name->getCanonicalText() }, SymbolKind::Var));
if (Let2->TypeAssert) {
addConstraint(new CEqual(Ty, inferTypeExpression(Let2->TypeAssert->TypeExpression), Let));
}
for (auto [Param, TE]: zen::zip(Params, Instance->TypeExps)) {
addConstraint(new CEqual(Param, TE->getType()));
}
}
if (Let->Body) { if (Let->Body) {
switch (Let->Body->getKind()) { switch (Let->Body->getKind()) {
case NodeKind::LetExprBody: case NodeKind::LetExprBody:
@ -447,11 +475,6 @@ namespace bolt {
{ {
auto Decl = static_cast<InstanceDeclaration*>(N); auto Decl = static_cast<InstanceDeclaration*>(N);
// Needed to set the associated Type on the CST node
for (auto TE: Decl->TypeExps) {
inferTypeExpression(TE);
}
for (auto Element: Decl->Elements) { for (auto Element: Decl->Elements) {
infer(Element); infer(Element);
} }