Skip to content

Commit

Permalink
Allow defaults in function signatures with syntax (..., p : t = e, ...).
Browse files Browse the repository at this point in the history
  • Loading branch information
m-kurtenacker committed Dec 12, 2024
1 parent 70e2a43 commit f9c4068
Show file tree
Hide file tree
Showing 10 changed files with 169 additions and 5 deletions.
18 changes: 18 additions & 0 deletions include/artic/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -1640,6 +1640,24 @@ struct ImplicitParamPtrn : public Ptrn {
void print(Printer&) const override;
};

struct DefaultParamPtrn : public Ptrn {
Ptr<Ptrn> underlying;
Ptr<Expr> default_expr;

DefaultParamPtrn(const Loc& loc, Ptr<Ptrn>&& underlying, Ptr<Expr>&& default_expr)
: Ptrn(loc), underlying(std::move(underlying)), default_expr(std::move(default_expr))
{}

bool is_trivial() const override;

void emit(Emitter&, const thorin::Def*) const override;
const artic::Type* infer(TypeChecker&) override;
const artic::Type* check(TypeChecker&, const artic::Type*) override;
void bind(NameBinder&) override;
void resolve_summons(Summoner&) override;
void print(Printer&) const override;
};

/// A pattern that matches against a structure field.
struct FieldPtrn : public Ptrn {
Identifier id;
Expand Down
29 changes: 29 additions & 0 deletions include/artic/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,34 @@ struct ImplicitParamType : public Type {
friend class TypeTable;
};

struct DefaultParamType : public Type {
const Type* underlying;
const ast::Expr* expr;

void print(Printer&) const override;
bool equals(const Type*) const override;
size_t hash() const override;
bool contains(const Type*) const override;

const Type* replace(const ReplaceMap&) const override;

const thorin::Type* convert(Emitter&) const override;
std::string stringify(Emitter&) const override;

size_t order(std::unordered_set<const Type*>&) const override;
void variance(TypeVarMap<TypeVariance>&, bool) const override;
void bounds(TypeVarMap<TypeBounds>&, const Type*, bool) const override;
bool is_sized(std::unordered_set<const Type*>&) const override;
private:
DefaultParamType(TypeTable& type_table, const Type* underlying, const ast::Expr* expr)
: Type(type_table)
, underlying(underlying)
, expr(expr)
{}

friend class TypeTable;
};

/// Function type (can represent continuations when the codomain is a `NoRetType`).
struct FnType : public Type {
const Type* dom;
Expand Down Expand Up @@ -672,6 +700,7 @@ class TypeTable {
const PtrType* ptr_type(const Type*, bool, size_t);
const RefType* ref_type(const Type*, bool, size_t);
const ImplicitParamType* implicit_param_type(const Type*);
const DefaultParamType* default_param_type(const Type*, const ast::Expr*);
const FnType* fn_type(const Type*, const Type*);
const FnType* cn_type(const Type*);
const BottomType* bottom_type();
Expand Down
4 changes: 4 additions & 0 deletions src/ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,10 @@ bool ImplicitParamPtrn::is_trivial() const {
return underlying->is_trivial();
}

bool DefaultParamPtrn::is_trivial() const {
return underlying->is_trivial();
}

void FieldPtrn::collect_bound_ptrns(std::vector<const IdPtrn*>& bound_ptrns) const {
if (ptrn)
ptrn->collect_bound_ptrns(bound_ptrns);
Expand Down
5 changes: 5 additions & 0 deletions src/bind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,11 @@ void ImplicitParamPtrn::bind(artic::NameBinder& binder) {
underlying->bind(binder);
}

void DefaultParamPtrn::bind(artic::NameBinder& binder) {
underlying->bind(binder);
default_expr->bind(binder);
}

void FieldPtrn::bind(NameBinder& binder) {
if (ptrn) binder.bind(*ptrn);
}
Expand Down
22 changes: 21 additions & 1 deletion src/check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ static bool is_unit(Ptr<ast::Expr>& expr) {

static bool is_tuple_type_with_implicits(const artic::Type* type) {
if (auto tuple_t = type->isa<artic::TupleType>(); tuple_t && !is_unit_type(tuple_t))
return std::any_of(tuple_t->args.begin(), tuple_t->args.end(), [&](auto arg){ return arg->template isa<ImplicitParamType>(); });
return std::any_of(tuple_t->args.begin(), tuple_t->args.end(), [&](auto arg){ return arg->template isa<ImplicitParamType>() || arg->template isa<DefaultParamType>(); });
return false;
}

Expand Down Expand Up @@ -198,6 +198,13 @@ const Type* TypeChecker::coerce(Ptr<ast::Expr>& expr, const Type* expected) {
args.push_back(std::move(summoned));
continue;
}
if (auto default_type = tuple_t->args[i]->isa<DefaultParamType>()) {
Ptr<ast::SummonExpr> summoned = make_ptr<ast::SummonExpr>(loc, Ptr<ast::Type>());
summoned->type = default_type->underlying;
summoned->resolved = default_type->expr;
args.push_back(std::move(summoned));
continue;
}

bad_arguments(loc, "non-implicit arguments", i, tuple_t->args.size());
}
Expand Down Expand Up @@ -1882,6 +1889,19 @@ const artic::Type * ImplicitParamPtrn::check(artic::TypeChecker& checker, const
return checker.type_table.implicit_param_type(underlying->type);
}

//TODO: we can use the default expression to infer the type of this pattern, and need to check it as well.
const artic::Type* DefaultParamPtrn::infer(artic::TypeChecker& checker) {
checker.infer(*default_expr);
checker.check(*underlying, default_expr->type);
return checker.type_table.default_param_type(default_expr->type, &*default_expr);
}

const artic::Type *DefaultParamPtrn::check(artic::TypeChecker& checker, const artic::Type* expected) {
checker.check(*underlying, expected);
checker.check(*default_expr, expected);
return checker.type_table.implicit_param_type(underlying->type);
}

const artic::Type* FieldPtrn::check(TypeChecker& checker, const artic::Type* expected) {
return checker.check(*ptrn, expected);
}
Expand Down
14 changes: 14 additions & 0 deletions src/emit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,8 @@ const thorin::Def* Emitter::down_cast(const thorin::Def* def, const Type* from,

if (to->isa<ImplicitParamType>())
return def;
if (to->isa<DefaultParamType>())
return def;

auto to_ptr_type = to->isa<PtrType>();
// Casting a value to a pointer to the type of the value effectively creates an allocation
Expand Down Expand Up @@ -1786,6 +1788,10 @@ void ImplicitParamPtrn::emit(artic::Emitter& emitter, const thorin::Def* value)
underlying->emit(emitter, value);
}

void DefaultParamPtrn::emit(artic::Emitter& emitter, const thorin::Def* value) const {
underlying->emit(emitter, value);
}

void FieldPtrn::emit(Emitter& emitter, const thorin::Def* value) const {
emitter.emit(*ptrn, value);
}
Expand Down Expand Up @@ -1900,6 +1906,10 @@ std::string ImplicitParamType::stringify(Emitter& emitter) const {
return "implicit_" + underlying->stringify(emitter);
}

std::string DefaultParamType::stringify(Emitter& emitter) const {
return "default_" + underlying->stringify(emitter);
}

std::string FnType::stringify(Emitter& emitter) const {
return "fn_" + dom->stringify(emitter) + "_" + codom->stringify(emitter);
}
Expand All @@ -1908,6 +1918,10 @@ const thorin::Type* ImplicitParamType::convert(artic::Emitter& emitter) const {
return underlying->convert(emitter);
}

const thorin::Type* DefaultParamType::convert(artic::Emitter& emitter) const {
return underlying->convert(emitter);
}

const thorin::Type* FnType::convert(Emitter& emitter) const {
if (codom->isa<BottomType>())
return emitter.continuation_type_with_mem(dom->convert(emitter));
Expand Down
16 changes: 13 additions & 3 deletions src/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ Ptr<ast::Ptrn> Parser::parse_ptrn(bool allow_types, bool allow_implicits) {
ahead().tag() == Token::LBracket ||
ahead().tag() == Token::LParen ||
ahead().tag() == Token::LBrace ||
(allow_types && ahead().tag() != Token::Colon && ahead().tag() != Token::As)) {
(allow_types && ahead().tag() != Token::Colon && ahead().tag() != Token::As && ahead().tag() != Token::Eq)) {
auto path = parse_path(std::move(id), true);
if (ahead().tag() == Token::LBrace)
ptrn = parse_record_ptrn(std::move(path));
Expand All @@ -311,8 +311,13 @@ Ptr<ast::Ptrn> Parser::parse_ptrn(bool allow_types, bool allow_implicits) {
return make_ptr<ast::TypedPtrn>(path.loc, Ptr<ast::Ptrn>(), std::move(type));
} else
ptrn = parse_ctor_ptrn(std::move(path));
} else
} else {
ptrn = parse_id_ptrn(std::move(id), false);
if (allow_implicits && accept(Token::Eq)) {
auto default_expr = parse_expr();
ptrn = make_ptr<ast::DefaultParamPtrn>(ptrn->loc, std::move(ptrn), std::move(default_expr));
}
}
}
break;
case Token::Mut:
Expand Down Expand Up @@ -349,7 +354,12 @@ Ptr<ast::Ptrn> Parser::parse_ptrn(bool allow_types, bool allow_implicits) {
ptrn = parse_error_ptrn();
break;
}
return parse_typed_ptrn(std::move(ptrn));
ptrn = parse_typed_ptrn(std::move(ptrn));
if (allow_implicits && accept(Token::Eq)) {
auto default_expr = parse_expr();
ptrn = make_ptr<ast::DefaultParamPtrn>(ptrn->loc, std::move(ptrn), std::move(default_expr));
}
return ptrn;
}

Ptr<ast::Ptrn> Parser::parse_typed_ptrn(Ptr<ast::Ptrn>&& ptrn) {
Expand Down
12 changes: 12 additions & 0 deletions src/print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,13 @@ void ImplicitParamPtrn::print(Printer& p) const {
underlying->print(p);
}

void DefaultParamPtrn::print(Printer& p) const {
p << log::keyword_style("default") << ' ';
underlying->print(p);
p << " = ";
default_expr->print(p);
}

void FieldPtrn::print(Printer& p) const {
if (is_etc()) {
p << "...";
Expand Down Expand Up @@ -772,6 +779,11 @@ void ImplicitParamType::print(artic::Printer& p) const {
underlying->print(p);
}

void DefaultParamType::print(artic::Printer& p) const {
p << "default ";
underlying->print(p);
}

void FnType::print(Printer& p) const {
p << log::keyword_style("fn") << ' ';
if (!dom->isa<TupleType>()) p << '(';
Expand Down
9 changes: 8 additions & 1 deletion src/summoner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ void TypedExpr::resolve_summons(artic::Summoner& summoner) {
}

void SummonExpr::resolve_summons(artic::Summoner& summoner) {
resolved = summoner.resolve(type, loc);
if (!resolved)
resolved = summoner.resolve(type, loc);
}

void FieldExpr::resolve_summons(artic::Summoner& summoner) {
Expand Down Expand Up @@ -223,6 +224,12 @@ void ImplicitParamPtrn::resolve_summons(artic::Summoner& summoner) {
summoner.insert(underlying->type, underlying->to_expr());
}

void DefaultParamPtrn::resolve_summons(artic::Summoner& summoner) {
default_expr->resolve_summons(summoner);

underlying->resolve_summons(summoner);
}

void FieldPtrn::resolve_summons(artic::Summoner& summoner) {
if (ptrn) ptrn->resolve_summons(summoner);
}
Expand Down
45 changes: 45 additions & 0 deletions src/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ bool ImplicitParamType::equals(const artic::Type* other) const {
other->as<ImplicitParamType>()->underlying == underlying;
}

bool DefaultParamType::equals(const artic::Type* other) const {
return
other->isa<DefaultParamType>() &&
other->as<DefaultParamType>()->underlying == underlying &&
other->as<DefaultParamType>()->expr == expr;
}

bool FnType::equals(const Type* other) const {
return
other->isa<FnType>() &&
Expand Down Expand Up @@ -120,6 +127,13 @@ size_t ImplicitParamType::hash() const {
.combine(underlying);
}

size_t DefaultParamType::hash() const {
return fnv::Hash()
.combine(typeid(*this).hash_code())
.combine(underlying)
.combine(expr);
}

size_t FnType::hash() const {
return fnv::Hash()
.combine(typeid(*this).hash_code())
Expand Down Expand Up @@ -164,6 +178,10 @@ bool ImplicitParamType::contains(const artic::Type* type) const {
return type == this || underlying->contains(type);
}

bool DefaultParamType::contains(const artic::Type* type) const {
return type == this || underlying->contains(type);
}

bool FnType::contains(const Type* type) const {
return type == this || dom->contains(type) || codom->contains(type);
}
Expand Down Expand Up @@ -206,6 +224,10 @@ const Type* ImplicitParamType::replace(const artic::ReplaceMap& map) const {
return type_table.implicit_param_type(underlying->replace(map));
}

const Type* DefaultParamType::replace(const artic::ReplaceMap& map) const {
return type_table.default_param_type(underlying->replace(map), expr);
}

const Type* FnType::replace(const std::unordered_map<const TypeVar*, const Type*>& map) const {
return type_table.fn_type(dom->replace(map), codom->replace(map));
}
Expand Down Expand Up @@ -233,6 +255,10 @@ size_t ImplicitParamType::order(std::unordered_set<const Type*>& seen) const {
return underlying->order(seen);
}

size_t DefaultParamType::order(std::unordered_set<const Type*>& seen) const {
return underlying->order(seen);
}

size_t FnType::order(std::unordered_set<const Type*>& seen) const {
return 1 + std::max(dom->order(seen), codom->order(seen));
}
Expand Down Expand Up @@ -294,6 +320,10 @@ void ImplicitParamType::variance(TypeVarMap<artic::TypeVariance>& vars, bool dir
return underlying->variance(vars, dir);
}

void DefaultParamType::variance(TypeVarMap<artic::TypeVariance>& vars, bool dir) const {
return underlying->variance(vars, dir);
}

void TypeVar::variance(std::unordered_map<const TypeVar*, TypeVariance>& vars, bool dir) const {
if (auto it = vars.find(this); it != vars.end()) {
bool var_dir = it->second == TypeVariance::Covariant ? true : false;
Expand Down Expand Up @@ -333,6 +363,10 @@ void ImplicitParamType::bounds(TypeVarMap<artic::TypeBounds>& bounds, const arti
underlying->bounds(bounds, type, dir);
}

void DefaultParamType::bounds(TypeVarMap<artic::TypeBounds>& bounds, const artic::Type* type, bool dir) const {
underlying->bounds(bounds, type, dir);
}

void FnType::bounds(std::unordered_map<const TypeVar*, TypeBounds>& bounds, const Type* type, bool dir) const {
if (auto fn_type = type->isa<FnType>()) {
dom->bounds(bounds, fn_type->dom, !dir);
Expand Down Expand Up @@ -370,6 +404,10 @@ bool ImplicitParamType::is_sized(std::unordered_set<const Type*>& seen) const {
return underlying->is_sized(seen);
}

bool DefaultParamType::is_sized(std::unordered_set<const Type*>& seen) const {
return underlying->is_sized(seen);
}

bool FnType::is_sized(std::unordered_set<const Type*>& seen) const {
return dom->is_sized(seen) && codom->is_sized(seen);
}
Expand Down Expand Up @@ -491,6 +529,9 @@ bool Type::subtype(const Type* other) const {
if (auto implicit = other->isa<ImplicitParamType>())
return this->subtype(implicit->underlying) || is_unit_type(this);

if (auto default_type = other->isa<DefaultParamType>())
return this->subtype(default_type->underlying) || is_unit_type(this);

auto other_ptr_type = other->isa<PtrType>();

// Take the address of values automatically:
Expand Down Expand Up @@ -678,6 +719,10 @@ const ImplicitParamType* TypeTable::implicit_param_type(const Type* underlying)
return insert<ImplicitParamType>(underlying);
}

const DefaultParamType* TypeTable::default_param_type(const Type* underlying, const ast::Expr* expr) {
return insert<DefaultParamType>(underlying, expr);
}

const FnType* TypeTable::fn_type(const Type* dom, const Type* codom) {
return insert<FnType>(dom, codom);
}
Expand Down

0 comments on commit f9c4068

Please sign in to comment.