Add support for record patterns and improve implicit forall
This commit is contained in:
parent
34b7693229
commit
de29a77cd3
6 changed files with 185 additions and 21 deletions
|
@ -145,6 +145,7 @@ namespace bolt {
|
|||
BindPattern,
|
||||
LiteralPattern,
|
||||
RecordPatternField,
|
||||
RecordPattern,
|
||||
NamedRecordPattern,
|
||||
NamedTuplePattern,
|
||||
TuplePattern,
|
||||
|
@ -1383,22 +1384,61 @@ namespace bolt {
|
|||
class RecordPatternField : public Node {
|
||||
public:
|
||||
|
||||
DotDot* DotDot;
|
||||
Identifier* Name;
|
||||
Equals* Equals;
|
||||
Pattern* Pattern;
|
||||
|
||||
inline RecordPatternField(
|
||||
class DotDot* DotDot,
|
||||
Identifier* Name,
|
||||
class Equals* Equals,
|
||||
class Pattern* Pattern
|
||||
): Node(NodeKind::RecordPatternField),
|
||||
DotDot(DotDot),
|
||||
Name(Name),
|
||||
Equals(Equals),
|
||||
Pattern(Pattern) {}
|
||||
|
||||
inline RecordPatternField(
|
||||
Identifier* Name,
|
||||
class Equals* Equals,
|
||||
class Pattern* Pattern
|
||||
): RecordPatternField(nullptr, Name, Equals, Pattern) {}
|
||||
|
||||
inline RecordPatternField(
|
||||
class DotDot* DotDot
|
||||
): RecordPatternField(DotDot, nullptr, nullptr, nullptr) {}
|
||||
|
||||
inline RecordPatternField(
|
||||
class DotDot* DotDot,
|
||||
class Pattern* Pattern
|
||||
): RecordPatternField(DotDot, nullptr, nullptr, Pattern) {}
|
||||
|
||||
inline RecordPatternField(
|
||||
Identifier* Name
|
||||
): RecordPatternField(Name, nullptr, nullptr) {}
|
||||
): RecordPatternField(nullptr, Name, nullptr, nullptr) {}
|
||||
|
||||
Token* getFirstToken() const override;
|
||||
Token* getLastToken() const override;
|
||||
|
||||
};
|
||||
|
||||
class RecordPattern : public Pattern {
|
||||
public:
|
||||
|
||||
LBrace* LBrace;
|
||||
std::vector<std::tuple<RecordPatternField*, Comma*>> Fields;
|
||||
RBrace* RBrace;
|
||||
|
||||
inline RecordPattern(
|
||||
class LBrace* LBrace,
|
||||
std::vector<std::tuple<RecordPatternField*, Comma*>> Fields,
|
||||
class RBrace* RBrace
|
||||
): Pattern(NodeKind::RecordPattern),
|
||||
LBrace(LBrace),
|
||||
Fields(Fields),
|
||||
RBrace(RBrace) {}
|
||||
|
||||
Token* getFirstToken() const override;
|
||||
Token* getLastToken() const override;
|
||||
|
|
|
@ -77,6 +77,7 @@ namespace bolt {
|
|||
BOLT_GEN_CASE(BindPattern)
|
||||
BOLT_GEN_CASE(LiteralPattern)
|
||||
BOLT_GEN_CASE(RecordPatternField)
|
||||
BOLT_GEN_CASE(RecordPattern)
|
||||
BOLT_GEN_CASE(NamedRecordPattern)
|
||||
BOLT_GEN_CASE(NamedTuplePattern)
|
||||
BOLT_GEN_CASE(TuplePattern)
|
||||
|
@ -372,6 +373,10 @@ namespace bolt {
|
|||
static_cast<D*>(this)->visitNode(N);
|
||||
}
|
||||
|
||||
void visitRecordPattern(RecordPattern* N) {
|
||||
static_cast<D*>(this)->visitPattern(N);
|
||||
}
|
||||
|
||||
void visitNamedRecordPattern(NamedRecordPattern* N) {
|
||||
static_cast<D*>(this)->visitPattern(N);
|
||||
}
|
||||
|
@ -592,6 +597,7 @@ namespace bolt {
|
|||
BOLT_GEN_CHILD_CASE(BindPattern)
|
||||
BOLT_GEN_CHILD_CASE(LiteralPattern)
|
||||
BOLT_GEN_CHILD_CASE(RecordPatternField)
|
||||
BOLT_GEN_CHILD_CASE(RecordPattern)
|
||||
BOLT_GEN_CHILD_CASE(NamedRecordPattern)
|
||||
BOLT_GEN_CHILD_CASE(NamedTuplePattern)
|
||||
BOLT_GEN_CHILD_CASE(TuplePattern)
|
||||
|
@ -867,7 +873,12 @@ namespace bolt {
|
|||
}
|
||||
|
||||
void visitEachChild(RecordPatternField* N) {
|
||||
BOLT_VISIT(N->Name);
|
||||
if (N->DotDot) {
|
||||
BOLT_VISIT(N->DotDot);
|
||||
}
|
||||
if (N->Name) {
|
||||
BOLT_VISIT(N->Name);
|
||||
}
|
||||
if (N->Equals) {
|
||||
BOLT_VISIT(N->Equals);
|
||||
}
|
||||
|
@ -876,6 +887,17 @@ namespace bolt {
|
|||
}
|
||||
}
|
||||
|
||||
void visitEachChild(RecordPattern* N) {
|
||||
BOLT_VISIT(N->LBrace);
|
||||
for (auto [Field, Comma]: N->Fields) {
|
||||
BOLT_VISIT(Field);
|
||||
if (Comma) {
|
||||
BOLT_VISIT(Comma);
|
||||
}
|
||||
}
|
||||
BOLT_VISIT(N->RBrace);
|
||||
}
|
||||
|
||||
void visitEachChild(NamedRecordPattern* N) {
|
||||
for (auto [Name, Dot]: N->ModulePath) {
|
||||
BOLT_VISIT(Name);
|
||||
|
|
|
@ -244,7 +244,7 @@ namespace bolt {
|
|||
void forwardDeclareFunctionDeclaration(LetDeclaration* N, TVSet* TVs, ConstraintSet* Constraints);
|
||||
|
||||
Type* inferExpression(Expression* Expression);
|
||||
Type* inferTypeExpression(TypeExpression* TE, bool IsPoly = true);
|
||||
Type* inferTypeExpression(TypeExpression* TE, bool AutoVars = true);
|
||||
Type* inferLiteral(Literal* Lit);
|
||||
Type* inferPattern(Pattern* Pattern, ConstraintSet* Constraints = new ConstraintSet, TVSet* TVs = new TVSet);
|
||||
|
||||
|
|
|
@ -174,13 +174,25 @@ namespace bolt {
|
|||
addSymbol(Y->Name->Text, Decl, SymbolKind::Var);
|
||||
break;
|
||||
}
|
||||
case NodeKind::RecordPattern:
|
||||
{
|
||||
auto Y = static_cast<RecordPattern*>(X);
|
||||
for (auto [Field, Comma]: Y->Fields) {
|
||||
if (Field->Pattern) {
|
||||
visitPattern(Field->Pattern, Decl);
|
||||
} else if (Field->Name) {
|
||||
addSymbol(Field->Name->Text, Decl, SymbolKind::Var);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case NodeKind::NamedRecordPattern:
|
||||
{
|
||||
auto Y = static_cast<NamedRecordPattern*>(X);
|
||||
for (auto [Field, Comma]: Y->Fields) {
|
||||
if (Field->Pattern) {
|
||||
visitPattern(Field->Pattern, Decl);
|
||||
} else {
|
||||
} else if (Field->Name) {
|
||||
addSymbol(Field->Name->Text, Decl, SymbolKind::Var);
|
||||
}
|
||||
}
|
||||
|
@ -507,6 +519,14 @@ namespace bolt {
|
|||
return Name;
|
||||
}
|
||||
|
||||
Token* RecordPattern::getFirstToken() const {
|
||||
return LBrace;
|
||||
}
|
||||
|
||||
Token* RecordPattern::getLastToken() const {
|
||||
return RBrace;
|
||||
}
|
||||
|
||||
Token* NamedRecordPattern::getFirstToken() const {
|
||||
if (!ModulePath.empty()) {
|
||||
return std::get<0>(ModulePath.back());
|
||||
|
|
|
@ -266,14 +266,14 @@ namespace bolt {
|
|||
|
||||
case NodeKind::LetDeclaration:
|
||||
{
|
||||
// Function declarations are handled separately in forwardDeclareLetDeclaration()
|
||||
// Function declarations are handled separately in forwardDeclareLetDeclaration() and inferExpression()
|
||||
auto Decl = static_cast<LetDeclaration*>(X);
|
||||
if (!Decl->isVariable()) {
|
||||
break;
|
||||
}
|
||||
Type* Ty;
|
||||
if (Decl->TypeAssert) {
|
||||
Ty = inferTypeExpression(Decl->TypeAssert->TypeExpression, false);
|
||||
Ty = inferTypeExpression(Decl->TypeAssert->TypeExpression);
|
||||
} else {
|
||||
Ty = createTypeVar();
|
||||
}
|
||||
|
@ -291,6 +291,7 @@ namespace bolt {
|
|||
for (auto TE: Decl->TVs) {
|
||||
auto TV = createRigidVar(TE->Name->getCanonicalText());
|
||||
Decl->Ctx->TVs->emplace(TV);
|
||||
Decl->Ctx->Env.add(TE->Name->getCanonicalText(), new Forall(TV), SymKind::Type);
|
||||
Vars.push_back(TV);
|
||||
}
|
||||
|
||||
|
@ -313,7 +314,7 @@ namespace bolt {
|
|||
std::vector<Type*> ParamTypes;
|
||||
for (auto Element: TupleMember->Elements) {
|
||||
// inferTypeExpression will look up any TVars that were part of the signature of Decl
|
||||
ParamTypes.push_back(inferTypeExpression(Element));
|
||||
ParamTypes.push_back(inferTypeExpression(Element, false));
|
||||
}
|
||||
Decl->Ctx->Parent->Env.add(
|
||||
TupleMember->Name->getCanonicalText(),
|
||||
|
@ -350,6 +351,8 @@ namespace bolt {
|
|||
std::vector<Type*> Vars;
|
||||
for (auto TE: Decl->Vars) {
|
||||
auto TV = createRigidVar(TE->Name->getCanonicalText());
|
||||
Decl->Ctx->TVs->emplace(TV);
|
||||
Decl->Ctx->Env.add(TE->Name->getCanonicalText(), new Forall(TV), SymKind::Type);
|
||||
Vars.push_back(TV);
|
||||
}
|
||||
|
||||
|
@ -370,7 +373,7 @@ namespace bolt {
|
|||
FieldsTy = new Type(
|
||||
TField(
|
||||
Field->Name->getCanonicalText(),
|
||||
new Type(TPresent(inferTypeExpression(Field->TypeExpression))),
|
||||
new Type(TPresent(inferTypeExpression(Field->TypeExpression, false))),
|
||||
FieldsTy
|
||||
)
|
||||
);
|
||||
|
@ -811,7 +814,7 @@ namespace bolt {
|
|||
}
|
||||
}
|
||||
|
||||
Type* Checker::inferTypeExpression(TypeExpression* N, bool IsPoly) {
|
||||
Type* Checker::inferTypeExpression(TypeExpression* N, bool AutoVars) {
|
||||
|
||||
switch (N->getKind()) {
|
||||
|
||||
|
@ -833,9 +836,9 @@ namespace bolt {
|
|||
case NodeKind::AppTypeExpression:
|
||||
{
|
||||
auto AppTE = static_cast<AppTypeExpression*>(N);
|
||||
Type* Ty = inferTypeExpression(AppTE->Op, IsPoly);
|
||||
Type* Ty = inferTypeExpression(AppTE->Op, AutoVars);
|
||||
for (auto Arg: AppTE->Args) {
|
||||
Ty = new Type(TApp(Ty, inferTypeExpression(Arg, IsPoly)));
|
||||
Ty = new Type(TApp(Ty, inferTypeExpression(Arg, AutoVars)));
|
||||
}
|
||||
N->setType(Ty);
|
||||
return Ty;
|
||||
|
@ -846,10 +849,10 @@ namespace bolt {
|
|||
auto VarTE = static_cast<VarTypeExpression*>(N);
|
||||
auto Ty = lookupMono(VarTE->Name->getCanonicalText(), SymKind::Type);
|
||||
if (Ty == nullptr) {
|
||||
if (IsPoly && Config.typeVarsRequireForall()) {
|
||||
if (!AutoVars || Config.typeVarsRequireForall()) {
|
||||
DE.add<BindingNotFoundDiagnostic>(VarTE->Name->getCanonicalText(), VarTE->Name);
|
||||
}
|
||||
Ty = IsPoly ? createRigidVar(VarTE->Name->getCanonicalText()) : createTypeVar();
|
||||
Ty = createRigidVar(VarTE->Name->getCanonicalText());
|
||||
addBinding(VarTE->Name->getCanonicalText(), new Forall(Ty), SymKind::Type);
|
||||
}
|
||||
ZEN_ASSERT(Ty->isVar());
|
||||
|
@ -860,9 +863,9 @@ namespace bolt {
|
|||
case NodeKind::RecordTypeExpression:
|
||||
{
|
||||
auto RecTE = static_cast<RecordTypeExpression*>(N);
|
||||
auto Ty = RecTE->Rest ? inferTypeExpression(RecTE->Rest, IsPoly) : new Type(TNil());
|
||||
auto Ty = RecTE->Rest ? inferTypeExpression(RecTE->Rest, AutoVars) : new Type(TNil());
|
||||
for (auto [Field, Comma]: RecTE->Fields) {
|
||||
Ty = new Type(TField(Field->Name->getCanonicalText(), new Type(TPresent(inferTypeExpression(Field->TE, IsPoly))), Ty));
|
||||
Ty = new Type(TField(Field->Name->getCanonicalText(), new Type(TPresent(inferTypeExpression(Field->TE, AutoVars))), Ty));
|
||||
}
|
||||
N->setType(Ty);
|
||||
return Ty;
|
||||
|
@ -873,7 +876,7 @@ namespace bolt {
|
|||
auto TupleTE = static_cast<TupleTypeExpression*>(N);
|
||||
std::vector<Type*> ElementTypes;
|
||||
for (auto [TE, Comma]: TupleTE->Elements) {
|
||||
ElementTypes.push_back(inferTypeExpression(TE, IsPoly));
|
||||
ElementTypes.push_back(inferTypeExpression(TE, AutoVars));
|
||||
}
|
||||
auto Ty = new Type(TTuple(ElementTypes));
|
||||
N->setType(Ty);
|
||||
|
@ -883,7 +886,7 @@ namespace bolt {
|
|||
case NodeKind::NestedTypeExpression:
|
||||
{
|
||||
auto NestedTE = static_cast<NestedTypeExpression*>(N);
|
||||
auto Ty = inferTypeExpression(NestedTE->TE, IsPoly);
|
||||
auto Ty = inferTypeExpression(NestedTE->TE, AutoVars);
|
||||
N->setType(Ty);
|
||||
return Ty;
|
||||
}
|
||||
|
@ -893,9 +896,9 @@ namespace bolt {
|
|||
auto ArrowTE = static_cast<ArrowTypeExpression*>(N);
|
||||
std::vector<Type*> ParamTypes;
|
||||
for (auto ParamType: ArrowTE->ParamTypes) {
|
||||
ParamTypes.push_back(inferTypeExpression(ParamType, IsPoly));
|
||||
ParamTypes.push_back(inferTypeExpression(ParamType, AutoVars));
|
||||
}
|
||||
auto ReturnType = inferTypeExpression(ArrowTE->ReturnType, IsPoly);
|
||||
auto ReturnType = inferTypeExpression(ArrowTE->ReturnType, AutoVars);
|
||||
auto Ty = Type::buildArrow(ParamTypes, ReturnType);
|
||||
N->setType(Ty);
|
||||
return Ty;
|
||||
|
@ -907,7 +910,7 @@ namespace bolt {
|
|||
for (auto [C, Comma]: QTE->Constraints) {
|
||||
inferConstraintExpression(C);
|
||||
}
|
||||
auto Ty = inferTypeExpression(QTE->TE, IsPoly);
|
||||
auto Ty = inferTypeExpression(QTE->TE, AutoVars);
|
||||
N->setType(Ty);
|
||||
return Ty;
|
||||
}
|
||||
|
@ -1111,6 +1114,15 @@ namespace bolt {
|
|||
return Ty;
|
||||
}
|
||||
|
||||
RecordPatternField* getRestField(std::vector<std::tuple<RecordPatternField*, Comma*>> Fields) {
|
||||
for (auto [Field, Comma]: Fields) {
|
||||
if (Field->DotDot) {
|
||||
return Field;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Type* Checker::inferPattern(
|
||||
Pattern* Pattern,
|
||||
ConstraintSet* Constraints,
|
||||
|
@ -1145,6 +1157,34 @@ namespace bolt {
|
|||
return RetTy;
|
||||
}
|
||||
|
||||
case NodeKind::RecordPattern:
|
||||
{
|
||||
auto P = static_cast<RecordPattern*>(Pattern);
|
||||
auto RestField = getRestField(P->Fields);
|
||||
Type* RecordTy;
|
||||
if (RestField == nullptr) {
|
||||
RecordTy = new Type(TNil());
|
||||
} else if (RestField->Pattern) {
|
||||
RecordTy = inferPattern(RestField->Pattern);
|
||||
} else {
|
||||
RecordTy = createTypeVar();
|
||||
}
|
||||
for (auto [Field, Comma]: P->Fields) {
|
||||
if (Field->DotDot) {
|
||||
continue;
|
||||
}
|
||||
Type* FieldTy;
|
||||
if (Field->Pattern) {
|
||||
FieldTy = inferPattern(Field->Pattern, Constraints, TVs);
|
||||
} else {
|
||||
FieldTy = createTypeVar();
|
||||
addBinding(Field->Name->getCanonicalText(), new Forall(TVs, Constraints, FieldTy), SymKind::Var);
|
||||
}
|
||||
RecordTy = new Type(TField(Field->Name->getCanonicalText(), new Type(TPresent(FieldTy)), RecordTy));
|
||||
}
|
||||
return RecordTy;
|
||||
}
|
||||
|
||||
case NodeKind::NamedRecordPattern:
|
||||
{
|
||||
auto P = static_cast<NamedRecordPattern*>(Pattern);
|
||||
|
@ -1153,8 +1193,19 @@ namespace bolt {
|
|||
DE.add<BindingNotFoundDiagnostic>(P->Name->getCanonicalText(), P->Name);
|
||||
return createTypeVar();
|
||||
}
|
||||
auto RecordTy = new Type(TNil());
|
||||
auto RestField = getRestField(P->Fields);
|
||||
Type* RecordTy;
|
||||
if (RestField == nullptr) {
|
||||
RecordTy = new Type(TNil());
|
||||
} else if (RestField->Pattern) {
|
||||
RecordTy = inferPattern(RestField->Pattern);
|
||||
} else {
|
||||
RecordTy = createTypeVar();
|
||||
}
|
||||
for (auto [Field, Comma]: P->Fields) {
|
||||
if (Field->DotDot) {
|
||||
continue;
|
||||
}
|
||||
Type* FieldTy;
|
||||
if (Field->Pattern) {
|
||||
FieldTy = inferPattern(Field->Pattern, Constraints, TVs);
|
||||
|
|
|
@ -158,6 +158,23 @@ finish:
|
|||
if (T0->getKind() == NodeKind::RBrace) {
|
||||
break;
|
||||
}
|
||||
if (T0->getKind() == NodeKind::DotDot) {
|
||||
Tokens.get();
|
||||
auto DotDot = static_cast<class DotDot*>(T0);
|
||||
auto T1 = Tokens.peek();
|
||||
if (T1->getKind() == NodeKind::RBrace) {
|
||||
Fields.push_back(std::make_tuple(new RecordPatternField(DotDot), nullptr));
|
||||
break;
|
||||
}
|
||||
auto P = parseWidePattern();
|
||||
auto T2 = Tokens.peek();
|
||||
if (T2->getKind() != NodeKind::RBrace) {
|
||||
DE.add<UnexpectedTokenDiagnostic>(File, T2, std::vector { NodeKind::RBrace, NodeKind::Comma });
|
||||
return {};
|
||||
}
|
||||
Fields.push_back(std::make_tuple(new RecordPatternField(DotDot, P), nullptr));
|
||||
break;
|
||||
}
|
||||
auto Name = expectToken<Identifier>();
|
||||
Equals* Equals = nullptr;
|
||||
Pattern* Pattern = nullptr;
|
||||
|
@ -194,6 +211,19 @@ finish:
|
|||
case NodeKind::Identifier:
|
||||
Tokens.get();
|
||||
return new BindPattern(static_cast<Identifier*>(T0));
|
||||
case NodeKind::LBrace:
|
||||
{
|
||||
Tokens.get();
|
||||
auto LBrace = static_cast<class LBrace*>(T0);
|
||||
auto Fields = parseRecordPatternFields();
|
||||
if (!Fields) {
|
||||
LBrace->unref();
|
||||
skipToRBrace();
|
||||
return nullptr;
|
||||
}
|
||||
auto RBrace = static_cast<class RBrace*>(Tokens.get());
|
||||
return new RecordPattern(LBrace, *Fields, RBrace);
|
||||
}
|
||||
case NodeKind::IdentifierAlt:
|
||||
{
|
||||
Tokens.get();
|
||||
|
@ -207,6 +237,7 @@ finish:
|
|||
Tokens.get();
|
||||
auto Fields = parseRecordPatternFields();
|
||||
if (!Fields) {
|
||||
LBrace->unref();
|
||||
skipToRBrace();
|
||||
return nullptr;
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue