Add support for record patterns and improve implicit forall

This commit is contained in:
Sam Vervaeck 2024-01-22 01:11:06 +01:00
parent 34b7693229
commit de29a77cd3
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
6 changed files with 185 additions and 21 deletions

View file

@ -145,6 +145,7 @@ namespace bolt {
BindPattern, BindPattern,
LiteralPattern, LiteralPattern,
RecordPatternField, RecordPatternField,
RecordPattern,
NamedRecordPattern, NamedRecordPattern,
NamedTuplePattern, NamedTuplePattern,
TuplePattern, TuplePattern,
@ -1383,22 +1384,61 @@ namespace bolt {
class RecordPatternField : public Node { class RecordPatternField : public Node {
public: public:
DotDot* DotDot;
Identifier* Name; Identifier* Name;
Equals* Equals; Equals* Equals;
Pattern* Pattern; Pattern* Pattern;
inline RecordPatternField( inline RecordPatternField(
class DotDot* DotDot,
Identifier* Name, Identifier* Name,
class Equals* Equals, class Equals* Equals,
class Pattern* Pattern class Pattern* Pattern
): Node(NodeKind::RecordPatternField), ): Node(NodeKind::RecordPatternField),
DotDot(DotDot),
Name(Name), Name(Name),
Equals(Equals), Equals(Equals),
Pattern(Pattern) {} Pattern(Pattern) {}
inline RecordPatternField(
Identifier* Name,
class Equals* Equals,
class Pattern* Pattern
): RecordPatternField(nullptr, Name, Equals, Pattern) {}
inline RecordPatternField(
class DotDot* DotDot
): RecordPatternField(DotDot, nullptr, nullptr, nullptr) {}
inline RecordPatternField(
class DotDot* DotDot,
class Pattern* Pattern
): RecordPatternField(DotDot, nullptr, nullptr, Pattern) {}
inline RecordPatternField( inline RecordPatternField(
Identifier* Name Identifier* Name
): RecordPatternField(Name, nullptr, nullptr) {} ): RecordPatternField(nullptr, Name, nullptr, nullptr) {}
Token* getFirstToken() const override;
Token* getLastToken() const override;
};
class RecordPattern : public Pattern {
public:
LBrace* LBrace;
std::vector<std::tuple<RecordPatternField*, Comma*>> Fields;
RBrace* RBrace;
inline RecordPattern(
class LBrace* LBrace,
std::vector<std::tuple<RecordPatternField*, Comma*>> Fields,
class RBrace* RBrace
): Pattern(NodeKind::RecordPattern),
LBrace(LBrace),
Fields(Fields),
RBrace(RBrace) {}
Token* getFirstToken() const override; Token* getFirstToken() const override;
Token* getLastToken() const override; Token* getLastToken() const override;

View file

@ -77,6 +77,7 @@ namespace bolt {
BOLT_GEN_CASE(BindPattern) BOLT_GEN_CASE(BindPattern)
BOLT_GEN_CASE(LiteralPattern) BOLT_GEN_CASE(LiteralPattern)
BOLT_GEN_CASE(RecordPatternField) BOLT_GEN_CASE(RecordPatternField)
BOLT_GEN_CASE(RecordPattern)
BOLT_GEN_CASE(NamedRecordPattern) BOLT_GEN_CASE(NamedRecordPattern)
BOLT_GEN_CASE(NamedTuplePattern) BOLT_GEN_CASE(NamedTuplePattern)
BOLT_GEN_CASE(TuplePattern) BOLT_GEN_CASE(TuplePattern)
@ -372,6 +373,10 @@ namespace bolt {
static_cast<D*>(this)->visitNode(N); static_cast<D*>(this)->visitNode(N);
} }
void visitRecordPattern(RecordPattern* N) {
static_cast<D*>(this)->visitPattern(N);
}
void visitNamedRecordPattern(NamedRecordPattern* N) { void visitNamedRecordPattern(NamedRecordPattern* N) {
static_cast<D*>(this)->visitPattern(N); static_cast<D*>(this)->visitPattern(N);
} }
@ -592,6 +597,7 @@ namespace bolt {
BOLT_GEN_CHILD_CASE(BindPattern) BOLT_GEN_CHILD_CASE(BindPattern)
BOLT_GEN_CHILD_CASE(LiteralPattern) BOLT_GEN_CHILD_CASE(LiteralPattern)
BOLT_GEN_CHILD_CASE(RecordPatternField) BOLT_GEN_CHILD_CASE(RecordPatternField)
BOLT_GEN_CHILD_CASE(RecordPattern)
BOLT_GEN_CHILD_CASE(NamedRecordPattern) BOLT_GEN_CHILD_CASE(NamedRecordPattern)
BOLT_GEN_CHILD_CASE(NamedTuplePattern) BOLT_GEN_CHILD_CASE(NamedTuplePattern)
BOLT_GEN_CHILD_CASE(TuplePattern) BOLT_GEN_CHILD_CASE(TuplePattern)
@ -867,7 +873,12 @@ namespace bolt {
} }
void visitEachChild(RecordPatternField* N) { void visitEachChild(RecordPatternField* N) {
BOLT_VISIT(N->Name); if (N->DotDot) {
BOLT_VISIT(N->DotDot);
}
if (N->Name) {
BOLT_VISIT(N->Name);
}
if (N->Equals) { if (N->Equals) {
BOLT_VISIT(N->Equals); BOLT_VISIT(N->Equals);
} }
@ -876,6 +887,17 @@ namespace bolt {
} }
} }
void visitEachChild(RecordPattern* N) {
BOLT_VISIT(N->LBrace);
for (auto [Field, Comma]: N->Fields) {
BOLT_VISIT(Field);
if (Comma) {
BOLT_VISIT(Comma);
}
}
BOLT_VISIT(N->RBrace);
}
void visitEachChild(NamedRecordPattern* N) { void visitEachChild(NamedRecordPattern* N) {
for (auto [Name, Dot]: N->ModulePath) { for (auto [Name, Dot]: N->ModulePath) {
BOLT_VISIT(Name); BOLT_VISIT(Name);

View file

@ -244,7 +244,7 @@ namespace bolt {
void forwardDeclareFunctionDeclaration(LetDeclaration* N, TVSet* TVs, ConstraintSet* Constraints); void forwardDeclareFunctionDeclaration(LetDeclaration* N, TVSet* TVs, ConstraintSet* Constraints);
Type* inferExpression(Expression* Expression); Type* inferExpression(Expression* Expression);
Type* inferTypeExpression(TypeExpression* TE, bool IsPoly = true); Type* inferTypeExpression(TypeExpression* TE, bool AutoVars = true);
Type* inferLiteral(Literal* Lit); Type* inferLiteral(Literal* Lit);
Type* inferPattern(Pattern* Pattern, ConstraintSet* Constraints = new ConstraintSet, TVSet* TVs = new TVSet); Type* inferPattern(Pattern* Pattern, ConstraintSet* Constraints = new ConstraintSet, TVSet* TVs = new TVSet);

View file

@ -174,13 +174,25 @@ namespace bolt {
addSymbol(Y->Name->Text, Decl, SymbolKind::Var); addSymbol(Y->Name->Text, Decl, SymbolKind::Var);
break; break;
} }
case NodeKind::RecordPattern:
{
auto Y = static_cast<RecordPattern*>(X);
for (auto [Field, Comma]: Y->Fields) {
if (Field->Pattern) {
visitPattern(Field->Pattern, Decl);
} else if (Field->Name) {
addSymbol(Field->Name->Text, Decl, SymbolKind::Var);
}
}
break;
}
case NodeKind::NamedRecordPattern: case NodeKind::NamedRecordPattern:
{ {
auto Y = static_cast<NamedRecordPattern*>(X); auto Y = static_cast<NamedRecordPattern*>(X);
for (auto [Field, Comma]: Y->Fields) { for (auto [Field, Comma]: Y->Fields) {
if (Field->Pattern) { if (Field->Pattern) {
visitPattern(Field->Pattern, Decl); visitPattern(Field->Pattern, Decl);
} else { } else if (Field->Name) {
addSymbol(Field->Name->Text, Decl, SymbolKind::Var); addSymbol(Field->Name->Text, Decl, SymbolKind::Var);
} }
} }
@ -507,6 +519,14 @@ namespace bolt {
return Name; return Name;
} }
Token* RecordPattern::getFirstToken() const {
return LBrace;
}
Token* RecordPattern::getLastToken() const {
return RBrace;
}
Token* NamedRecordPattern::getFirstToken() const { Token* NamedRecordPattern::getFirstToken() const {
if (!ModulePath.empty()) { if (!ModulePath.empty()) {
return std::get<0>(ModulePath.back()); return std::get<0>(ModulePath.back());

View file

@ -266,14 +266,14 @@ namespace bolt {
case NodeKind::LetDeclaration: case NodeKind::LetDeclaration:
{ {
// Function declarations are handled separately in forwardDeclareLetDeclaration() // Function declarations are handled separately in forwardDeclareLetDeclaration() and inferExpression()
auto Decl = static_cast<LetDeclaration*>(X); auto Decl = static_cast<LetDeclaration*>(X);
if (!Decl->isVariable()) { if (!Decl->isVariable()) {
break; break;
} }
Type* Ty; Type* Ty;
if (Decl->TypeAssert) { if (Decl->TypeAssert) {
Ty = inferTypeExpression(Decl->TypeAssert->TypeExpression, false); Ty = inferTypeExpression(Decl->TypeAssert->TypeExpression);
} else { } else {
Ty = createTypeVar(); Ty = createTypeVar();
} }
@ -291,6 +291,7 @@ namespace bolt {
for (auto TE: Decl->TVs) { for (auto TE: Decl->TVs) {
auto TV = createRigidVar(TE->Name->getCanonicalText()); auto TV = createRigidVar(TE->Name->getCanonicalText());
Decl->Ctx->TVs->emplace(TV); Decl->Ctx->TVs->emplace(TV);
Decl->Ctx->Env.add(TE->Name->getCanonicalText(), new Forall(TV), SymKind::Type);
Vars.push_back(TV); Vars.push_back(TV);
} }
@ -313,7 +314,7 @@ namespace bolt {
std::vector<Type*> ParamTypes; std::vector<Type*> ParamTypes;
for (auto Element: TupleMember->Elements) { for (auto Element: TupleMember->Elements) {
// inferTypeExpression will look up any TVars that were part of the signature of Decl // inferTypeExpression will look up any TVars that were part of the signature of Decl
ParamTypes.push_back(inferTypeExpression(Element)); ParamTypes.push_back(inferTypeExpression(Element, false));
} }
Decl->Ctx->Parent->Env.add( Decl->Ctx->Parent->Env.add(
TupleMember->Name->getCanonicalText(), TupleMember->Name->getCanonicalText(),
@ -350,6 +351,8 @@ namespace bolt {
std::vector<Type*> Vars; std::vector<Type*> Vars;
for (auto TE: Decl->Vars) { for (auto TE: Decl->Vars) {
auto TV = createRigidVar(TE->Name->getCanonicalText()); auto TV = createRigidVar(TE->Name->getCanonicalText());
Decl->Ctx->TVs->emplace(TV);
Decl->Ctx->Env.add(TE->Name->getCanonicalText(), new Forall(TV), SymKind::Type);
Vars.push_back(TV); Vars.push_back(TV);
} }
@ -370,7 +373,7 @@ namespace bolt {
FieldsTy = new Type( FieldsTy = new Type(
TField( TField(
Field->Name->getCanonicalText(), Field->Name->getCanonicalText(),
new Type(TPresent(inferTypeExpression(Field->TypeExpression))), new Type(TPresent(inferTypeExpression(Field->TypeExpression, false))),
FieldsTy FieldsTy
) )
); );
@ -811,7 +814,7 @@ namespace bolt {
} }
} }
Type* Checker::inferTypeExpression(TypeExpression* N, bool IsPoly) { Type* Checker::inferTypeExpression(TypeExpression* N, bool AutoVars) {
switch (N->getKind()) { switch (N->getKind()) {
@ -833,9 +836,9 @@ namespace bolt {
case NodeKind::AppTypeExpression: case NodeKind::AppTypeExpression:
{ {
auto AppTE = static_cast<AppTypeExpression*>(N); auto AppTE = static_cast<AppTypeExpression*>(N);
Type* Ty = inferTypeExpression(AppTE->Op, IsPoly); Type* Ty = inferTypeExpression(AppTE->Op, AutoVars);
for (auto Arg: AppTE->Args) { for (auto Arg: AppTE->Args) {
Ty = new Type(TApp(Ty, inferTypeExpression(Arg, IsPoly))); Ty = new Type(TApp(Ty, inferTypeExpression(Arg, AutoVars)));
} }
N->setType(Ty); N->setType(Ty);
return Ty; return Ty;
@ -846,10 +849,10 @@ namespace bolt {
auto VarTE = static_cast<VarTypeExpression*>(N); auto VarTE = static_cast<VarTypeExpression*>(N);
auto Ty = lookupMono(VarTE->Name->getCanonicalText(), SymKind::Type); auto Ty = lookupMono(VarTE->Name->getCanonicalText(), SymKind::Type);
if (Ty == nullptr) { if (Ty == nullptr) {
if (IsPoly && Config.typeVarsRequireForall()) { if (!AutoVars || Config.typeVarsRequireForall()) {
DE.add<BindingNotFoundDiagnostic>(VarTE->Name->getCanonicalText(), VarTE->Name); DE.add<BindingNotFoundDiagnostic>(VarTE->Name->getCanonicalText(), VarTE->Name);
} }
Ty = IsPoly ? createRigidVar(VarTE->Name->getCanonicalText()) : createTypeVar(); Ty = createRigidVar(VarTE->Name->getCanonicalText());
addBinding(VarTE->Name->getCanonicalText(), new Forall(Ty), SymKind::Type); addBinding(VarTE->Name->getCanonicalText(), new Forall(Ty), SymKind::Type);
} }
ZEN_ASSERT(Ty->isVar()); ZEN_ASSERT(Ty->isVar());
@ -860,9 +863,9 @@ namespace bolt {
case NodeKind::RecordTypeExpression: case NodeKind::RecordTypeExpression:
{ {
auto RecTE = static_cast<RecordTypeExpression*>(N); auto RecTE = static_cast<RecordTypeExpression*>(N);
auto Ty = RecTE->Rest ? inferTypeExpression(RecTE->Rest, IsPoly) : new Type(TNil()); auto Ty = RecTE->Rest ? inferTypeExpression(RecTE->Rest, AutoVars) : new Type(TNil());
for (auto [Field, Comma]: RecTE->Fields) { for (auto [Field, Comma]: RecTE->Fields) {
Ty = new Type(TField(Field->Name->getCanonicalText(), new Type(TPresent(inferTypeExpression(Field->TE, IsPoly))), Ty)); Ty = new Type(TField(Field->Name->getCanonicalText(), new Type(TPresent(inferTypeExpression(Field->TE, AutoVars))), Ty));
} }
N->setType(Ty); N->setType(Ty);
return Ty; return Ty;
@ -873,7 +876,7 @@ namespace bolt {
auto TupleTE = static_cast<TupleTypeExpression*>(N); auto TupleTE = static_cast<TupleTypeExpression*>(N);
std::vector<Type*> ElementTypes; std::vector<Type*> ElementTypes;
for (auto [TE, Comma]: TupleTE->Elements) { for (auto [TE, Comma]: TupleTE->Elements) {
ElementTypes.push_back(inferTypeExpression(TE, IsPoly)); ElementTypes.push_back(inferTypeExpression(TE, AutoVars));
} }
auto Ty = new Type(TTuple(ElementTypes)); auto Ty = new Type(TTuple(ElementTypes));
N->setType(Ty); N->setType(Ty);
@ -883,7 +886,7 @@ namespace bolt {
case NodeKind::NestedTypeExpression: case NodeKind::NestedTypeExpression:
{ {
auto NestedTE = static_cast<NestedTypeExpression*>(N); auto NestedTE = static_cast<NestedTypeExpression*>(N);
auto Ty = inferTypeExpression(NestedTE->TE, IsPoly); auto Ty = inferTypeExpression(NestedTE->TE, AutoVars);
N->setType(Ty); N->setType(Ty);
return Ty; return Ty;
} }
@ -893,9 +896,9 @@ namespace bolt {
auto ArrowTE = static_cast<ArrowTypeExpression*>(N); auto ArrowTE = static_cast<ArrowTypeExpression*>(N);
std::vector<Type*> ParamTypes; std::vector<Type*> ParamTypes;
for (auto ParamType: ArrowTE->ParamTypes) { for (auto ParamType: ArrowTE->ParamTypes) {
ParamTypes.push_back(inferTypeExpression(ParamType, IsPoly)); ParamTypes.push_back(inferTypeExpression(ParamType, AutoVars));
} }
auto ReturnType = inferTypeExpression(ArrowTE->ReturnType, IsPoly); auto ReturnType = inferTypeExpression(ArrowTE->ReturnType, AutoVars);
auto Ty = Type::buildArrow(ParamTypes, ReturnType); auto Ty = Type::buildArrow(ParamTypes, ReturnType);
N->setType(Ty); N->setType(Ty);
return Ty; return Ty;
@ -907,7 +910,7 @@ namespace bolt {
for (auto [C, Comma]: QTE->Constraints) { for (auto [C, Comma]: QTE->Constraints) {
inferConstraintExpression(C); inferConstraintExpression(C);
} }
auto Ty = inferTypeExpression(QTE->TE, IsPoly); auto Ty = inferTypeExpression(QTE->TE, AutoVars);
N->setType(Ty); N->setType(Ty);
return Ty; return Ty;
} }
@ -1111,6 +1114,15 @@ namespace bolt {
return Ty; return Ty;
} }
RecordPatternField* getRestField(std::vector<std::tuple<RecordPatternField*, Comma*>> Fields) {
for (auto [Field, Comma]: Fields) {
if (Field->DotDot) {
return Field;
}
}
return nullptr;
}
Type* Checker::inferPattern( Type* Checker::inferPattern(
Pattern* Pattern, Pattern* Pattern,
ConstraintSet* Constraints, ConstraintSet* Constraints,
@ -1145,6 +1157,34 @@ namespace bolt {
return RetTy; return RetTy;
} }
case NodeKind::RecordPattern:
{
auto P = static_cast<RecordPattern*>(Pattern);
auto RestField = getRestField(P->Fields);
Type* RecordTy;
if (RestField == nullptr) {
RecordTy = new Type(TNil());
} else if (RestField->Pattern) {
RecordTy = inferPattern(RestField->Pattern);
} else {
RecordTy = createTypeVar();
}
for (auto [Field, Comma]: P->Fields) {
if (Field->DotDot) {
continue;
}
Type* FieldTy;
if (Field->Pattern) {
FieldTy = inferPattern(Field->Pattern, Constraints, TVs);
} else {
FieldTy = createTypeVar();
addBinding(Field->Name->getCanonicalText(), new Forall(TVs, Constraints, FieldTy), SymKind::Var);
}
RecordTy = new Type(TField(Field->Name->getCanonicalText(), new Type(TPresent(FieldTy)), RecordTy));
}
return RecordTy;
}
case NodeKind::NamedRecordPattern: case NodeKind::NamedRecordPattern:
{ {
auto P = static_cast<NamedRecordPattern*>(Pattern); auto P = static_cast<NamedRecordPattern*>(Pattern);
@ -1153,8 +1193,19 @@ namespace bolt {
DE.add<BindingNotFoundDiagnostic>(P->Name->getCanonicalText(), P->Name); DE.add<BindingNotFoundDiagnostic>(P->Name->getCanonicalText(), P->Name);
return createTypeVar(); return createTypeVar();
} }
auto RecordTy = new Type(TNil()); auto RestField = getRestField(P->Fields);
Type* RecordTy;
if (RestField == nullptr) {
RecordTy = new Type(TNil());
} else if (RestField->Pattern) {
RecordTy = inferPattern(RestField->Pattern);
} else {
RecordTy = createTypeVar();
}
for (auto [Field, Comma]: P->Fields) { for (auto [Field, Comma]: P->Fields) {
if (Field->DotDot) {
continue;
}
Type* FieldTy; Type* FieldTy;
if (Field->Pattern) { if (Field->Pattern) {
FieldTy = inferPattern(Field->Pattern, Constraints, TVs); FieldTy = inferPattern(Field->Pattern, Constraints, TVs);

View file

@ -158,6 +158,23 @@ finish:
if (T0->getKind() == NodeKind::RBrace) { if (T0->getKind() == NodeKind::RBrace) {
break; break;
} }
if (T0->getKind() == NodeKind::DotDot) {
Tokens.get();
auto DotDot = static_cast<class DotDot*>(T0);
auto T1 = Tokens.peek();
if (T1->getKind() == NodeKind::RBrace) {
Fields.push_back(std::make_tuple(new RecordPatternField(DotDot), nullptr));
break;
}
auto P = parseWidePattern();
auto T2 = Tokens.peek();
if (T2->getKind() != NodeKind::RBrace) {
DE.add<UnexpectedTokenDiagnostic>(File, T2, std::vector { NodeKind::RBrace, NodeKind::Comma });
return {};
}
Fields.push_back(std::make_tuple(new RecordPatternField(DotDot, P), nullptr));
break;
}
auto Name = expectToken<Identifier>(); auto Name = expectToken<Identifier>();
Equals* Equals = nullptr; Equals* Equals = nullptr;
Pattern* Pattern = nullptr; Pattern* Pattern = nullptr;
@ -194,6 +211,19 @@ finish:
case NodeKind::Identifier: case NodeKind::Identifier:
Tokens.get(); Tokens.get();
return new BindPattern(static_cast<Identifier*>(T0)); return new BindPattern(static_cast<Identifier*>(T0));
case NodeKind::LBrace:
{
Tokens.get();
auto LBrace = static_cast<class LBrace*>(T0);
auto Fields = parseRecordPatternFields();
if (!Fields) {
LBrace->unref();
skipToRBrace();
return nullptr;
}
auto RBrace = static_cast<class RBrace*>(Tokens.get());
return new RecordPattern(LBrace, *Fields, RBrace);
}
case NodeKind::IdentifierAlt: case NodeKind::IdentifierAlt:
{ {
Tokens.get(); Tokens.get();
@ -207,6 +237,7 @@ finish:
Tokens.get(); Tokens.get();
auto Fields = parseRecordPatternFields(); auto Fields = parseRecordPatternFields();
if (!Fields) { if (!Fields) {
LBrace->unref();
skipToRBrace(); skipToRBrace();
return nullptr; return nullptr;
} }