diff --git a/include/bolt/CST.hpp b/include/bolt/CST.hpp index 1cae3f3c0..0be4f143d 100644 --- a/include/bolt/CST.hpp +++ b/include/bolt/CST.hpp @@ -76,6 +76,7 @@ namespace bolt { MatchCase, MatchExpression, MemberExpression, + TupleExpression, NestedExpression, ConstantExpression, CallExpression, @@ -830,6 +831,10 @@ namespace bolt { std::string getText() const override; + inline Integer getInteger() const noexcept { + return V; + } + Value getValue() override; static bool classof(const Node* N) { @@ -1143,6 +1148,27 @@ namespace bolt { }; + class TupleExpression : public Expression { + public: + + class LParen* LParen; + std::vector> Elements; + class RParen* RParen; + + inline TupleExpression( + class LParen* LParen, + std::vector> Elements, + class RParen* RParen + ): Expression(NodeKind::TupleExpression), + LParen(LParen), + Elements(Elements), + RParen(RParen) {} + + Token* getFirstToken() override; + Token* getLastToken() override; + + }; + class NestedExpression : public Expression { public: diff --git a/include/bolt/CSTVisitor.hpp b/include/bolt/CSTVisitor.hpp index 4d29db89c..8a9bba81d 100644 --- a/include/bolt/CSTVisitor.hpp +++ b/include/bolt/CSTVisitor.hpp @@ -111,6 +111,8 @@ namespace bolt { return static_cast(this)->visitMatchExpression(static_cast(N)); case NodeKind::MemberExpression: return static_cast(this)->visitMemberExpression(static_cast(N)); + case NodeKind::TupleExpression: + return static_cast(this)->visitTupleExpression(static_cast(N)); case NodeKind::NestedExpression: return static_cast(this)->visitNestedExpression(static_cast(N)); case NodeKind::ConstantExpression: @@ -378,6 +380,10 @@ namespace bolt { visitExpression(N); } + void visitTupleExpression(TupleExpression* N) { + visitExpression(N); + } + void visitNestedExpression(NestedExpression* N) { visitExpression(N); } @@ -616,6 +622,9 @@ namespace bolt { case NodeKind::MemberExpression: visitEachChild(static_cast(N)); break; + case NodeKind::TupleExpression: + visitEachChild(static_cast(N)); + break; case NodeKind::NestedExpression: visitEachChild(static_cast(N)); break; @@ -876,6 +885,17 @@ namespace bolt { BOLT_VISIT(N->Name); } + void visitEachChild(TupleExpression* N) { + BOLT_VISIT(N->LParen); + for (auto [E, Comma]: N->Elements) { + BOLT_VISIT(E); + if (Comma) { + BOLT_VISIT(Comma); + } + } + BOLT_VISIT(N->RParen); + } + void visitEachChild(NestedExpression* N) { BOLT_VISIT(N->LParen); BOLT_VISIT(N->Inner); diff --git a/include/bolt/Checker.hpp b/include/bolt/Checker.hpp index ecae3c3fc..a8d26582b 100644 --- a/include/bolt/Checker.hpp +++ b/include/bolt/Checker.hpp @@ -32,6 +32,7 @@ namespace bolt { Con, Arrow, Tuple, + TupleIndex, }; class Type { @@ -148,6 +149,21 @@ namespace bolt { }; + class TTupleIndex : public Type { + public: + + Type* Ty; + std::size_t I; + + inline TTupleIndex(Type* Ty, std::size_t I): + Type(TypeKind::TupleIndex), Ty(Ty), I(I) {} + + static bool classof(const Type* Ty) { + return Ty->getKind() == TypeKind::TupleIndex; + } + + }; + // template // struct DerefHash { // std::size_t operator()(const T& Value) const noexcept { @@ -403,6 +419,8 @@ namespace bolt { void checkTypeclassSigs(Node* N); + Type* simplify(Type* Ty); + void join(TVar* A, Type* B, Node* Source); bool unify(Type* A, Type* B, Node* Source); void solveCEqual(CEqual* C); diff --git a/include/bolt/Diagnostics.hpp b/include/bolt/Diagnostics.hpp index e11a1c773..dbd7f8f16 100644 --- a/include/bolt/Diagnostics.hpp +++ b/include/bolt/Diagnostics.hpp @@ -15,6 +15,7 @@ namespace bolt { class Type; class TCon; class TVar; + class TTuple; using TypeclassId = ByteString; @@ -37,6 +38,8 @@ namespace bolt { TypeclassMissing, InstanceNotFound, ClassNotFound, + TupleIndexOutOfRange, + InvalidTypeToTypeclass, }; class Diagnostic : std::runtime_error { @@ -135,6 +138,27 @@ namespace bolt { }; + class TupleIndexOutOfRangeDiagnostic : public Diagnostic { + public: + + TTuple* Tuple; + std::size_t I; + + inline TupleIndexOutOfRangeDiagnostic(TTuple* Tuple, std::size_t I): + Diagnostic(DiagnosticKind::TupleIndexOutOfRange), Tuple(Tuple), I(I) {} + + }; + + class InvalidTypeToTypeclassDiagnostic : public Diagnostic { + public: + + Type* Actual; + + inline InvalidTypeToTypeclassDiagnostic(Type* Actual): + Diagnostic(DiagnosticKind::InvalidTypeToTypeclass) {} + + }; + class DiagnosticEngine { protected: diff --git a/src/CST.cc b/src/CST.cc index ecd346060..ff4a3a574 100644 --- a/src/CST.cc +++ b/src/CST.cc @@ -283,6 +283,14 @@ namespace bolt { return Name; } + Token* TupleExpression::getFirstToken() { + return LParen; + } + + Token* TupleExpression::getLastToken() { + return RParen; + } + Token* NestedExpression::getFirstToken() { return LParen; } diff --git a/src/Checker.cc b/src/Checker.cc index af1f431dc..86e59f2c9 100644 --- a/src/Checker.cc +++ b/src/Checker.cc @@ -3,6 +3,10 @@ // TODO (maybe) make unficiation work like union-find in find() +// TODO make simplify() rewrite the types in-place such that a reference too (Bool, Int).0 becomes Bool + +// TODO Fix TVSub to use TVar.Id instead of the pointer address + #include #include #include @@ -58,6 +62,12 @@ namespace bolt { } break; } + case TypeKind::TupleIndex: + { + auto Index = static_cast(this); + Index->Ty->addTypeVars(TVs); + break; + } case TypeKind::Tuple: { auto Tuple = static_cast(this); @@ -93,6 +103,11 @@ namespace bolt { } return false; } + case TypeKind::TupleIndex: + { + auto Index = static_cast(this); + return Index->Ty->hasTypeVar(TV); + } case TypeKind::Tuple: { auto Tuple = static_cast(this); @@ -146,6 +161,12 @@ namespace bolt { } return Changed ? new TCon(Con->Id, NewArgs, Con->DisplayName) : this; } + case TypeKind::TupleIndex: + { + auto Tuple = static_cast(this); + auto NewTy = Tuple->Ty->substitute(Sub); + return NewTy != Tuple->Ty ? new TTupleIndex(NewTy, Tuple->I) : Tuple; + } case TypeKind::Tuple: { auto Tuple = static_cast(this); @@ -821,6 +842,35 @@ namespace bolt { return RetTy; } + case NodeKind::TupleExpression: + { + auto Tuple = static_cast(X); + std::vector Types; + for (auto [E, Comma]: Tuple->Elements) { + Types.push_back(inferExpression(E)); + } + return new TTuple(Types); + } + + case NodeKind::MemberExpression: + { + auto Member = static_cast(X); + switch (Member->Name->getKind()) { + case NodeKind::IntegerLiteral: + { + auto I = static_cast(Member->Name); + return new TTupleIndex(inferExpression(Member->E), I->getInteger()); + } + case NodeKind::Identifier: + { + // TODO + break; + } + default: + ZEN_UNREACHABLE + } + } + case NodeKind::NestedExpression: { auto Nested = static_cast(X); @@ -1124,9 +1174,8 @@ namespace bolt { for (auto Class: Classes) { propagateClassTycon(Class, llvm::cast(Ty), Source); } - } else { - ZEN_UNREACHABLE - // DE.add(Ty); + } else if (!Classes.empty()) { + DE.add(Ty); } }; @@ -1174,28 +1223,139 @@ namespace bolt { }; void Checker::solveCEqual(CEqual* C) { - /* std::cerr << describe(C->Left) << " ~ " << describe(C->Right) << std::endl; */ + std::cerr << describe(C->Left) << " ~ " << describe(C->Right) << std::endl; if (!unify(C->Left, C->Right, C->Source)) { - DE.add(C->Left->substitute(Solution), C->Right->substitute(Solution), C->Source); + DE.add(simplify(C->Left), simplify(C->Right), C->Source); } } + Type* Checker::simplify(Type* Ty) { + + while (Ty->getKind() == TypeKind::Var) { + auto Match = Solution.find(static_cast(Ty)); + if (Match == Solution.end()) { + break; + } + Ty = Match->second; + } + + switch (Ty->getKind()) { + + case TypeKind::Var: + break; + + case TypeKind::Tuple: + { + auto Tuple = static_cast(Ty); + bool Changed = false; + std::vector NewElementTypes; + for (auto Ty: Tuple->ElementTypes) { + auto NewElementType = simplify(Ty); + if (NewElementType != Ty) { + Changed = true; + } + NewElementTypes.push_back(NewElementType); + } + return Changed ? new TTuple(NewElementTypes) : Ty; + } + + case TypeKind::Arrow: + { + auto Arrow = static_cast(Ty); + bool Changed = false; + std::vector NewParamTys; + for (auto ParamTy: Arrow->ParamTypes) { + auto NewParamTy = simplify(ParamTy); + if (NewParamTy != ParamTy) { + Changed = true; + } + NewParamTys.push_back(NewParamTy); + } + auto NewRetTy = simplify(Arrow->ReturnType); + if (NewRetTy != Arrow->ReturnType) { + Changed = true; + } + Ty = Changed ? new TArrow(NewParamTys, NewRetTy) : Arrow; + break; + } + + case TypeKind::Con: + { + auto Con = static_cast(Ty); + bool Changed = false; + std::vector NewArgs; + for (auto Arg: Con->Args) { + auto NewArg = simplify(Arg); + if (NewArg != Arg) { + Changed = true; + } + NewArgs.push_back(NewArg); + } + return Changed ? new TCon(Con->Id, NewArgs, Con->DisplayName) : Ty; + } + + case TypeKind::TupleIndex: + { + auto Index = static_cast(Ty); + auto MaybeTuple = simplify(Index->Ty); + if (llvm::isa(MaybeTuple)) { + auto Tuple = static_cast(MaybeTuple); + if (Index->I >= Tuple->ElementTypes.size()) { + DE.add(Tuple, Index->I); + } else { + Ty = simplify(Tuple->ElementTypes[Index->I]); + } + } + break; + } + + } + + return Ty; + } + + void Checker::join(TVar* TV, Type* Ty, Node* Source) { + + Solution[TV] = Ty; + + propagateClasses(TV->Contexts, Ty, Source); + + // This is a very specific adjustment that is critical to the + // well-functioning of the infer/unify algorithm. When addConstraint() is + // called, it may decide to solve the constraint immediately during + // inference. If this happens, a type variable might get assigned a concrete + // type such as Int. We therefore never want the variable to be polymorphic + // and be instantiated with a fresh variable, as it has already been solved. + // Should it get assigned another unification variable, that's OK too + // because then the context of that variable is what matters and not anymore + // the context of this one. + if (!Contexts.empty()) { + Contexts.back()->TVs->erase(TV); + } + + } + bool Checker::unify(Type* A, Type* B, Node* Source) { auto find = [&](auto OrigTy) { auto Ty = OrigTy; - while (Ty->getKind() == TypeKind::Var) { - auto Match = Solution.find(static_cast(Ty)); - if (Match == Solution.end()) { - break; - } - Ty = Match->second; + if (llvm::isa(Ty)) { + auto TV = static_cast(Ty); + do { + auto Match = Solution.find(static_cast(Ty)); + if (Match == Solution.end()) { + break; + } + Ty = Match->second; + } while (Ty->getKind() == TypeKind::Var); + // FIXME does this actually improove performance? + Solution[TV] = Ty; } return Ty; }; - A = find(A); - B = find(B); + A = simplify(A); + B = simplify(B); if (llvm::isa(A) && llvm::isa(B)) { auto Var1 = static_cast(A); @@ -1206,19 +1366,19 @@ namespace bolt { } return true; } - TVar* Dest; + TVar* To; TVar* From; if (Var1->getVarKind() == VarKind::Rigid && Var2->getVarKind() == VarKind::Unification) { - Dest = Var1; + To = Var1; From = Var2; } else { // Only cases left are Var1 = Unification, Var2 = Rigid and Var1 = Unification, Var2 = Unification // Either way, Var1 is a good candidate for being unified away - Dest = Var2; + To = Var2; From = Var1; } - Solution[From] = Dest; - propagateClasses(From->Contexts, Dest, Source); + join(From, To, Source); + propagateClasses(From->Contexts, To, Source); return true; } @@ -1234,10 +1394,7 @@ namespace bolt { // than obsure references to an occurs check return false; } - Solution[TV] = B; - if (!TV->Contexts.empty()) { - propagateClasses(TV->Contexts, B, Source); - } + join(TV, B, Source); return true; } @@ -1301,6 +1458,12 @@ namespace bolt { return Success; } + // if (llvm::isa(A) && llvm::isa(B)) { + // auto Index1 = static_cast(A); + // auto Index2 = static_cast(B); + // return unify(Index1->Ty, Index2->Ty, Source); + // } + if (llvm::isa(A) && llvm::isa(B)) { auto Con1 = static_cast(A); auto Con2 = static_cast(B); diff --git a/src/Diagnostics.cc b/src/Diagnostics.cc index 2148b8126..8844df560 100644 --- a/src/Diagnostics.cc +++ b/src/Diagnostics.cc @@ -162,6 +162,11 @@ namespace bolt { Out << ")"; return Out.str(); } + case TypeKind::TupleIndex: + { + auto Y = static_cast(Ty); + return describe(Y->Ty) + "." + std::to_string(Y->I); + } } } diff --git a/src/Parser.cc b/src/Parser.cc index 507321ff5..5bb71521c 100644 --- a/src/Parser.cc +++ b/src/Parser.cc @@ -237,9 +237,35 @@ after_constraints: case NodeKind::LParen: { Tokens.get(); - auto E = parseExpression(); - auto T2 = static_cast(expectToken(NodeKind::RParen)); - return new NestedExpression(static_cast(T0), E, T2); + std::vector> Elements; + auto LParen = static_cast(T0); + RParen* RParen; + for (;;) { + auto T1 = Tokens.peek(); + if (llvm::isa(T1)) { + Tokens.get(); + RParen = static_cast(T1); + break; + } + auto E = parseExpression(); + auto T2 = Tokens.get(); + switch (T2->getKind()) { + case NodeKind::RParen: + RParen = static_cast(T2); + Elements.push_back({ E, nullptr }); + goto finish; + case NodeKind::Comma: + Elements.push_back({ E, static_cast(T2) }); + break; + default: + throw UnexpectedTokenDiagnostic(File, T2, { NodeKind::RParen, NodeKind::Comma }); + } + } +finish: + if (Elements.size() == 1 && !std::get<1>(Elements.front())) { + return new NestedExpression(LParen, std::get<0>(Elements.front()), RParen); + } + return new TupleExpression { LParen, Elements, RParen }; } case NodeKind::MatchKeyword: { @@ -307,7 +333,7 @@ finish: std::vector Args; for (;;) { auto T1 = Tokens.peek(); - if (T1->getKind() == NodeKind::LineFoldEnd || T1->getKind() == NodeKind::RParen || T1->getKind() == NodeKind::BlockStart || ExprOperators.isInfix(T1)) { + if (T1->getKind() == NodeKind::LineFoldEnd || T1->getKind() == NodeKind::RParen || T1->getKind() == NodeKind::BlockStart || T1->getKind() == NodeKind::Comma || ExprOperators.isInfix(T1)) { break; } Args.push_back(parsePrimitiveExpression());