Improve checking recursive functions and some minor fixes
This commit is contained in:
parent
2448a70c76
commit
06127ff624
9 changed files with 413 additions and 317 deletions
2
.vscode/tasks.json
vendored
2
.vscode/tasks.json
vendored
|
@ -6,7 +6,7 @@
|
|||
"label": "CMake: build",
|
||||
"command": "build",
|
||||
"targets": [
|
||||
"all"
|
||||
"bolt"
|
||||
],
|
||||
"group": {
|
||||
"kind": "build",
|
||||
|
|
|
@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 3.10)
|
|||
|
||||
project(Bolt CXX)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD 20)
|
||||
|
||||
add_subdirectory(deps/zen EXCLUDE_FROM_ALL)
|
||||
|
||||
|
@ -22,7 +22,6 @@ add_library(
|
|||
src/Parser.cc
|
||||
src/Types.cc
|
||||
src/Checker.cc
|
||||
src/IPRGraph.cc
|
||||
)
|
||||
target_link_directories(
|
||||
BoltCore
|
||||
|
|
|
@ -1530,6 +1530,7 @@ namespace bolt {
|
|||
|
||||
public:
|
||||
|
||||
bool IsCycleActive = false;
|
||||
InferContext* Ctx;
|
||||
class Type* Ty;
|
||||
|
||||
|
@ -1718,6 +1719,10 @@ namespace bolt {
|
|||
return TheScope;
|
||||
}
|
||||
|
||||
static bool classof(const Node* N) {
|
||||
return N->getKind() == NodeKind::SourceFile;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template<> inline NodeKind getNodeType<Equals>() { return NodeKind::Equals; }
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
#include "bolt/Common.hpp"
|
||||
#include "bolt/CST.hpp"
|
||||
#include "bolt/Type.hpp"
|
||||
#include "bolt/Support/Graph.hpp"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
@ -171,13 +172,17 @@ namespace bolt {
|
|||
size_t NextConTypeId = 0;
|
||||
size_t NextTypeVarId = 0;
|
||||
|
||||
Type* BoolType;
|
||||
Type* IntType;
|
||||
Type* StringType;
|
||||
|
||||
Graph<Node*> RefGraph;
|
||||
|
||||
std::unordered_map<Node*, InferContext*> CallGraph;
|
||||
|
||||
std::unordered_map<ByteString, std::vector<InstanceDeclaration*>> InstanceMap;
|
||||
|
||||
Type* BoolType;
|
||||
Type* IntType;
|
||||
Type* StringType;
|
||||
std::vector<InferContext*> Contexts;
|
||||
|
||||
TVSub Solution;
|
||||
|
||||
|
@ -191,14 +196,13 @@ namespace bolt {
|
|||
*/
|
||||
CEqual* C;
|
||||
|
||||
std::vector<InferContext*> Contexts;
|
||||
|
||||
InferContext& getContext();
|
||||
|
||||
void addConstraint(Constraint* Constraint);
|
||||
void addClass(TypeclassSignature Sig);
|
||||
|
||||
void forwardDeclare(Node* Node);
|
||||
void forwardDeclareLetDeclaration(LetDeclaration* N, TVSet* TVs, ConstraintSet* Constraints);
|
||||
|
||||
Type* inferExpression(Expression* Expression);
|
||||
Type* inferTypeExpression(TypeExpression* TE);
|
||||
|
@ -208,13 +212,14 @@ namespace bolt {
|
|||
void inferBindings(Pattern* Pattern, Type* T);
|
||||
|
||||
void infer(Node* node);
|
||||
void inferLetDeclaration(LetDeclaration* N);
|
||||
|
||||
Constraint* convertToConstraint(ConstraintExpression* C);
|
||||
|
||||
TCon* createPrimConType();
|
||||
TVar* createTypeVar();
|
||||
TVarRigid* createRigidVar(ByteString Name);
|
||||
InferContext* createInferContext();
|
||||
InferContext* createInferContext(TVSet* TVs = new TVSet, ConstraintSet* Constraints = new ConstraintSet);
|
||||
|
||||
void addBinding(ByteString Name, Scheme* Scm);
|
||||
|
||||
|
@ -276,6 +281,8 @@ namespace bolt {
|
|||
|
||||
void solve(Constraint* Constraint, TVSub& Solution);
|
||||
|
||||
void populate(SourceFile* SF);
|
||||
|
||||
/**
|
||||
* Verifies that type class signatures on type asserts in let-declarations
|
||||
* correctly declare the right type classes.
|
||||
|
|
|
@ -1,29 +0,0 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
namespace bolt {
|
||||
|
||||
class Node;
|
||||
class ReferenceExpression;
|
||||
|
||||
/**
|
||||
* An inter-procedural reference graph.
|
||||
*
|
||||
* This graph keeps track of the references made to other procedures in the
|
||||
* same program.
|
||||
*/
|
||||
class IPRGraph {
|
||||
|
||||
std::unordered_map<Node*, Node*> Edges;
|
||||
|
||||
public:
|
||||
|
||||
void populate(Node* , Node* Decl = nullptr);
|
||||
|
||||
bool isRecursive(ReferenceExpression* Node);
|
||||
|
||||
};
|
||||
|
||||
}
|
145
include/bolt/Support/Graph.hpp
Normal file
145
include/bolt/Support/Graph.hpp
Normal file
|
@ -0,0 +1,145 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <unordered_set>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <stack>
|
||||
#include <optional>
|
||||
|
||||
#include "zen/range.hpp"
|
||||
|
||||
namespace bolt {
|
||||
|
||||
template<typename V>
|
||||
class Graph {
|
||||
|
||||
std::unordered_set<V> Vertices;
|
||||
std::unordered_multimap<V, V> Edges;
|
||||
|
||||
public:
|
||||
|
||||
void addVertex(V Vert) {
|
||||
Vertices.emplace(Vert);
|
||||
}
|
||||
|
||||
void addEdge(V A, V B) {
|
||||
Vertices.emplace(A);
|
||||
Vertices.emplace(B);
|
||||
Edges.emplace(A, B);
|
||||
}
|
||||
|
||||
std::size_t countVertices() const {
|
||||
return Vertices.size();
|
||||
}
|
||||
|
||||
bool hasVertex(const V& Vert) const {
|
||||
return Vertices.count(Vert);
|
||||
}
|
||||
|
||||
bool hasEdge(const V& From) const {
|
||||
return Edges.count(From);
|
||||
}
|
||||
|
||||
bool hasEdge(const V& From, const V& To) const {
|
||||
for (auto X: Edges.equal_range(From)) {
|
||||
if (X == To) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto getTargetVertices(const V& From) const {
|
||||
return zen::make_iterator_range(Edges.equal_range(From)).map_second();
|
||||
}
|
||||
|
||||
auto getVertices() const {
|
||||
return zen::make_iterator_range(Vertices);
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
struct TarjanVertexData {
|
||||
std::optional<std::size_t> Index;
|
||||
std::size_t LowLink;
|
||||
bool OnStack = false;
|
||||
};
|
||||
|
||||
class TarjanSolver {
|
||||
public:
|
||||
|
||||
std::vector<std::vector<V>> SCCs;
|
||||
|
||||
private:
|
||||
|
||||
const Graph& G;
|
||||
std::unordered_map<V, TarjanVertexData> Map;
|
||||
std::size_t Index = 0;
|
||||
std::stack<V> Stack;
|
||||
|
||||
TarjanVertexData& getData(V From) {
|
||||
return Map.emplace(From, TarjanVertexData {}).first->second;
|
||||
}
|
||||
|
||||
void visitCycle(const V& From) {
|
||||
|
||||
auto& DataFrom = getData(From);
|
||||
DataFrom.Index = Index;
|
||||
DataFrom.LowLink = Index;
|
||||
Index++;
|
||||
Stack.push(From);
|
||||
DataFrom.OnStack = true;
|
||||
|
||||
for (const auto& To: G.getTargetVertices(From)) {
|
||||
auto& DataTo = getData(To);
|
||||
if (!DataTo.Index) {
|
||||
visitCycle(To);
|
||||
DataFrom.LowLink = std::min(DataFrom.LowLink, DataTo.LowLink);
|
||||
} else if (DataTo.OnStack) {
|
||||
DataFrom.LowLink = std::min(DataFrom.LowLink, *DataTo.Index);
|
||||
}
|
||||
}
|
||||
|
||||
if (DataFrom.LowLink == DataFrom.Index) {
|
||||
std::vector<V> SCC;
|
||||
for (;;) {
|
||||
auto& X = Stack.top();
|
||||
Stack.pop();
|
||||
auto& DataX = getData(X);
|
||||
DataX.OnStack = false;
|
||||
SCC.push_back(X);
|
||||
if (X == From) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
SCCs.push_back(SCC);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
TarjanSolver(const Graph& G):
|
||||
G(G) {}
|
||||
|
||||
void solve() {
|
||||
for (auto From: G.Vertices) {
|
||||
if (!Map.count(From)) {
|
||||
visitCycle(From);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
public:
|
||||
|
||||
std::vector<std::vector<V>> strongconnect() const {
|
||||
TarjanSolver S { *this };
|
||||
S.solve();
|
||||
return S.SCCs;
|
||||
}
|
||||
|
||||
};
|
||||
}
|
||||
|
416
src/Checker.cc
416
src/Checker.cc
|
@ -230,112 +230,8 @@ namespace bolt {
|
|||
}
|
||||
|
||||
case NodeKind::LetDeclaration:
|
||||
{
|
||||
auto Let = static_cast<LetDeclaration*>(X);
|
||||
bool IsFunc = !Let->Params.empty();
|
||||
bool IsInstance = llvm::isa<InstanceDeclaration>(Let->Parent);
|
||||
bool IsClass = llvm::isa<ClassDeclaration>(Let->Parent);
|
||||
bool HasContext = IsFunc || IsInstance || IsClass;
|
||||
|
||||
if (HasContext) {
|
||||
Let->Ctx = createInferContext();
|
||||
Contexts.push_back(Let->Ctx);
|
||||
}
|
||||
|
||||
// If declaring a let-declaration inside a type class declaration,
|
||||
// we need to mark that the let-declaration requires this class.
|
||||
// This marking is set on the rigid type variables of the class, which
|
||||
// are then added to this local type environment.
|
||||
if (IsClass) {
|
||||
auto Class = static_cast<ClassDeclaration*>(Let->Parent);
|
||||
for (auto TE: Class->TypeVars) {
|
||||
auto TV = llvm::cast<TVar>(TE->getType());
|
||||
Let->Ctx->Env.emplace(TE->Name->getCanonicalText(), new Forall(TV));
|
||||
Let->Ctx->TVs->emplace(TV);
|
||||
}
|
||||
}
|
||||
|
||||
// Here we infer the primary type of the let declaration. If there's a
|
||||
// type assert, that assert should be authoritative so we use that.
|
||||
// Otherwise, the type is not further specified and we create a new
|
||||
// unification variable.
|
||||
Type* Ty;
|
||||
if (Let->TypeAssert) {
|
||||
Ty = inferTypeExpression(Let->TypeAssert->TypeExpression);
|
||||
} else {
|
||||
Ty = createTypeVar();
|
||||
}
|
||||
Let->Ty = Ty;
|
||||
|
||||
// If declaring a let-declaration inside a type instance declaration,
|
||||
// we need to perform some work to make sure the type asserts of the
|
||||
// corresponding let-declaration in the type class declaration are
|
||||
// accounted for.
|
||||
if (IsInstance) {
|
||||
|
||||
auto Instance = static_cast<InstanceDeclaration*>(Let->Parent);
|
||||
auto Class = llvm::cast<ClassDeclaration>(Instance->getScope()->lookup({ {}, Instance->Name->getCanonicalText() }, SymbolKind::Class));
|
||||
|
||||
// The type asserts in the type class declaration might make use of
|
||||
// the type parameters of the type class declaration, so it is
|
||||
// important to make them available in the type environment. Moreover,
|
||||
// we will be unifying them with the actual types declared in the
|
||||
// instance declaration, so we keep track of them.
|
||||
std::vector<TVar *> Params;
|
||||
TVSub Sub;
|
||||
for (auto TE: Class->TypeVars) {
|
||||
auto TV = createTypeVar();
|
||||
Sub.emplace(llvm::cast<TVar>(TE->getType()), TV);
|
||||
Params.push_back(TV);
|
||||
}
|
||||
|
||||
auto SigLet = llvm::cast<LetDeclaration>(Class->getScope()->lookupDirect({ {}, llvm::cast<BindPattern>(Let->Pattern)->Name->getCanonicalText() }, SymbolKind::Var));
|
||||
|
||||
// It would be very strange if there was no type assert in the type
|
||||
// class let-declaration but we rather not let the compiler crash if that happens.
|
||||
if (SigLet->TypeAssert) {
|
||||
addConstraint(new CEqual(Ty, inferTypeExpression(SigLet->TypeAssert->TypeExpression)->substitute(Sub), Let));
|
||||
}
|
||||
|
||||
// Here we do the actual unification of e.g. Eq a with Eq Bool. The
|
||||
// unification variables we created previously will be unified with
|
||||
// e.g. Bool, which causes the type assert to also collapse to e.g.
|
||||
// Bool -> Bool -> Bool.
|
||||
for (auto [Param, TE] : zen::zip(Params, Instance->TypeExps)) {
|
||||
addConstraint(new CEqual(Param, TE->getType()));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if (Let->Body) {
|
||||
switch (Let->Body->getKind()) {
|
||||
case NodeKind::LetExprBody:
|
||||
break;
|
||||
case NodeKind::LetBlockBody:
|
||||
{
|
||||
auto Block = static_cast<LetBlockBody*>(Let->Body);
|
||||
if (IsFunc) {
|
||||
Let->Ctx->ReturnType = createTypeVar();
|
||||
}
|
||||
for (auto Element: Block->Elements) {
|
||||
forwardDeclare(Element);
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
ZEN_UNREACHABLE
|
||||
}
|
||||
}
|
||||
|
||||
if (HasContext) {
|
||||
Contexts.pop_back();
|
||||
inferBindings(Let->Pattern, Ty, Let->Ctx->Constraints, Let->Ctx->TVs);
|
||||
} else {
|
||||
inferBindings(Let->Pattern, Ty);
|
||||
}
|
||||
|
||||
// These declarations will be handled separately in check()
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
ZEN_UNREACHABLE
|
||||
|
@ -344,6 +240,173 @@ namespace bolt {
|
|||
|
||||
}
|
||||
|
||||
void Checker::forwardDeclareLetDeclaration(LetDeclaration* N, TVSet* TVs, ConstraintSet* Constraints) {
|
||||
|
||||
auto Let = static_cast<LetDeclaration*>(N);
|
||||
bool IsFunc = !Let->Params.empty();
|
||||
bool IsInstance = llvm::isa<InstanceDeclaration>(Let->Parent);
|
||||
bool IsClass = llvm::isa<ClassDeclaration>(Let->Parent);
|
||||
bool HasContext = IsFunc || IsInstance || IsClass;
|
||||
|
||||
if (HasContext) {
|
||||
Let->Ctx = createInferContext(TVs, Constraints);
|
||||
Contexts.push_back(Let->Ctx);
|
||||
}
|
||||
|
||||
// If declaring a let-declaration inside a type class declaration,
|
||||
// we need to mark that the let-declaration requires this class.
|
||||
// This marking is set on the rigid type variables of the class, which
|
||||
// are then added to this local type environment.
|
||||
if (IsClass) {
|
||||
auto Class = static_cast<ClassDeclaration*>(Let->Parent);
|
||||
for (auto TE: Class->TypeVars) {
|
||||
auto TV = llvm::cast<TVar>(TE->getType());
|
||||
Let->Ctx->Env.emplace(TE->Name->getCanonicalText(), new Forall(TV));
|
||||
Let->Ctx->TVs->emplace(TV);
|
||||
}
|
||||
}
|
||||
|
||||
// Here we infer the primary type of the let declaration. If there's a
|
||||
// type assert, that assert should be authoritative so we use that.
|
||||
// Otherwise, the type is not further specified and we create a new
|
||||
// unification variable.
|
||||
Type* Ty;
|
||||
if (Let->TypeAssert) {
|
||||
Ty = inferTypeExpression(Let->TypeAssert->TypeExpression);
|
||||
} else {
|
||||
Ty = createTypeVar();
|
||||
}
|
||||
Let->Ty = Ty;
|
||||
|
||||
// If declaring a let-declaration inside a type instance declaration,
|
||||
// we need to perform some work to make sure the type asserts of the
|
||||
// corresponding let-declaration in the type class declaration are
|
||||
// accounted for.
|
||||
if (IsInstance) {
|
||||
|
||||
auto Instance = static_cast<InstanceDeclaration*>(Let->Parent);
|
||||
auto Class = llvm::cast<ClassDeclaration>(Instance->getScope()->lookup({ {}, Instance->Name->getCanonicalText() }, SymbolKind::Class));
|
||||
|
||||
// The type asserts in the type class declaration might make use of
|
||||
// the type parameters of the type class declaration, so it is
|
||||
// important to make them available in the type environment. Moreover,
|
||||
// we will be unifying them with the actual types declared in the
|
||||
// instance declaration, so we keep track of them.
|
||||
std::vector<TVar *> Params;
|
||||
TVSub Sub;
|
||||
for (auto TE: Class->TypeVars) {
|
||||
auto TV = createTypeVar();
|
||||
Sub.emplace(llvm::cast<TVar>(TE->getType()), TV);
|
||||
Params.push_back(TV);
|
||||
}
|
||||
|
||||
auto SigLet = llvm::cast<LetDeclaration>(Class->getScope()->lookupDirect({ {}, llvm::cast<BindPattern>(Let->Pattern)->Name->getCanonicalText() }, SymbolKind::Var));
|
||||
|
||||
// It would be very strange if there was no type assert in the type
|
||||
// class let-declaration but we rather not let the compiler crash if that happens.
|
||||
if (SigLet->TypeAssert) {
|
||||
addConstraint(new CEqual(Ty, inferTypeExpression(SigLet->TypeAssert->TypeExpression)->substitute(Sub), Let));
|
||||
}
|
||||
|
||||
// Here we do the actual unification of e.g. Eq a with Eq Bool. The
|
||||
// unification variables we created previously will be unified with
|
||||
// e.g. Bool, which causes the type assert to also collapse to e.g.
|
||||
// Bool -> Bool -> Bool.
|
||||
for (auto [Param, TE] : zen::zip(Params, Instance->TypeExps)) {
|
||||
addConstraint(new CEqual(Param, TE->getType()));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if (Let->Body) {
|
||||
switch (Let->Body->getKind()) {
|
||||
case NodeKind::LetExprBody:
|
||||
break;
|
||||
case NodeKind::LetBlockBody:
|
||||
{
|
||||
auto Block = static_cast<LetBlockBody*>(Let->Body);
|
||||
if (IsFunc) {
|
||||
Let->Ctx->ReturnType = createTypeVar();
|
||||
}
|
||||
for (auto Element: Block->Elements) {
|
||||
forwardDeclare(Element);
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
ZEN_UNREACHABLE
|
||||
}
|
||||
}
|
||||
|
||||
if (HasContext) {
|
||||
Contexts.pop_back();
|
||||
inferBindings(Let->Pattern, Ty, Let->Ctx->Constraints, Let->Ctx->TVs);
|
||||
} else {
|
||||
inferBindings(Let->Pattern, Ty);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void Checker::inferLetDeclaration(LetDeclaration* N) {
|
||||
|
||||
auto Decl = static_cast<LetDeclaration*>(N);
|
||||
bool IsFunc = !Decl->Params.empty();
|
||||
bool IsInstance = llvm::isa<InstanceDeclaration>(Decl->Parent);
|
||||
bool IsClass = llvm::isa<ClassDeclaration>(Decl->Parent);
|
||||
bool HasContext = IsFunc || IsInstance || IsClass;
|
||||
|
||||
if (HasContext) {
|
||||
Contexts.push_back(Decl->Ctx);
|
||||
}
|
||||
|
||||
std::vector<Type*> ParamTypes;
|
||||
Type* RetType;
|
||||
|
||||
for (auto Param: Decl->Params) {
|
||||
// TODO incorporate Param->TypeAssert or make it a kind of pattern
|
||||
TVar* TV = createTypeVar();
|
||||
inferBindings(Param->Pattern, TV);
|
||||
ParamTypes.push_back(TV);
|
||||
}
|
||||
|
||||
if (Decl->Body) {
|
||||
switch (Decl->Body->getKind()) {
|
||||
case NodeKind::LetExprBody:
|
||||
{
|
||||
auto Expr = static_cast<LetExprBody*>(Decl->Body);
|
||||
RetType = inferExpression(Expr->Expression);
|
||||
break;
|
||||
}
|
||||
case NodeKind::LetBlockBody:
|
||||
{
|
||||
auto Block = static_cast<LetBlockBody*>(Decl->Body);
|
||||
ZEN_ASSERT(HasContext);
|
||||
RetType = Decl->Ctx->ReturnType;
|
||||
for (auto Element: Block->Elements) {
|
||||
infer(Element);
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
ZEN_UNREACHABLE
|
||||
}
|
||||
} else {
|
||||
RetType = createTypeVar();
|
||||
}
|
||||
|
||||
if (HasContext) {
|
||||
Contexts.pop_back();
|
||||
}
|
||||
|
||||
if (IsFunc) {
|
||||
addConstraint(new CEqual { Decl->Ty, new TArrow(ParamTypes, RetType), N });
|
||||
} else {
|
||||
// Declaration is a plain (typed) variable
|
||||
addConstraint(new CEqual { Decl->Ty, RetType, N });
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void Checker::infer(Node* N) {
|
||||
|
||||
switch (N->getKind()) {
|
||||
|
@ -369,11 +432,9 @@ namespace bolt {
|
|||
case NodeKind::InstanceDeclaration:
|
||||
{
|
||||
auto Decl = static_cast<InstanceDeclaration*>(N);
|
||||
|
||||
for (auto Element: Decl->Elements) {
|
||||
infer(Element);
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
|
@ -392,71 +453,18 @@ namespace bolt {
|
|||
}
|
||||
|
||||
case NodeKind::LetDeclaration:
|
||||
{
|
||||
auto Decl = static_cast<LetDeclaration*>(N);
|
||||
bool HasContext = !Decl->Params.empty();
|
||||
|
||||
if (HasContext) {
|
||||
Contexts.push_back(Decl->Ctx);
|
||||
}
|
||||
|
||||
std::vector<Type*> ParamTypes;
|
||||
Type* RetType;
|
||||
|
||||
for (auto Param: Decl->Params) {
|
||||
// TODO incorporate Param->TypeAssert or make it a kind of pattern
|
||||
TVar* TV = createTypeVar();
|
||||
inferBindings(Param->Pattern, TV);
|
||||
ParamTypes.push_back(TV);
|
||||
}
|
||||
|
||||
if (Decl->Body) {
|
||||
switch (Decl->Body->getKind()) {
|
||||
case NodeKind::LetExprBody:
|
||||
{
|
||||
auto Expr = static_cast<LetExprBody*>(Decl->Body);
|
||||
RetType = inferExpression(Expr->Expression);
|
||||
break;
|
||||
}
|
||||
case NodeKind::LetBlockBody:
|
||||
{
|
||||
auto Block = static_cast<LetBlockBody*>(Decl->Body);
|
||||
ZEN_ASSERT(HasContext);
|
||||
RetType = Decl->Ctx->ReturnType;
|
||||
for (auto Element: Block->Elements) {
|
||||
infer(Element);
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
ZEN_UNREACHABLE
|
||||
}
|
||||
} else {
|
||||
RetType = createTypeVar();
|
||||
}
|
||||
|
||||
if (HasContext) {
|
||||
// Declaration is a function
|
||||
addConstraint(new CEqual { Decl->Ty, new TArrow(ParamTypes, RetType), N });
|
||||
Contexts.pop_back();
|
||||
} else {
|
||||
// Declaration is a plain (typed) variable
|
||||
addConstraint(new CEqual { Decl->Ty, RetType, N });
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case NodeKind::ReturnStatement:
|
||||
{
|
||||
auto RetStmt = static_cast<ReturnStatement*>(N);
|
||||
Type* ReturnType;
|
||||
if (RetStmt->Expression) {
|
||||
ReturnType = inferExpression(RetStmt->Expression);
|
||||
addConstraint(new CEqual { inferExpression(RetStmt->Expression), getReturnType(), RetStmt->Expression });
|
||||
} else {
|
||||
ReturnType = new TTuple({});
|
||||
addConstraint(new CEqual { new TTuple({}), getReturnType(), N });
|
||||
}
|
||||
addConstraint(new CEqual { ReturnType, getReturnType(), N });
|
||||
break;
|
||||
}
|
||||
|
||||
|
@ -486,7 +494,7 @@ namespace bolt {
|
|||
return TV;
|
||||
}
|
||||
|
||||
InferContext* Checker::createInferContext() {
|
||||
InferContext* Checker::createInferContext(TVSet* TVs, ConstraintSet* Constraints) {
|
||||
auto Ctx = new InferContext;
|
||||
Ctx->TVs = new TVSet;
|
||||
Ctx->Constraints = new ConstraintSet;
|
||||
|
@ -677,11 +685,12 @@ namespace bolt {
|
|||
{
|
||||
auto Ref = static_cast<ReferenceExpression*>(X);
|
||||
ZEN_ASSERT(Ref->ModulePath.empty());
|
||||
auto Ctx = lookupCall(Ref, Ref->getSymbolPath());
|
||||
if (Ctx) {
|
||||
/* std::cerr << "recursive call!\n"; */
|
||||
ZEN_ASSERT(Ctx->ReturnType != nullptr);
|
||||
return Ctx->ReturnType;
|
||||
auto Target = Ref->getScope()->lookup(Ref->getSymbolPath());
|
||||
if (Target && llvm::isa<LetDeclaration>(Target)) {
|
||||
auto Let = static_cast<LetDeclaration*>(Target);
|
||||
if (Let->IsCycleActive) {
|
||||
return Let->Ty;
|
||||
}
|
||||
}
|
||||
auto Scm = lookup(Ref->Name->getCanonicalText());
|
||||
if (Scm == nullptr) {
|
||||
|
@ -821,6 +830,43 @@ namespace bolt {
|
|||
return Ty;
|
||||
}
|
||||
|
||||
void Checker::populate(SourceFile* SF) {
|
||||
|
||||
struct Visitor : public CSTVisitor<Visitor> {
|
||||
|
||||
Graph<Node*>& RefGraph;
|
||||
|
||||
std::stack<Node*> Stack;
|
||||
|
||||
void visitLetDeclaration(LetDeclaration* N) {
|
||||
RefGraph.addVertex(N);
|
||||
Stack.push(N);
|
||||
visitEachChild(N);
|
||||
Stack.pop();
|
||||
}
|
||||
|
||||
void visitReferenceExpression(ReferenceExpression* N) {
|
||||
auto Y = static_cast<ReferenceExpression*>(N);
|
||||
auto Def = Y->getScope()->lookup(Y->getSymbolPath());
|
||||
// Name lookup failures will be reported directly in inferExpression().
|
||||
// Parameters are clearly no let-decarations. They never have their own
|
||||
// inference context, so we have to skip them.
|
||||
if (Def == nullptr || Def->getKind() == NodeKind::Parameter) {
|
||||
return;
|
||||
}
|
||||
ZEN_ASSERT(Def->getKind() == NodeKind::LetDeclaration || Def->getKind() == NodeKind::SourceFile);
|
||||
RefGraph.addEdge(Stack.top(), Def);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
RefGraph.addVertex(SF);
|
||||
Visitor V { {}, RefGraph };
|
||||
V.Stack.push(SF);
|
||||
V.visit(SF);
|
||||
|
||||
}
|
||||
|
||||
void Checker::checkTypeclassSigs(Node* N) {
|
||||
|
||||
struct LetVisitor : CSTVisitor<LetVisitor> {
|
||||
|
@ -965,7 +1011,37 @@ namespace bolt {
|
|||
addBinding("-", new Forall(new TArrow({ IntType, IntType }, IntType)));
|
||||
addBinding("*", new Forall(new TArrow({ IntType, IntType }, IntType)));
|
||||
addBinding("/", new Forall(new TArrow({ IntType, IntType }, IntType)));
|
||||
populate(SF);
|
||||
forwardDeclare(SF);
|
||||
auto SCCs = RefGraph.strongconnect();
|
||||
for (auto Nodes: SCCs) {
|
||||
if (Nodes.size() == 1 && llvm::isa<SourceFile>(Nodes[0])) {
|
||||
continue;
|
||||
}
|
||||
auto TVs = new TVSet;
|
||||
auto Constraints = new ConstraintSet;
|
||||
for (auto N: Nodes) {
|
||||
auto Decl = static_cast<LetDeclaration*>(N);
|
||||
forwardDeclareLetDeclaration(Decl, TVs, Constraints);
|
||||
}
|
||||
}
|
||||
for (auto Nodes: SCCs) {
|
||||
if (Nodes.size() == 1 && llvm::isa<SourceFile>(Nodes[0])) {
|
||||
continue;
|
||||
}
|
||||
for (auto N: Nodes) {
|
||||
auto Decl = static_cast<LetDeclaration*>(N);
|
||||
Decl->IsCycleActive = true;
|
||||
}
|
||||
for (auto N: Nodes) {
|
||||
auto Decl = static_cast<LetDeclaration*>(N);
|
||||
inferLetDeclaration(Decl);
|
||||
}
|
||||
for (auto N: Nodes) {
|
||||
auto Decl = static_cast<LetDeclaration*>(N);
|
||||
Decl->IsCycleActive = false;
|
||||
}
|
||||
}
|
||||
infer(SF);
|
||||
Contexts.pop_back();
|
||||
solve(new CMany(*RootContext->Constraints), Solution);
|
||||
|
@ -1280,7 +1356,9 @@ namespace bolt {
|
|||
To = Var2;
|
||||
From = Var1;
|
||||
}
|
||||
join(From, To);
|
||||
if (From->Id != To->Id) {
|
||||
join(From, To);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -442,7 +442,7 @@ namespace bolt {
|
|||
{
|
||||
auto E = static_cast<const BindingNotFoundDiagnostic&>(D);
|
||||
writePrefix(E);
|
||||
write("binding '");
|
||||
write("binding ");
|
||||
writeBinding(E.Name);
|
||||
write(" was not found\n\n");
|
||||
if (E.Initiator != nullptr) {
|
||||
|
|
109
src/IPRGraph.cc
109
src/IPRGraph.cc
|
@ -1,109 +0,0 @@
|
|||
|
||||
#include <stack>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "zen/config.hpp"
|
||||
|
||||
#include "bolt/CST.hpp"
|
||||
#include "bolt/IPRGraph.hpp"
|
||||
|
||||
namespace bolt {
|
||||
|
||||
void IPRGraph::populate(Node* X, Node* Decl) {
|
||||
|
||||
switch (X->getKind()) {
|
||||
|
||||
case NodeKind::SourceFile:
|
||||
{
|
||||
auto Y = static_cast<SourceFile*>(X);
|
||||
for (auto Element: Y->Elements) {
|
||||
populate(Element, Decl);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case NodeKind::IfStatement:
|
||||
{
|
||||
auto Y = static_cast<IfStatement*>(X);
|
||||
for (auto Part: Y->Parts) {
|
||||
for (auto Element: Part->Elements) {
|
||||
populate(Element, Decl);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case NodeKind::LetDeclaration:
|
||||
{
|
||||
auto Y = static_cast<LetDeclaration*>(X);
|
||||
if (Y->Body) {
|
||||
switch (Y->Body->getKind()) {
|
||||
case NodeKind::LetBlockBody:
|
||||
{
|
||||
auto Z = static_cast<LetBlockBody*>(Y->Body);
|
||||
for (auto Element: Z->Elements) {
|
||||
populate(Element, Y);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case NodeKind::LetExprBody:
|
||||
{
|
||||
auto Z = static_cast<LetExprBody*>(Y->Body);
|
||||
populate(Z->Expression, Y);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
ZEN_UNREACHABLE
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case NodeKind::ConstantExpression:
|
||||
break;
|
||||
|
||||
case NodeKind::CallExpression:
|
||||
{
|
||||
auto Y = static_cast<CallExpression*>(X);
|
||||
populate(Y->Function, Decl);
|
||||
for (auto Arg: Y->Args) {
|
||||
populate(Arg, Decl);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case NodeKind::ReferenceExpression:
|
||||
{
|
||||
auto Y = static_cast<ReferenceExpression*>(X);
|
||||
auto Def = Y->getScope()->lookup(Y->getSymbolPath());
|
||||
ZEN_ASSERT(Def != nullptr);
|
||||
if (Decl != nullptr) {
|
||||
Edges.emplace(Decl, Y);
|
||||
}
|
||||
Edges.emplace(Y, Def);
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
ZEN_UNREACHABLE
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/* bool IPRGraph::isRecursive(ReferenceExpression* From) { */
|
||||
/* std::unordered_set<Node*> Visited; */
|
||||
/* std::stack<Node*> Queue; */
|
||||
/* while (Queue.size()) { */
|
||||
/* auto A = Queue.top(); */
|
||||
/* Queue.pop(); */
|
||||
/* if (Visited.count(A)) { */
|
||||
/* return true; */
|
||||
/* } */
|
||||
/* for (auto B: getOutEdges(A)) { */
|
||||
/* } */
|
||||
/* } */
|
||||
/* return false; */
|
||||
/* } */
|
||||
|
||||
}
|
Loading…
Reference in a new issue