Skip to content

Commit

Permalink
added Fn trait
Browse files Browse the repository at this point in the history
  • Loading branch information
TomerStarkware committed Nov 14, 2024
1 parent 4a674b5 commit 42020d5
Show file tree
Hide file tree
Showing 16 changed files with 340 additions and 120 deletions.
1 change: 1 addition & 0 deletions corelib/src/ops.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ use range::RangeOp;

mod function;
pub use function::FnOnce;
pub use function::Fn;
22 changes: 22 additions & 0 deletions corelib/src/ops/function.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,25 @@ pub trait FnOnce<T, Args> {
/// Performs the call operation.
fn call(self: T, args: Args) -> Self::Output;
}

/// An implementation of `FnOnce` when `Fn` is implemented.
/// Makes sure we can always pass an `Fn` to a function that expects an `FnOnce`.
impl FnOnceImpl<T, Args, +Destruct<T>, +Fn<T, Args>> of FnOnce<T, Args> {
type Output = Fn::<T, Args>::Output;
fn call(self: T, args: Args) -> Self::Output {
Fn::call(@self, args)
}
}

/// The version of the call operator that takes a by-snapshot receiver.
///
/// Instances of `Fn` can be called multiple times.
///
/// `Fn` is implemented automatically by closures that capture only copyable variables.

pub trait Fn<T, Args> {
/// The returned type after the call operator is used.
type Output;
/// Performs the call operation.
fn call(self: @T, args: Args) -> Self::Output;
}
12 changes: 12 additions & 0 deletions corelib/src/test/language_features/closure_test.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,15 @@ fn option_map_test() {
assert_eq!(option_map(Option::Some(2), |x| Option::Some(x)), Option::Some(Option::Some(2)));
}

fn array_map<T, F, impl Fn: core::ops::Fn<F, (T,)>, +Drop<T>, +Drop<F>, +Drop<Fn::Output>>(
arr: [T; 2], f: F,
) -> [core::ops::Fn::<F, (T,)>::Output; 2] {
let [a, b] = arr;
[f(a), f(b)]
}

#[test]
fn array_map_test() {
assert_eq!(array_map([2, 3], |x| x + 3), [5, 6]);
}

28 changes: 15 additions & 13 deletions crates/cairo-lang-lowering/src/borrow_check/test_data/closure
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,13 @@ blk0 (root):
Statements:
(v0: core::felt252) <- 8
(v1: {[email protected]:3:13: 3:16}) <- struct_construct(v0{`x`})
(v2: core::felt252) <- 2
(v3: (core::felt252,)) <- struct_construct(v2{`2`})
(v4: core::felt252) <- Generated core::ops::function::FnOnce::<{[email protected]:3:13: 3:16}, (core::felt252,)>::call(v1{`c`}, v3{`c(2)`})
(v5: core::felt252) <- core::Felt252Add::add(v4{`y`}, v0{`x`})
(v2: {[email protected]:3:13: 3:16}, v3: @{[email protected]:3:13: 3:16}) <- snapshot(v1{`c`})
(v4: core::felt252) <- 2
(v5: (core::felt252,)) <- struct_construct(v4{`2`})
(v6: core::felt252) <- Generated core::ops::function::Fn::<{[email protected]:3:13: 3:16}, (core::felt252,)>::call(v3{`c`}, v5{`c(2)`})
(v7: core::felt252) <- core::Felt252Add::add(v6{`y`}, v0{`x`})
End:
Return(v5)
Return(v7)

//! > ==========================================================================

Expand Down Expand Up @@ -314,15 +315,16 @@ Statements:
(v3: core::array::Array::<core::felt252>, v2: ()) <- core::array::ArrayImpl::<core::felt252>::append(v0{`__array_builder_macro_result__`}, v1{`99_felt252`})
(v4: core::array::Array::<core::felt252>, v5: @core::array::Array::<core::felt252>) <- snapshot(v3{`x`})
(v6: {[email protected]:3:13: 3:16}) <- struct_construct(v5{`|a| { (@x).len() * (a + 3) }`})
(v7: core::integer::u32) <- 2
(v8: (core::integer::u32,)) <- struct_construct(v7{`2`})
(v9: core::integer::u32) <- Generated core::ops::function::FnOnce::<{[email protected]:3:13: 3:16}, (core::integer::u32,)>::call(v6{`c`}, v8{`c(2)`})
(v10: core::array::Array::<core::felt252>, v11: @core::array::Array::<core::felt252>) <- snapshot(v4{`x`})
(v12: core::integer::u32) <- 0
(v13: @core::felt252) <- core::ops::index::DeprecatedIndexViewImpl::<core::array::Array::<core::felt252>, core::integer::u32, @core::felt252, core::array::ArrayIndex::<core::felt252>>::index(v11{`x`}, v12{`0`})
(v14: core::felt252) <- desnap(v13{`0`})
(v7: {[email protected]:3:13: 3:16}, v8: @{[email protected]:3:13: 3:16}) <- snapshot(v6{`c`})
(v9: core::integer::u32) <- 2
(v10: (core::integer::u32,)) <- struct_construct(v9{`2`})
(v11: core::integer::u32) <- Generated core::ops::function::Fn::<{[email protected]:3:13: 3:16}, (core::integer::u32,)>::call(v8{`c`}, v10{`c(2)`})
(v12: core::array::Array::<core::felt252>, v13: @core::array::Array::<core::felt252>) <- snapshot(v4{`x`})
(v14: core::integer::u32) <- 0
(v15: @core::felt252) <- core::ops::index::DeprecatedIndexViewImpl::<core::array::Array::<core::felt252>, core::integer::u32, @core::felt252, core::array::ArrayIndex::<core::felt252>>::index(v13{`x`}, v14{`0`})
(v16: core::felt252) <- desnap(v15{`0`})
End:
Return(v14)
Return(v16)

//! > ==========================================================================

Expand Down
1 change: 1 addition & 0 deletions crates/cairo-lang-lowering/src/ids.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ impl FunctionLongId {
semantic::corelib::destruct_trait_fn(semantic_db),
semantic::corelib::panic_destruct_trait_fn(semantic_db),
semantic::corelib::fn_once_call_trait_fn(semantic_db),
semantic::corelib::fn_call_trait_fn(semantic_db),
]
.contains(&function)
);
Expand Down
146 changes: 91 additions & 55 deletions crates/cairo-lang-lowering/src/lower/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{iter, vec};
use std::vec;

use block_builder::BlockBuilder;
use cairo_lang_debug::DebugWithDb;
Expand All @@ -7,7 +7,7 @@ use cairo_lang_diagnostics::{Diagnostics, Maybe};
use cairo_lang_semantic::corelib::{ErrorPropagationType, unwrap_error_propagation_type};
use cairo_lang_semantic::db::SemanticGroup;
use cairo_lang_semantic::items::functions::{GenericFunctionId, ImplGenericFunctionId};
use cairo_lang_semantic::items::imp::{GeneratedImplItems, GeneratedImplLongId, ImplLongId};
use cairo_lang_semantic::items::imp::ImplLongId;
use cairo_lang_semantic::usage::MemberPath;
use cairo_lang_semantic::{
ConcreteFunction, ConcreteTraitLongId, ExprVar, LocalVariable, VarId, corelib,
Expand Down Expand Up @@ -1780,11 +1780,11 @@ fn add_closure_call_function(
encapsulated_ctx: &mut LoweringContext<'_, '_>,
expr: &semantic::ExprClosure,
closure_info: &ClosureInfo,
trait_id: cairo_lang_defs::ids::TraitId,
) -> Maybe<()> {
let semantic_db = encapsulated_ctx.db.upcast();
let semantic_db: &dyn SemanticGroup = encapsulated_ctx.db.upcast();
let closure_ty = extract_matches!(expr.ty.lookup_intern(semantic_db), TypeLongId::Closure);
let expr_location = encapsulated_ctx.get_location(expr.stable_ptr.untyped());
let trait_id = semantic::corelib::fn_once_trait(semantic_db);
let parameters_ty = TypeLongId::Tuple(closure_ty.param_tys.clone()).intern(semantic_db);
let concrete_trait = ConcreteTraitLongId {
trait_id,
Expand All @@ -1794,18 +1794,26 @@ fn add_closure_call_function(
],
}
.intern(semantic_db);
let trait_function = semantic::corelib::fn_once_call_trait_fn(semantic_db);

let ret_ty = semantic_db.trait_type_by_name(trait_id, "Output".into()).unwrap().unwrap();
let impl_id = ImplLongId::GeneratedImpl(
GeneratedImplLongId {
concrete_trait,
generic_params: vec![],
impl_items: GeneratedImplItems(iter::once((ret_ty, closure_ty.ret_ty)).collect()),
}
.intern(semantic_db),
)
.intern(semantic_db);
let Ok(impl_id) = semantic::types::get_impl_at_context(
semantic_db,
encapsulated_ctx.variables.lookup_context.clone(),
concrete_trait,
None,
) else {
// If the impl doesn't exist, there won't be a call to the call-function, so we don't need
// to generate it.
return Ok(());
};
if !matches!(impl_id.lookup_intern(semantic_db), ImplLongId::GeneratedImpl(_)) {
// If the impl is not generated, we don't need to generate a lowering for it.
return Ok(());
}

let trait_function: cairo_lang_defs::ids::TraitFunctionId = semantic_db
.trait_function_by_name(trait_id, "call".into())
.unwrap()
.expect("Call function must exist for an Fn trait.");

let generic_function =
GenericFunctionId::Impl(ImplGenericFunctionId { impl_id, function: trait_function });
let function = semantic::FunctionLongId {
Expand All @@ -1827,49 +1835,65 @@ fn add_closure_call_function(
let root_block_id = alloc_empty_block(&mut ctx);
let mut builder = BlockBuilder::root(&mut ctx, root_block_id);

let (closure_param_var_id, closure_var) = if trait_id
== semantic::corelib::fn_once_trait(semantic_db)
{
// If the closure is FnOnce, the closure is passed by value.
let closure_param_var = ctx.new_var(VarRequest { ty: expr.ty, location: expr_location });
let closure_var = VarUsage { var_id: closure_param_var, location: expr_location };
(closure_param_var, closure_var)
} else {
// If the closure is Fn the closure argument will be a snapshot, so we need to desnap it.
let closure_param_var = ctx.new_var(VarRequest {
ty: wrap_in_snapshots(semantic_db, expr.ty, 1),
location: expr_location,
});

let closure_var = generators::Desnap {
input: VarUsage { var_id: closure_param_var, location: expr_location },
location: expr_location,
}
.add(&mut ctx, &mut builder.statements);
(closure_param_var, closure_var)
};
let parameters: Vec<VariableId> = [
ctx.new_var(VarRequest { ty: expr.ty, location: expr_location }),
closure_param_var_id,
ctx.new_var(VarRequest { ty: parameters_ty, location: expr_location }),
]
.into();

let root_ok = {
let captured_vars = generators::StructDestructure {
input: VarUsage { var_id: parameters[0], location: expr_location },
var_reqs: chain!(closure_info.members.iter(), closure_info.snapshots.iter())
.map(|(_, ty)| VarRequest { ty: *ty, location: expr_location })
.collect_vec(),
}
.add(&mut ctx, &mut builder.statements);
for (i, (param, _)) in closure_info.members.iter().enumerate() {
builder.semantics.introduce(param.clone(), captured_vars[i]);
}
for (i, (param, _)) in closure_info.snapshots.iter().enumerate() {
builder
.snapped_semantics
.insert(param.clone(), captured_vars[i + closure_info.members.len()]);
}
let param_vars = generators::StructDestructure {
input: VarUsage { var_id: parameters[1], location: expr_location },
var_reqs: closure_ty
.param_tys
.iter()
.map(|ty| VarRequest { ty: *ty, location: expr_location })
.collect_vec(),
}
.add(&mut ctx, &mut builder.statements);
for (param_var, param) in param_vars.into_iter().zip(expr.params.iter()) {
builder
.semantics
.introduce((&parameter_as_member_path(param.clone())).into(), param_var);
}
let lowered_expr = lower_expr(&mut ctx, &mut builder, expr.body);
let maybe_sealed_block = lowered_expr_to_block_scope_end(&mut ctx, builder, lowered_expr);
maybe_sealed_block.and_then(|block_sealed| {
wrap_sealed_block_as_function(&mut ctx, block_sealed, expr.stable_ptr.untyped())?;
Ok(root_block_id)
})
};
let captured_vars = generators::StructDestructure {
input: closure_var,
var_reqs: chain!(closure_info.members.iter(), closure_info.snapshots.iter())
.map(|(_, ty)| VarRequest { ty: *ty, location: expr_location })
.collect_vec(),
}
.add(&mut ctx, &mut builder.statements);
for (i, (param, _)) in closure_info.members.iter().enumerate() {
builder.semantics.introduce(param.clone(), captured_vars[i]);
}
for (i, (param, _)) in closure_info.snapshots.iter().enumerate() {
builder
.snapped_semantics
.insert(param.clone(), captured_vars[i + closure_info.members.len()]);
}
let param_vars = generators::StructDestructure {
input: VarUsage { var_id: parameters[1], location: expr_location },
var_reqs: closure_ty
.param_tys
.iter()
.map(|ty| VarRequest { ty: *ty, location: expr_location })
.collect_vec(),
}
.add(&mut ctx, &mut builder.statements);
for (param_var, param) in param_vars.into_iter().zip(expr.params.iter()) {
builder.semantics.introduce((&parameter_as_member_path(param.clone())).into(), param_var);
}
let lowered_expr = lower_expr(&mut ctx, &mut builder, expr.body);
let maybe_sealed_block = lowered_expr_to_block_scope_end(&mut ctx, builder, lowered_expr);
let root_ok = maybe_sealed_block.and_then(|block_sealed| {
wrap_sealed_block_as_function(&mut ctx, block_sealed, expr.stable_ptr.untyped())?;
Ok(root_block_id)
});
let blocks = root_ok
.map(|_| ctx.blocks.build().expect("Root block must exist."))
.unwrap_or_else(FlatBlocks::new_errored);
Expand Down Expand Up @@ -1910,8 +1934,20 @@ fn lower_expr_closure(
ctx,
expr,
builder.semantics.closures.get(&capture_var_usage.var_id).unwrap(),
semantic::corelib::fn_once_trait(ctx.db.upcast()),
)
.map_err(LoweringFlowError::Failed)?;

if ctx.variables[capture_var_usage.var_id].copyable.is_ok() {
add_closure_call_function(
ctx,
expr,
builder.semantics.closures.get(&capture_var_usage.var_id).unwrap(),
semantic::corelib::fn_trait(ctx.db.upcast()),
)
.map_err(LoweringFlowError::Failed)?;
}

Ok(closure_variable)
}

Expand Down
18 changes: 10 additions & 8 deletions crates/cairo-lang-lowering/src/lower/test_data/closure
Original file line number Diff line number Diff line change
Expand Up @@ -277,23 +277,25 @@ End:
Return()


Generated core::ops::function::FnOnce::call lowering for source location:
Generated core::ops::function::Fn::call lowering for source location:
let c = || a;
^^

Parameters: v0: {[email protected]:6:14: 6:16}, v1: ()
Parameters: v0: @{[email protected]:6:14: 6:16}, v2: ()
blk0 (root):
Statements:
(v2: core::integer::u32) <- struct_destructure(v0)
() <- struct_destructure(v1)
(v1: {[email protected]:6:14: 6:16}) <- desnap(v0)
(v3: core::integer::u32) <- struct_destructure(v1)
() <- struct_destructure(v2)
End:
Return(v2)
Return(v3)


Final lowering:
Parameters: v0: {[email protected]:6:14: 6:16}, v1: ()
Parameters: v0: @{[email protected]:6:14: 6:16}, v1: ()
blk0 (root):
Statements:
(v2: core::integer::u32) <- struct_destructure(v0)
(v2: {[email protected]:6:14: 6:16}) <- desnap(v0)
(v3: core::integer::u32) <- struct_destructure(v2)
End:
Return(v2)
Return(v3)
12 changes: 12 additions & 0 deletions crates/cairo-lang-semantic/src/corelib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -621,10 +621,22 @@ pub fn fn_once_trait(db: &dyn SemanticGroup) -> TraitId {
get_core_trait(db, CoreTraitContext::Ops, "FnOnce".into())
}

pub fn fn_trait(db: &dyn SemanticGroup) -> TraitId {
get_core_trait(db, CoreTraitContext::Ops, "Fn".into())
}

pub fn fn_traits(db: &dyn SemanticGroup) -> [TraitId; 2] {
[fn_trait(db), fn_once_trait(db)]
}

pub fn fn_once_call_trait_fn(db: &dyn SemanticGroup) -> TraitFunctionId {
get_core_trait_fn(db, CoreTraitContext::Ops, "FnOnce".into(), "call".into())
}

pub fn fn_call_trait_fn(db: &dyn SemanticGroup) -> TraitFunctionId {
get_core_trait_fn(db, CoreTraitContext::Ops, "Fn".into(), "call".into())
}

pub fn copy_trait(db: &dyn SemanticGroup) -> TraitId {
get_core_trait(db, CoreTraitContext::TopLevel, "Copy".into())
}
Expand Down
24 changes: 21 additions & 3 deletions crates/cairo-lang-semantic/src/diagnostic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -860,8 +860,19 @@ impl DiagnosticEntry for SemanticDiagnostic {
)
}
}
SemanticDiagnosticKind::TypeEqualTraitReImplementation => {
"Type equals trait should not be re-implemented.".into()
SemanticDiagnosticKind::CallExpressionRequiresFunction { ty, inference_errors } => {
if inference_errors.is_empty() {
format!("Call expression requires a function, found `{}`.", ty.format(db))
} else {
format!(
"Call expression requires a function, found `{}`.\n{}",
ty.format(db),
inference_errors.format(db)
)
}
}
SemanticDiagnosticKind::CompilerTraitReImplementation { trait_id } => {
format!("Trait `{}` should not be re-implemented.", trait_id.full_path(db.upcast()))
}
SemanticDiagnosticKind::ClosureInGlobalScope => {
"Closures are not allowed in this context.".into()
Expand Down Expand Up @@ -1189,7 +1200,12 @@ pub enum SemanticDiagnosticKind {
trait_name: SmolStr,
inference_errors: TraitInferenceErrors,
},
CallExpressionRequiresFunction {
ty: semantic::TypeId,
inference_errors: TraitInferenceErrors,
},
MultipleImplementationOfIndexOperator(semantic::TypeId),

UnsupportedInlineArguments,
RedundantInlineAttribute,
InlineAttrForExternFunctionNotAllowed,
Expand Down Expand Up @@ -1231,7 +1247,9 @@ pub enum SemanticDiagnosticKind {
DerefCycle {
deref_chain: String,
},
TypeEqualTraitReImplementation,
CompilerTraitReImplementation {
trait_id: TraitId,
},
ClosureInGlobalScope,
MaybeMissingColonColon,
CallingShadowedFunction {
Expand Down
Loading

0 comments on commit 42020d5

Please sign in to comment.