Add basic support for if-statements

This commit is contained in:
Sam Vervaeck 2022-08-25 23:04:09 +02:00
parent b4d54f025c
commit 43301a3a44
8 changed files with 365 additions and 23 deletions

View file

@ -29,6 +29,9 @@ namespace bolt {
ReturnKeyword,
ModKeyword,
StructKeyword,
ElifKeyword,
IfKeyword,
ElseKeyword,
Invalid,
EndOfFile,
BlockStart,
@ -50,6 +53,8 @@ namespace bolt {
UnaryExpression,
ExpressionStatement,
ReturnStatement,
IfStatement,
IfStatementPart,
TypeAssert,
Param,
LetBlockBody,
@ -338,6 +343,42 @@ namespace bolt {
};
class ElseKeyword : public Token {
public:
ElseKeyword(TextLoc StartLoc):
Token(NodeType::ElseKeyword, StartLoc) {}
std::string getText() const override;
~ElseKeyword();
};
class ElifKeyword : public Token {
public:
ElifKeyword(TextLoc StartLoc):
Token(NodeType::ElifKeyword, StartLoc) {}
std::string getText() const override;
~ElifKeyword();
};
class IfKeyword : public Token {
public:
IfKeyword(TextLoc StartLoc):
Token(NodeType::IfKeyword, StartLoc) {}
std::string getText() const override;
~IfKeyword();
};
class ModKeyword : public Token {
public:
@ -731,6 +772,51 @@ namespace bolt {
};
class IfStatementPart : public Node {
public:
Token* Keyword;
Expression* Test;
BlockStart* BlockStart;
std::vector<Node*> Elements;
inline IfStatementPart(
Token* Keyword,
Expression* Test,
class BlockStart* BlockStart,
std::vector<Node*> Elements
): Node(NodeType::IfStatementPart),
Keyword(Keyword),
Test(Test),
BlockStart(BlockStart),
Elements(Elements) {}
void setParents() override;
Token* getFirstToken() override;
Token* getLastToken() override;
~IfStatementPart();
};
class IfStatement : public Statement {
public:
std::vector<IfStatementPart*> Parts;
inline IfStatement(std::vector<IfStatementPart*> Parts):
Statement(NodeType::IfStatement), Parts(Parts) {}
void setParents() override;
Token* getFirstToken() override;
Token* getLastToken() override;
~IfStatement();
};
class ReturnStatement : public Statement {
public:

View file

@ -1,11 +1,11 @@
#pragma once
#include "bolt/Diagnostics.hpp"
#include "zen/config.hpp"
#include "bolt/ByteString.hpp"
#include <stack>
#include <unordered_map>
#include <unordered_set>
#include <vector>
@ -13,8 +13,11 @@
namespace bolt {
class DiagnosticEngine;
class Node;
class Expression;
class TypeExpression;
class Pattern;
class SourceFile;
class Type;
@ -28,6 +31,7 @@ namespace bolt {
Con,
Arrow,
Any,
Tuple,
};
class Type {
@ -88,6 +92,16 @@ namespace bolt {
};
class TTuple : public Type {
public:
std::vector<Type*> ElementTypes;
inline TTuple(std::vector<Type*> ElementTypes):
Type(TypeKind::Tuple), ElementTypes(ElementTypes) {}
};
class TAny : public Type {
public:
@ -115,7 +129,7 @@ namespace bolt {
Type* Type;
inline Forall(class Type* Type):
TVs(nullptr), Constraints(nullptr), Type(Type) {}
TVs(new TVSet), Constraints(new ConstraintSet), Type(Type) {}
inline Forall(
TVSet& TVs,
@ -228,10 +242,10 @@ namespace bolt {
class CMany : public Constraint {
public:
ConstraintSet& Constraints;
ConstraintSet& Elements;
inline CMany(ConstraintSet& Constraints):
Constraint(ConstraintKind::Many), Constraints(Constraints) {}
Constraint(ConstraintKind::Many), Elements(Constraints) {}
};
@ -249,14 +263,15 @@ namespace bolt {
TVSet TVs;
ConstraintSet Constraints;
TypeEnv Env;
Type* ReturnType;
InferContext* Parent;
inline InferContext(InferContext* Parent, TVSet& TVs, ConstraintSet& Constraints, TypeEnv& Env):
Parent(Parent), TVs(TVs), Constraints(Constraints), Env(Env) {}
inline InferContext(InferContext* Parent, TVSet& TVs, ConstraintSet& Constraints, TypeEnv& Env, Type* ReturnType):
Parent(Parent), TVs(TVs), Constraints(Constraints), Env(Env), ReturnType(ReturnType) {}
inline InferContext(InferContext* Parent = nullptr):
Parent(Parent) {}
Parent(Parent), ReturnType(nullptr) {}
void addConstraint(Constraint* C);
@ -275,6 +290,14 @@ namespace bolt {
size_t nextConTypeId = 0;
size_t nextTypeVarId = 0;
Type* BoolType;
Type* IntType;
Type* StringType;
std::stack<InferContext> Contexts;
void addConstraint(Constraint* Constraint);
Type* inferExpression(Expression* Expression, InferContext& Ctx);
Type* inferTypeExpression(TypeExpression* TE, InferContext& Ctx);

View file

@ -68,6 +68,8 @@ namespace bolt {
Token* peekFirstTokenAfterModifiers();
Token* expectToken(NodeType Ty);
Expression* parseInfixOperatorAfterExpression(Expression* LHS, int MinPrecedence);
TypeExpression* parsePrimitiveTypeExpression();
@ -94,6 +96,10 @@ namespace bolt {
Expression* parseCallExpression();
IfStatement* parseIfStatement();
ReturnStatement* parseReturnStatement();
ExpressionStatement* parseExpressionStatement();
Node* parseLetBodyElement();

View file

@ -99,6 +99,26 @@ namespace bolt {
Expression->setParents();
}
void IfStatementPart::setParents() {
Keyword->Parent = this;
if (Test) {
Test->Parent = this;
Test->setParents();
}
BlockStart->Parent = this;
for (auto Element: Elements) {
Element->Parent = this;
Element->setParents();
}
}
void IfStatement::setParents() {
for (auto Part: Parts) {
Part->Parent = this;
Part->setParents();
}
}
void TypeAssert::setParents() {
Colon->Parent = this;
TypeExpression->Parent = this;
@ -230,6 +250,15 @@ namespace bolt {
ReturnKeyword::~ReturnKeyword() {
}
IfKeyword::~IfKeyword() {
}
ElifKeyword::~ElifKeyword() {
}
ElseKeyword::~ElseKeyword() {
}
ModKeyword::~ModKeyword() {
}
@ -335,6 +364,23 @@ namespace bolt {
Expression->unref();
}
IfStatementPart::~IfStatementPart() {
Keyword->unref();
if (Test) {
Test->unref();
}
BlockStart->unref();
for (auto Element: Elements) {
Element->unref();
}
}
IfStatement::~IfStatement() {
for (auto Part: Parts) {
Part->unref();
}
}
TypeAssert::~TypeAssert() {
Colon->unref();
TypeExpression->unref();
@ -501,6 +547,27 @@ namespace bolt {
return ReturnKeyword;
}
Token* IfStatementPart::getFirstToken() {
return Keyword;
}
Token* IfStatementPart::getLastToken() {
if (Elements.size()) {
return Elements.back()->getLastToken();
}
return BlockStart;
}
Token* IfStatement::getFirstToken() {
ZEN_ASSERT(Parts.size());
return Parts.front()->getFirstToken();
}
Token* IfStatement::getLastToken() {
ZEN_ASSERT(Parts.size());
return Parts.back()->getLastToken();
}
Token* TypeAssert::getFirstToken() {
return Colon;
}
@ -655,6 +722,18 @@ namespace bolt {
return "return";
}
std::string IfKeyword::getText() const {
return "if";
}
std::string ElseKeyword::getText() const {
return "else";
}
std::string ElifKeyword::getText() const {
return "elif";
}
std::string ModKeyword::getText() const {
return "mod";
}

View file

@ -35,6 +35,16 @@ namespace bolt {
}
return false;
}
case TypeKind::Tuple:
{
auto Y = static_cast<TTuple*>(this);
for (auto Ty: Y->ElementTypes) {
if (Ty->hasTypeVar(TV)) {
return true;
}
}
return false;
}
case TypeKind::Any:
return false;
}
@ -69,6 +79,15 @@ namespace bolt {
}
return new TCon(Y->Id, NewArgs, Y->DisplayName);
}
case TypeKind::Tuple:
{
auto Y = static_cast<TTuple*>(this);
std::vector<Type*> NewElementTypes;
for (auto Ty: Y->ElementTypes) {
NewElementTypes.push_back(Ty->substitute(Sub));
}
return new TTuple(NewElementTypes);
}
}
}
@ -126,7 +145,11 @@ namespace bolt {
}
Checker::Checker(DiagnosticEngine& DE):
DE(DE) {}
DE(DE) {
BoolType = new TCon(nextConTypeId++, {}, "Bool");
IntType = new TCon(nextConTypeId++, {}, "Int");
StringType = new TCon(nextConTypeId++, {}, "String");
}
void Checker::infer(Node* X, InferContext& Ctx) {
@ -141,6 +164,20 @@ namespace bolt {
break;
}
case NodeType::IfStatement:
{
auto Y = static_cast<IfStatement*>(X);
for (auto Part: Y->Parts) {
if (Part->Test != nullptr) {
Ctx.addConstraint(new CEqual { BoolType, inferExpression(Part->Test, Ctx), Part->Test });
}
for (auto Element: Part->Elements) {
infer(Element, Ctx);
}
}
break;
}
case NodeType::LetDeclaration:
{
auto Y = static_cast<LetDeclaration*>(X);
@ -178,6 +215,7 @@ namespace bolt {
{
auto Z = static_cast<LetBlockBody*>(Y->Body);
RetType = createTypeVar(*NewCtx);
NewCtx->ReturnType = RetType;
for (auto Element: Z->Elements) {
infer(Element, *NewCtx);
}
@ -197,6 +235,19 @@ namespace bolt {
break;
}
case NodeType::ReturnStatement:
{
auto Y = static_cast<ReturnStatement*>(X);
Type* ReturnType;
if (Y->Expression) {
ReturnType = inferExpression(Y->Expression, Ctx);
} else {
ReturnType = new TTuple({});
}
ZEN_ASSERT(Ctx.ReturnType != nullptr);
Ctx.addConstraint(new CEqual { ReturnType, Ctx.ReturnType, X });
break;
}
case NodeType::ExpressionStatement:
{
@ -375,16 +426,15 @@ namespace bolt {
void Checker::check(SourceFile *SF) {
InferContext Toplevel;
auto StringTy = new TCon(nextConTypeId++, {}, "String");
auto IntTy = new TCon(nextConTypeId++, {}, "Int");
auto BoolTy = new TCon(nextConTypeId++, {}, "Bool");
Toplevel.addBinding("String", Forall(StringTy));
Toplevel.addBinding("Int", Forall(IntTy));
Toplevel.addBinding("Bool", Forall(BoolTy));
Toplevel.addBinding("+", Forall(new TArrow({ IntTy, IntTy }, IntTy)));
Toplevel.addBinding("-", Forall(new TArrow({ IntTy, IntTy }, IntTy)));
Toplevel.addBinding("*", Forall(new TArrow({ IntTy, IntTy }, IntTy)));
Toplevel.addBinding("/", Forall(new TArrow({ IntTy, IntTy }, IntTy)));
Toplevel.addBinding("String", Forall(StringType));
Toplevel.addBinding("Int", Forall(IntType));
Toplevel.addBinding("Bool", Forall(BoolType));
Toplevel.addBinding("True", Forall(BoolType));
Toplevel.addBinding("False", Forall(BoolType));
Toplevel.addBinding("+", Forall(new TArrow({ IntType, IntType }, IntType)));
Toplevel.addBinding("-", Forall(new TArrow({ IntType, IntType }, IntType)));
Toplevel.addBinding("*", Forall(new TArrow({ IntType, IntType }, IntType)));
Toplevel.addBinding("/", Forall(new TArrow({ IntType, IntType }, IntType)));
infer(SF, Toplevel);
solve(new CMany(Toplevel.Constraints));
}
@ -480,6 +530,22 @@ namespace bolt {
return unify(Y->ReturnType, Z->ReturnType, Solution);
}
if (A->getKind() == TypeKind::Tuple && B->getKind() == TypeKind::Tuple) {
auto Y = static_cast<TTuple*>(A);
auto Z = static_cast<TTuple*>(B);
if (Y->ElementTypes.size() != Z->ElementTypes.size()) {
return false;
}
auto Count = Y->ElementTypes.size();
bool Success = true;
for (size_t I = 0; I < Count; I++) {
if (!unify(Y->ElementTypes[I], Z->ElementTypes[I], Solution)) {
Success = false;
}
}
return Success;
}
if (A->getKind() == TypeKind::Con && B->getKind() == TypeKind::Con) {
auto Y = static_cast<TCon*>(A);
auto Z = static_cast<TCon*>(B);

View file

@ -129,6 +129,21 @@ namespace bolt {
}
return Out.str();
}
case TypeKind::Tuple:
{
std::ostringstream Out;
auto Y = static_cast<const TTuple*>(Ty);
Out << "(";
if (Y->ElementTypes.size()) {
auto Iter = Y->ElementTypes.begin();
Out << describe(*Iter++);
while (Iter != Y->ElementTypes.end()) {
Out << ", " << describe(*Iter++);
}
}
Out << ")";
return Out.str();
}
}
}

View file

@ -3,6 +3,7 @@
#include "bolt/Scanner.hpp"
#include "bolt/Parser.hpp"
#include "bolt/Diagnostics.hpp"
#include <exception>
#include <vector>
namespace bolt {
@ -74,6 +75,14 @@ namespace bolt {
} \
}
Token* Parser::expectToken(NodeType Type) {
auto T = Tokens.get();
if (T->Type != Type) {
throw UnexpectedTokenDiagnostic(File, T, std::vector<NodeType> { Type }); \
}
return T;
}
Pattern* Parser::parsePattern() {
auto T0 = Tokens.peek();
switch (T0->Type) {
@ -87,10 +96,7 @@ namespace bolt {
QualifiedName* Parser::parseQualifiedName() {
std::vector<Identifier*> ModulePath;
auto Name = Tokens.get();
if (Name->Type != NodeType::Identifier) {
throw UnexpectedTokenDiagnostic(File, Name, std::vector { NodeType::Identifier });
}
auto Name = expectToken(NodeType::Identifier);
for (;;) {
auto T1 = Tokens.peek();
if (T1->Type != NodeType::Dot) {
@ -156,7 +162,7 @@ namespace bolt {
std::vector<Expression*> Args;
for (;;) {
auto T1 = Tokens.peek();
if (T1->Type == NodeType::LineFoldEnd || ExprOperators.isInfix(T1)) {
if (T1->Type == NodeType::LineFoldEnd || T1->Type == NodeType::BlockStart || ExprOperators.isInfix(T1)) {
break;
}
Args.push_back(parsePrimitiveExpression());
@ -216,6 +222,52 @@ namespace bolt {
return new ExpressionStatement(E);
}
ReturnStatement* Parser::parseReturnStatement() {
auto T0 = static_cast<ReturnKeyword*>(expectToken(NodeType::ReturnKeyword));
Expression* Expression = nullptr;
auto T1 = Tokens.peek();
if (T1->Type != NodeType::LineFoldEnd) {
Expression = parseExpression();
}
BOLT_EXPECT_TOKEN(LineFoldEnd);
return new ReturnStatement(static_cast<ReturnKeyword*>(T0), Expression);
}
IfStatement* Parser::parseIfStatement() {
std::vector<IfStatementPart*> Parts;
auto T0 = expectToken(NodeType::IfKeyword);
auto Test = parseExpression();
auto T1 = static_cast<BlockStart*>(expectToken(NodeType::BlockStart));
std::vector<Node*> Then;
for (;;) {
auto T2 = Tokens.peek();
if (T2->Type == NodeType::BlockEnd) {
Tokens.get();
break;
}
Then.push_back(parseLetBodyElement());
}
Parts.push_back(new IfStatementPart(T0, Test, T1, Then));
BOLT_EXPECT_TOKEN(LineFoldEnd)
auto T3 = Tokens.peek();
if (T3->Type == NodeType::ElseKeyword) {
Tokens.get();
auto T4 = static_cast<BlockStart*>(expectToken(NodeType::BlockStart));
std::vector<Node*> Else;
for (;;) {
auto T5 = Tokens.peek();
if (T5->Type == NodeType::BlockEnd) {
Tokens.get();
break;
}
Else.push_back(parseLetBodyElement());
}
Parts.push_back(new IfStatementPart(T3, nullptr, T4, Else));
BOLT_EXPECT_TOKEN(LineFoldEnd)
}
return new IfStatement(Parts);
}
LetDeclaration* Parser::parseLetDeclaration() {
PubKeyword* Pub = nullptr;
@ -316,6 +368,10 @@ after_params:
switch (T0->Type) {
case NodeType::LetKeyword:
return parseLetDeclaration();
case NodeType::ReturnKeyword:
return parseReturnStatement();
case NodeType::IfKeyword:
return parseIfStatement();
default:
return parseExpressionStatement();
}
@ -326,6 +382,8 @@ after_params:
switch (T0->Type) {
case NodeType::LetKeyword:
return parseLetDeclaration();
case NodeType::IfKeyword:
return parseIfStatement();
default:
return parseExpressionStatement();
}

View file

@ -64,6 +64,9 @@ namespace bolt {
{ "return", NodeType::ReturnKeyword },
{ "type", NodeType::TypeKeyword },
{ "mod", NodeType::ModKeyword },
{ "if", NodeType::IfKeyword },
{ "else", NodeType::ElseKeyword },
{ "elif", NodeType::ElifKeyword },
};
Scanner::Scanner(TextFile& File, Stream<Char>& Chars):
@ -209,6 +212,12 @@ digit_finish:
return new TypeKeyword(StartLoc);
case NodeType::ReturnKeyword:
return new ReturnKeyword(StartLoc);
case NodeType::IfKeyword:
return new IfKeyword(StartLoc);
case NodeType::ElifKeyword:
return new ElifKeyword(StartLoc);
case NodeType::ElseKeyword:
return new ElseKeyword(StartLoc);
default:
ZEN_UNREACHABLE
}