Add support for record patterns and improve implicit forall

This commit is contained in:
Sam Vervaeck 2024-01-22 01:11:06 +01:00
parent 34b7693229
commit de29a77cd3
Signed by: samvv
SSH key fingerprint: SHA256:dIg0ywU1OP+ZYifrYxy8c5esO72cIKB+4/9wkZj1VaY
6 changed files with 185 additions and 21 deletions

View file

@ -145,6 +145,7 @@ namespace bolt {
BindPattern,
LiteralPattern,
RecordPatternField,
RecordPattern,
NamedRecordPattern,
NamedTuplePattern,
TuplePattern,
@ -1383,22 +1384,61 @@ namespace bolt {
class RecordPatternField : public Node {
public:
DotDot* DotDot;
Identifier* Name;
Equals* Equals;
Pattern* Pattern;
inline RecordPatternField(
class DotDot* DotDot,
Identifier* Name,
class Equals* Equals,
class Pattern* Pattern
): Node(NodeKind::RecordPatternField),
DotDot(DotDot),
Name(Name),
Equals(Equals),
Pattern(Pattern) {}
inline RecordPatternField(
Identifier* Name,
class Equals* Equals,
class Pattern* Pattern
): RecordPatternField(nullptr, Name, Equals, Pattern) {}
inline RecordPatternField(
class DotDot* DotDot
): RecordPatternField(DotDot, nullptr, nullptr, nullptr) {}
inline RecordPatternField(
class DotDot* DotDot,
class Pattern* Pattern
): RecordPatternField(DotDot, nullptr, nullptr, Pattern) {}
inline RecordPatternField(
Identifier* Name
): RecordPatternField(Name, nullptr, nullptr) {}
): RecordPatternField(nullptr, Name, nullptr, nullptr) {}
Token* getFirstToken() const override;
Token* getLastToken() const override;
};
class RecordPattern : public Pattern {
public:
LBrace* LBrace;
std::vector<std::tuple<RecordPatternField*, Comma*>> Fields;
RBrace* RBrace;
inline RecordPattern(
class LBrace* LBrace,
std::vector<std::tuple<RecordPatternField*, Comma*>> Fields,
class RBrace* RBrace
): Pattern(NodeKind::RecordPattern),
LBrace(LBrace),
Fields(Fields),
RBrace(RBrace) {}
Token* getFirstToken() const override;
Token* getLastToken() const override;

View file

@ -77,6 +77,7 @@ namespace bolt {
BOLT_GEN_CASE(BindPattern)
BOLT_GEN_CASE(LiteralPattern)
BOLT_GEN_CASE(RecordPatternField)
BOLT_GEN_CASE(RecordPattern)
BOLT_GEN_CASE(NamedRecordPattern)
BOLT_GEN_CASE(NamedTuplePattern)
BOLT_GEN_CASE(TuplePattern)
@ -372,6 +373,10 @@ namespace bolt {
static_cast<D*>(this)->visitNode(N);
}
void visitRecordPattern(RecordPattern* N) {
static_cast<D*>(this)->visitPattern(N);
}
void visitNamedRecordPattern(NamedRecordPattern* N) {
static_cast<D*>(this)->visitPattern(N);
}
@ -592,6 +597,7 @@ namespace bolt {
BOLT_GEN_CHILD_CASE(BindPattern)
BOLT_GEN_CHILD_CASE(LiteralPattern)
BOLT_GEN_CHILD_CASE(RecordPatternField)
BOLT_GEN_CHILD_CASE(RecordPattern)
BOLT_GEN_CHILD_CASE(NamedRecordPattern)
BOLT_GEN_CHILD_CASE(NamedTuplePattern)
BOLT_GEN_CHILD_CASE(TuplePattern)
@ -867,7 +873,12 @@ namespace bolt {
}
void visitEachChild(RecordPatternField* N) {
BOLT_VISIT(N->Name);
if (N->DotDot) {
BOLT_VISIT(N->DotDot);
}
if (N->Name) {
BOLT_VISIT(N->Name);
}
if (N->Equals) {
BOLT_VISIT(N->Equals);
}
@ -876,6 +887,17 @@ namespace bolt {
}
}
void visitEachChild(RecordPattern* N) {
BOLT_VISIT(N->LBrace);
for (auto [Field, Comma]: N->Fields) {
BOLT_VISIT(Field);
if (Comma) {
BOLT_VISIT(Comma);
}
}
BOLT_VISIT(N->RBrace);
}
void visitEachChild(NamedRecordPattern* N) {
for (auto [Name, Dot]: N->ModulePath) {
BOLT_VISIT(Name);

View file

@ -244,7 +244,7 @@ namespace bolt {
void forwardDeclareFunctionDeclaration(LetDeclaration* N, TVSet* TVs, ConstraintSet* Constraints);
Type* inferExpression(Expression* Expression);
Type* inferTypeExpression(TypeExpression* TE, bool IsPoly = true);
Type* inferTypeExpression(TypeExpression* TE, bool AutoVars = true);
Type* inferLiteral(Literal* Lit);
Type* inferPattern(Pattern* Pattern, ConstraintSet* Constraints = new ConstraintSet, TVSet* TVs = new TVSet);

View file

@ -174,13 +174,25 @@ namespace bolt {
addSymbol(Y->Name->Text, Decl, SymbolKind::Var);
break;
}
case NodeKind::RecordPattern:
{
auto Y = static_cast<RecordPattern*>(X);
for (auto [Field, Comma]: Y->Fields) {
if (Field->Pattern) {
visitPattern(Field->Pattern, Decl);
} else if (Field->Name) {
addSymbol(Field->Name->Text, Decl, SymbolKind::Var);
}
}
break;
}
case NodeKind::NamedRecordPattern:
{
auto Y = static_cast<NamedRecordPattern*>(X);
for (auto [Field, Comma]: Y->Fields) {
if (Field->Pattern) {
visitPattern(Field->Pattern, Decl);
} else {
} else if (Field->Name) {
addSymbol(Field->Name->Text, Decl, SymbolKind::Var);
}
}
@ -507,6 +519,14 @@ namespace bolt {
return Name;
}
Token* RecordPattern::getFirstToken() const {
return LBrace;
}
Token* RecordPattern::getLastToken() const {
return RBrace;
}
Token* NamedRecordPattern::getFirstToken() const {
if (!ModulePath.empty()) {
return std::get<0>(ModulePath.back());

View file

@ -266,14 +266,14 @@ namespace bolt {
case NodeKind::LetDeclaration:
{
// Function declarations are handled separately in forwardDeclareLetDeclaration()
// Function declarations are handled separately in forwardDeclareLetDeclaration() and inferExpression()
auto Decl = static_cast<LetDeclaration*>(X);
if (!Decl->isVariable()) {
break;
}
Type* Ty;
if (Decl->TypeAssert) {
Ty = inferTypeExpression(Decl->TypeAssert->TypeExpression, false);
Ty = inferTypeExpression(Decl->TypeAssert->TypeExpression);
} else {
Ty = createTypeVar();
}
@ -291,6 +291,7 @@ namespace bolt {
for (auto TE: Decl->TVs) {
auto TV = createRigidVar(TE->Name->getCanonicalText());
Decl->Ctx->TVs->emplace(TV);
Decl->Ctx->Env.add(TE->Name->getCanonicalText(), new Forall(TV), SymKind::Type);
Vars.push_back(TV);
}
@ -313,7 +314,7 @@ namespace bolt {
std::vector<Type*> ParamTypes;
for (auto Element: TupleMember->Elements) {
// inferTypeExpression will look up any TVars that were part of the signature of Decl
ParamTypes.push_back(inferTypeExpression(Element));
ParamTypes.push_back(inferTypeExpression(Element, false));
}
Decl->Ctx->Parent->Env.add(
TupleMember->Name->getCanonicalText(),
@ -350,6 +351,8 @@ namespace bolt {
std::vector<Type*> Vars;
for (auto TE: Decl->Vars) {
auto TV = createRigidVar(TE->Name->getCanonicalText());
Decl->Ctx->TVs->emplace(TV);
Decl->Ctx->Env.add(TE->Name->getCanonicalText(), new Forall(TV), SymKind::Type);
Vars.push_back(TV);
}
@ -370,7 +373,7 @@ namespace bolt {
FieldsTy = new Type(
TField(
Field->Name->getCanonicalText(),
new Type(TPresent(inferTypeExpression(Field->TypeExpression))),
new Type(TPresent(inferTypeExpression(Field->TypeExpression, false))),
FieldsTy
)
);
@ -811,7 +814,7 @@ namespace bolt {
}
}
Type* Checker::inferTypeExpression(TypeExpression* N, bool IsPoly) {
Type* Checker::inferTypeExpression(TypeExpression* N, bool AutoVars) {
switch (N->getKind()) {
@ -833,9 +836,9 @@ namespace bolt {
case NodeKind::AppTypeExpression:
{
auto AppTE = static_cast<AppTypeExpression*>(N);
Type* Ty = inferTypeExpression(AppTE->Op, IsPoly);
Type* Ty = inferTypeExpression(AppTE->Op, AutoVars);
for (auto Arg: AppTE->Args) {
Ty = new Type(TApp(Ty, inferTypeExpression(Arg, IsPoly)));
Ty = new Type(TApp(Ty, inferTypeExpression(Arg, AutoVars)));
}
N->setType(Ty);
return Ty;
@ -846,10 +849,10 @@ namespace bolt {
auto VarTE = static_cast<VarTypeExpression*>(N);
auto Ty = lookupMono(VarTE->Name->getCanonicalText(), SymKind::Type);
if (Ty == nullptr) {
if (IsPoly && Config.typeVarsRequireForall()) {
if (!AutoVars || Config.typeVarsRequireForall()) {
DE.add<BindingNotFoundDiagnostic>(VarTE->Name->getCanonicalText(), VarTE->Name);
}
Ty = IsPoly ? createRigidVar(VarTE->Name->getCanonicalText()) : createTypeVar();
Ty = createRigidVar(VarTE->Name->getCanonicalText());
addBinding(VarTE->Name->getCanonicalText(), new Forall(Ty), SymKind::Type);
}
ZEN_ASSERT(Ty->isVar());
@ -860,9 +863,9 @@ namespace bolt {
case NodeKind::RecordTypeExpression:
{
auto RecTE = static_cast<RecordTypeExpression*>(N);
auto Ty = RecTE->Rest ? inferTypeExpression(RecTE->Rest, IsPoly) : new Type(TNil());
auto Ty = RecTE->Rest ? inferTypeExpression(RecTE->Rest, AutoVars) : new Type(TNil());
for (auto [Field, Comma]: RecTE->Fields) {
Ty = new Type(TField(Field->Name->getCanonicalText(), new Type(TPresent(inferTypeExpression(Field->TE, IsPoly))), Ty));
Ty = new Type(TField(Field->Name->getCanonicalText(), new Type(TPresent(inferTypeExpression(Field->TE, AutoVars))), Ty));
}
N->setType(Ty);
return Ty;
@ -873,7 +876,7 @@ namespace bolt {
auto TupleTE = static_cast<TupleTypeExpression*>(N);
std::vector<Type*> ElementTypes;
for (auto [TE, Comma]: TupleTE->Elements) {
ElementTypes.push_back(inferTypeExpression(TE, IsPoly));
ElementTypes.push_back(inferTypeExpression(TE, AutoVars));
}
auto Ty = new Type(TTuple(ElementTypes));
N->setType(Ty);
@ -883,7 +886,7 @@ namespace bolt {
case NodeKind::NestedTypeExpression:
{
auto NestedTE = static_cast<NestedTypeExpression*>(N);
auto Ty = inferTypeExpression(NestedTE->TE, IsPoly);
auto Ty = inferTypeExpression(NestedTE->TE, AutoVars);
N->setType(Ty);
return Ty;
}
@ -893,9 +896,9 @@ namespace bolt {
auto ArrowTE = static_cast<ArrowTypeExpression*>(N);
std::vector<Type*> ParamTypes;
for (auto ParamType: ArrowTE->ParamTypes) {
ParamTypes.push_back(inferTypeExpression(ParamType, IsPoly));
ParamTypes.push_back(inferTypeExpression(ParamType, AutoVars));
}
auto ReturnType = inferTypeExpression(ArrowTE->ReturnType, IsPoly);
auto ReturnType = inferTypeExpression(ArrowTE->ReturnType, AutoVars);
auto Ty = Type::buildArrow(ParamTypes, ReturnType);
N->setType(Ty);
return Ty;
@ -907,7 +910,7 @@ namespace bolt {
for (auto [C, Comma]: QTE->Constraints) {
inferConstraintExpression(C);
}
auto Ty = inferTypeExpression(QTE->TE, IsPoly);
auto Ty = inferTypeExpression(QTE->TE, AutoVars);
N->setType(Ty);
return Ty;
}
@ -1111,6 +1114,15 @@ namespace bolt {
return Ty;
}
RecordPatternField* getRestField(std::vector<std::tuple<RecordPatternField*, Comma*>> Fields) {
for (auto [Field, Comma]: Fields) {
if (Field->DotDot) {
return Field;
}
}
return nullptr;
}
Type* Checker::inferPattern(
Pattern* Pattern,
ConstraintSet* Constraints,
@ -1145,6 +1157,34 @@ namespace bolt {
return RetTy;
}
case NodeKind::RecordPattern:
{
auto P = static_cast<RecordPattern*>(Pattern);
auto RestField = getRestField(P->Fields);
Type* RecordTy;
if (RestField == nullptr) {
RecordTy = new Type(TNil());
} else if (RestField->Pattern) {
RecordTy = inferPattern(RestField->Pattern);
} else {
RecordTy = createTypeVar();
}
for (auto [Field, Comma]: P->Fields) {
if (Field->DotDot) {
continue;
}
Type* FieldTy;
if (Field->Pattern) {
FieldTy = inferPattern(Field->Pattern, Constraints, TVs);
} else {
FieldTy = createTypeVar();
addBinding(Field->Name->getCanonicalText(), new Forall(TVs, Constraints, FieldTy), SymKind::Var);
}
RecordTy = new Type(TField(Field->Name->getCanonicalText(), new Type(TPresent(FieldTy)), RecordTy));
}
return RecordTy;
}
case NodeKind::NamedRecordPattern:
{
auto P = static_cast<NamedRecordPattern*>(Pattern);
@ -1153,8 +1193,19 @@ namespace bolt {
DE.add<BindingNotFoundDiagnostic>(P->Name->getCanonicalText(), P->Name);
return createTypeVar();
}
auto RecordTy = new Type(TNil());
auto RestField = getRestField(P->Fields);
Type* RecordTy;
if (RestField == nullptr) {
RecordTy = new Type(TNil());
} else if (RestField->Pattern) {
RecordTy = inferPattern(RestField->Pattern);
} else {
RecordTy = createTypeVar();
}
for (auto [Field, Comma]: P->Fields) {
if (Field->DotDot) {
continue;
}
Type* FieldTy;
if (Field->Pattern) {
FieldTy = inferPattern(Field->Pattern, Constraints, TVs);

View file

@ -158,6 +158,23 @@ finish:
if (T0->getKind() == NodeKind::RBrace) {
break;
}
if (T0->getKind() == NodeKind::DotDot) {
Tokens.get();
auto DotDot = static_cast<class DotDot*>(T0);
auto T1 = Tokens.peek();
if (T1->getKind() == NodeKind::RBrace) {
Fields.push_back(std::make_tuple(new RecordPatternField(DotDot), nullptr));
break;
}
auto P = parseWidePattern();
auto T2 = Tokens.peek();
if (T2->getKind() != NodeKind::RBrace) {
DE.add<UnexpectedTokenDiagnostic>(File, T2, std::vector { NodeKind::RBrace, NodeKind::Comma });
return {};
}
Fields.push_back(std::make_tuple(new RecordPatternField(DotDot, P), nullptr));
break;
}
auto Name = expectToken<Identifier>();
Equals* Equals = nullptr;
Pattern* Pattern = nullptr;
@ -194,6 +211,19 @@ finish:
case NodeKind::Identifier:
Tokens.get();
return new BindPattern(static_cast<Identifier*>(T0));
case NodeKind::LBrace:
{
Tokens.get();
auto LBrace = static_cast<class LBrace*>(T0);
auto Fields = parseRecordPatternFields();
if (!Fields) {
LBrace->unref();
skipToRBrace();
return nullptr;
}
auto RBrace = static_cast<class RBrace*>(Tokens.get());
return new RecordPattern(LBrace, *Fields, RBrace);
}
case NodeKind::IdentifierAlt:
{
Tokens.get();
@ -207,6 +237,7 @@ finish:
Tokens.get();
auto Fields = parseRecordPatternFields();
if (!Fields) {
LBrace->unref();
skipToRBrace();
return nullptr;
}