diff --git a/bootstrap/cxx/include/bolt/CST.hpp b/bootstrap/cxx/include/bolt/CST.hpp index bac763ec1..0555658e6 100644 --- a/bootstrap/cxx/include/bolt/CST.hpp +++ b/bootstrap/cxx/include/bolt/CST.hpp @@ -145,6 +145,7 @@ namespace bolt { BindPattern, LiteralPattern, RecordPatternField, + RecordPattern, NamedRecordPattern, NamedTuplePattern, TuplePattern, @@ -1383,22 +1384,61 @@ namespace bolt { class RecordPatternField : public Node { public: + DotDot* DotDot; Identifier* Name; Equals* Equals; Pattern* Pattern; inline RecordPatternField( + class DotDot* DotDot, Identifier* Name, class Equals* Equals, class Pattern* Pattern ): Node(NodeKind::RecordPatternField), + DotDot(DotDot), Name(Name), Equals(Equals), Pattern(Pattern) {} + inline RecordPatternField( + Identifier* Name, + class Equals* Equals, + class Pattern* Pattern + ): RecordPatternField(nullptr, Name, Equals, Pattern) {} + + inline RecordPatternField( + class DotDot* DotDot + ): RecordPatternField(DotDot, nullptr, nullptr, nullptr) {} + + inline RecordPatternField( + class DotDot* DotDot, + class Pattern* Pattern + ): RecordPatternField(DotDot, nullptr, nullptr, Pattern) {} + inline RecordPatternField( Identifier* Name - ): RecordPatternField(Name, nullptr, nullptr) {} + ): RecordPatternField(nullptr, Name, nullptr, nullptr) {} + + Token* getFirstToken() const override; + Token* getLastToken() const override; + + }; + + class RecordPattern : public Pattern { + public: + + LBrace* LBrace; + std::vector> Fields; + RBrace* RBrace; + + inline RecordPattern( + class LBrace* LBrace, + std::vector> Fields, + class RBrace* RBrace + ): Pattern(NodeKind::RecordPattern), + LBrace(LBrace), + Fields(Fields), + RBrace(RBrace) {} Token* getFirstToken() const override; Token* getLastToken() const override; diff --git a/bootstrap/cxx/include/bolt/CSTVisitor.hpp b/bootstrap/cxx/include/bolt/CSTVisitor.hpp index 532f2a671..15353c007 100644 --- a/bootstrap/cxx/include/bolt/CSTVisitor.hpp +++ b/bootstrap/cxx/include/bolt/CSTVisitor.hpp @@ -77,6 +77,7 @@ namespace bolt { BOLT_GEN_CASE(BindPattern) BOLT_GEN_CASE(LiteralPattern) BOLT_GEN_CASE(RecordPatternField) + BOLT_GEN_CASE(RecordPattern) BOLT_GEN_CASE(NamedRecordPattern) BOLT_GEN_CASE(NamedTuplePattern) BOLT_GEN_CASE(TuplePattern) @@ -372,6 +373,10 @@ namespace bolt { static_cast(this)->visitNode(N); } + void visitRecordPattern(RecordPattern* N) { + static_cast(this)->visitPattern(N); + } + void visitNamedRecordPattern(NamedRecordPattern* N) { static_cast(this)->visitPattern(N); } @@ -592,6 +597,7 @@ namespace bolt { BOLT_GEN_CHILD_CASE(BindPattern) BOLT_GEN_CHILD_CASE(LiteralPattern) BOLT_GEN_CHILD_CASE(RecordPatternField) + BOLT_GEN_CHILD_CASE(RecordPattern) BOLT_GEN_CHILD_CASE(NamedRecordPattern) BOLT_GEN_CHILD_CASE(NamedTuplePattern) BOLT_GEN_CHILD_CASE(TuplePattern) @@ -867,7 +873,12 @@ namespace bolt { } void visitEachChild(RecordPatternField* N) { - BOLT_VISIT(N->Name); + if (N->DotDot) { + BOLT_VISIT(N->DotDot); + } + if (N->Name) { + BOLT_VISIT(N->Name); + } if (N->Equals) { BOLT_VISIT(N->Equals); } @@ -876,6 +887,17 @@ namespace bolt { } } + void visitEachChild(RecordPattern* N) { + BOLT_VISIT(N->LBrace); + for (auto [Field, Comma]: N->Fields) { + BOLT_VISIT(Field); + if (Comma) { + BOLT_VISIT(Comma); + } + } + BOLT_VISIT(N->RBrace); + } + void visitEachChild(NamedRecordPattern* N) { for (auto [Name, Dot]: N->ModulePath) { BOLT_VISIT(Name); diff --git a/bootstrap/cxx/include/bolt/Checker.hpp b/bootstrap/cxx/include/bolt/Checker.hpp index 4d8f96e1b..985e8e6f5 100644 --- a/bootstrap/cxx/include/bolt/Checker.hpp +++ b/bootstrap/cxx/include/bolt/Checker.hpp @@ -244,7 +244,7 @@ namespace bolt { void forwardDeclareFunctionDeclaration(LetDeclaration* N, TVSet* TVs, ConstraintSet* Constraints); Type* inferExpression(Expression* Expression); - Type* inferTypeExpression(TypeExpression* TE, bool IsPoly = true); + Type* inferTypeExpression(TypeExpression* TE, bool AutoVars = true); Type* inferLiteral(Literal* Lit); Type* inferPattern(Pattern* Pattern, ConstraintSet* Constraints = new ConstraintSet, TVSet* TVs = new TVSet); diff --git a/bootstrap/cxx/src/CST.cc b/bootstrap/cxx/src/CST.cc index 6a1d4aafd..43fea60b7 100644 --- a/bootstrap/cxx/src/CST.cc +++ b/bootstrap/cxx/src/CST.cc @@ -174,13 +174,25 @@ namespace bolt { addSymbol(Y->Name->Text, Decl, SymbolKind::Var); break; } + case NodeKind::RecordPattern: + { + auto Y = static_cast(X); + for (auto [Field, Comma]: Y->Fields) { + if (Field->Pattern) { + visitPattern(Field->Pattern, Decl); + } else if (Field->Name) { + addSymbol(Field->Name->Text, Decl, SymbolKind::Var); + } + } + break; + } case NodeKind::NamedRecordPattern: { auto Y = static_cast(X); for (auto [Field, Comma]: Y->Fields) { if (Field->Pattern) { visitPattern(Field->Pattern, Decl); - } else { + } else if (Field->Name) { addSymbol(Field->Name->Text, Decl, SymbolKind::Var); } } @@ -507,6 +519,14 @@ namespace bolt { return Name; } + Token* RecordPattern::getFirstToken() const { + return LBrace; + } + + Token* RecordPattern::getLastToken() const { + return RBrace; + } + Token* NamedRecordPattern::getFirstToken() const { if (!ModulePath.empty()) { return std::get<0>(ModulePath.back()); diff --git a/bootstrap/cxx/src/Checker.cc b/bootstrap/cxx/src/Checker.cc index 59c213052..511c4e67f 100644 --- a/bootstrap/cxx/src/Checker.cc +++ b/bootstrap/cxx/src/Checker.cc @@ -266,14 +266,14 @@ namespace bolt { case NodeKind::LetDeclaration: { - // Function declarations are handled separately in forwardDeclareLetDeclaration() + // Function declarations are handled separately in forwardDeclareLetDeclaration() and inferExpression() auto Decl = static_cast(X); if (!Decl->isVariable()) { break; } Type* Ty; if (Decl->TypeAssert) { - Ty = inferTypeExpression(Decl->TypeAssert->TypeExpression, false); + Ty = inferTypeExpression(Decl->TypeAssert->TypeExpression); } else { Ty = createTypeVar(); } @@ -291,6 +291,7 @@ namespace bolt { for (auto TE: Decl->TVs) { auto TV = createRigidVar(TE->Name->getCanonicalText()); Decl->Ctx->TVs->emplace(TV); + Decl->Ctx->Env.add(TE->Name->getCanonicalText(), new Forall(TV), SymKind::Type); Vars.push_back(TV); } @@ -313,7 +314,7 @@ namespace bolt { std::vector ParamTypes; for (auto Element: TupleMember->Elements) { // inferTypeExpression will look up any TVars that were part of the signature of Decl - ParamTypes.push_back(inferTypeExpression(Element)); + ParamTypes.push_back(inferTypeExpression(Element, false)); } Decl->Ctx->Parent->Env.add( TupleMember->Name->getCanonicalText(), @@ -350,6 +351,8 @@ namespace bolt { std::vector Vars; for (auto TE: Decl->Vars) { auto TV = createRigidVar(TE->Name->getCanonicalText()); + Decl->Ctx->TVs->emplace(TV); + Decl->Ctx->Env.add(TE->Name->getCanonicalText(), new Forall(TV), SymKind::Type); Vars.push_back(TV); } @@ -370,7 +373,7 @@ namespace bolt { FieldsTy = new Type( TField( Field->Name->getCanonicalText(), - new Type(TPresent(inferTypeExpression(Field->TypeExpression))), + new Type(TPresent(inferTypeExpression(Field->TypeExpression, false))), FieldsTy ) ); @@ -811,7 +814,7 @@ namespace bolt { } } - Type* Checker::inferTypeExpression(TypeExpression* N, bool IsPoly) { + Type* Checker::inferTypeExpression(TypeExpression* N, bool AutoVars) { switch (N->getKind()) { @@ -833,9 +836,9 @@ namespace bolt { case NodeKind::AppTypeExpression: { auto AppTE = static_cast(N); - Type* Ty = inferTypeExpression(AppTE->Op, IsPoly); + Type* Ty = inferTypeExpression(AppTE->Op, AutoVars); for (auto Arg: AppTE->Args) { - Ty = new Type(TApp(Ty, inferTypeExpression(Arg, IsPoly))); + Ty = new Type(TApp(Ty, inferTypeExpression(Arg, AutoVars))); } N->setType(Ty); return Ty; @@ -846,10 +849,10 @@ namespace bolt { auto VarTE = static_cast(N); auto Ty = lookupMono(VarTE->Name->getCanonicalText(), SymKind::Type); if (Ty == nullptr) { - if (IsPoly && Config.typeVarsRequireForall()) { + if (!AutoVars || Config.typeVarsRequireForall()) { DE.add(VarTE->Name->getCanonicalText(), VarTE->Name); } - Ty = IsPoly ? createRigidVar(VarTE->Name->getCanonicalText()) : createTypeVar(); + Ty = createRigidVar(VarTE->Name->getCanonicalText()); addBinding(VarTE->Name->getCanonicalText(), new Forall(Ty), SymKind::Type); } ZEN_ASSERT(Ty->isVar()); @@ -860,9 +863,9 @@ namespace bolt { case NodeKind::RecordTypeExpression: { auto RecTE = static_cast(N); - auto Ty = RecTE->Rest ? inferTypeExpression(RecTE->Rest, IsPoly) : new Type(TNil()); + auto Ty = RecTE->Rest ? inferTypeExpression(RecTE->Rest, AutoVars) : new Type(TNil()); for (auto [Field, Comma]: RecTE->Fields) { - Ty = new Type(TField(Field->Name->getCanonicalText(), new Type(TPresent(inferTypeExpression(Field->TE, IsPoly))), Ty)); + Ty = new Type(TField(Field->Name->getCanonicalText(), new Type(TPresent(inferTypeExpression(Field->TE, AutoVars))), Ty)); } N->setType(Ty); return Ty; @@ -873,7 +876,7 @@ namespace bolt { auto TupleTE = static_cast(N); std::vector ElementTypes; for (auto [TE, Comma]: TupleTE->Elements) { - ElementTypes.push_back(inferTypeExpression(TE, IsPoly)); + ElementTypes.push_back(inferTypeExpression(TE, AutoVars)); } auto Ty = new Type(TTuple(ElementTypes)); N->setType(Ty); @@ -883,7 +886,7 @@ namespace bolt { case NodeKind::NestedTypeExpression: { auto NestedTE = static_cast(N); - auto Ty = inferTypeExpression(NestedTE->TE, IsPoly); + auto Ty = inferTypeExpression(NestedTE->TE, AutoVars); N->setType(Ty); return Ty; } @@ -893,9 +896,9 @@ namespace bolt { auto ArrowTE = static_cast(N); std::vector ParamTypes; for (auto ParamType: ArrowTE->ParamTypes) { - ParamTypes.push_back(inferTypeExpression(ParamType, IsPoly)); + ParamTypes.push_back(inferTypeExpression(ParamType, AutoVars)); } - auto ReturnType = inferTypeExpression(ArrowTE->ReturnType, IsPoly); + auto ReturnType = inferTypeExpression(ArrowTE->ReturnType, AutoVars); auto Ty = Type::buildArrow(ParamTypes, ReturnType); N->setType(Ty); return Ty; @@ -907,7 +910,7 @@ namespace bolt { for (auto [C, Comma]: QTE->Constraints) { inferConstraintExpression(C); } - auto Ty = inferTypeExpression(QTE->TE, IsPoly); + auto Ty = inferTypeExpression(QTE->TE, AutoVars); N->setType(Ty); return Ty; } @@ -1111,6 +1114,15 @@ namespace bolt { return Ty; } + RecordPatternField* getRestField(std::vector> Fields) { + for (auto [Field, Comma]: Fields) { + if (Field->DotDot) { + return Field; + } + } + return nullptr; + } + Type* Checker::inferPattern( Pattern* Pattern, ConstraintSet* Constraints, @@ -1145,6 +1157,34 @@ namespace bolt { return RetTy; } + case NodeKind::RecordPattern: + { + auto P = static_cast(Pattern); + auto RestField = getRestField(P->Fields); + Type* RecordTy; + if (RestField == nullptr) { + RecordTy = new Type(TNil()); + } else if (RestField->Pattern) { + RecordTy = inferPattern(RestField->Pattern); + } else { + RecordTy = createTypeVar(); + } + for (auto [Field, Comma]: P->Fields) { + if (Field->DotDot) { + continue; + } + Type* FieldTy; + if (Field->Pattern) { + FieldTy = inferPattern(Field->Pattern, Constraints, TVs); + } else { + FieldTy = createTypeVar(); + addBinding(Field->Name->getCanonicalText(), new Forall(TVs, Constraints, FieldTy), SymKind::Var); + } + RecordTy = new Type(TField(Field->Name->getCanonicalText(), new Type(TPresent(FieldTy)), RecordTy)); + } + return RecordTy; + } + case NodeKind::NamedRecordPattern: { auto P = static_cast(Pattern); @@ -1153,8 +1193,19 @@ namespace bolt { DE.add(P->Name->getCanonicalText(), P->Name); return createTypeVar(); } - auto RecordTy = new Type(TNil()); + auto RestField = getRestField(P->Fields); + Type* RecordTy; + if (RestField == nullptr) { + RecordTy = new Type(TNil()); + } else if (RestField->Pattern) { + RecordTy = inferPattern(RestField->Pattern); + } else { + RecordTy = createTypeVar(); + } for (auto [Field, Comma]: P->Fields) { + if (Field->DotDot) { + continue; + } Type* FieldTy; if (Field->Pattern) { FieldTy = inferPattern(Field->Pattern, Constraints, TVs); diff --git a/bootstrap/cxx/src/Parser.cc b/bootstrap/cxx/src/Parser.cc index b89da9822..0083b92c6 100644 --- a/bootstrap/cxx/src/Parser.cc +++ b/bootstrap/cxx/src/Parser.cc @@ -158,6 +158,23 @@ finish: if (T0->getKind() == NodeKind::RBrace) { break; } + if (T0->getKind() == NodeKind::DotDot) { + Tokens.get(); + auto DotDot = static_cast(T0); + auto T1 = Tokens.peek(); + if (T1->getKind() == NodeKind::RBrace) { + Fields.push_back(std::make_tuple(new RecordPatternField(DotDot), nullptr)); + break; + } + auto P = parseWidePattern(); + auto T2 = Tokens.peek(); + if (T2->getKind() != NodeKind::RBrace) { + DE.add(File, T2, std::vector { NodeKind::RBrace, NodeKind::Comma }); + return {}; + } + Fields.push_back(std::make_tuple(new RecordPatternField(DotDot, P), nullptr)); + break; + } auto Name = expectToken(); Equals* Equals = nullptr; Pattern* Pattern = nullptr; @@ -194,6 +211,19 @@ finish: case NodeKind::Identifier: Tokens.get(); return new BindPattern(static_cast(T0)); + case NodeKind::LBrace: + { + Tokens.get(); + auto LBrace = static_cast(T0); + auto Fields = parseRecordPatternFields(); + if (!Fields) { + LBrace->unref(); + skipToRBrace(); + return nullptr; + } + auto RBrace = static_cast(Tokens.get()); + return new RecordPattern(LBrace, *Fields, RBrace); + } case NodeKind::IdentifierAlt: { Tokens.get(); @@ -207,6 +237,7 @@ finish: Tokens.get(); auto Fields = parseRecordPatternFields(); if (!Fields) { + LBrace->unref(); skipToRBrace(); return nullptr; }