diff --git a/corelib/src/ops.cairo b/corelib/src/ops.cairo index eb34c097c59..4efd6830abc 100644 --- a/corelib/src/ops.cairo +++ b/corelib/src/ops.cairo @@ -15,3 +15,4 @@ use range::RangeOp; mod function; pub use function::FnOnce; +pub use function::Fn; diff --git a/corelib/src/ops/function.cairo b/corelib/src/ops/function.cairo index 7ab7c66d011..beb32082390 100644 --- a/corelib/src/ops/function.cairo +++ b/corelib/src/ops/function.cairo @@ -13,3 +13,24 @@ pub trait FnOnce { /// Performs the call operation. fn call(self: T, args: Args) -> Self::Output; } + +/// An implementation of `FnOnce` when `Fn` is implemented. +/// Makes sure we can always pass an `Fn` to a function that expects an `FnOnce`. +impl FnOnceImpl, +Fn> of FnOnce { + type Output = Fn::::Output; + fn call(self: T, args: Args) -> Self::Output { + Fn::call(@self, args) + } +} + +/// The version of the call operator that takes a by-snapshot receiver. +/// +/// Instances of `Fn` can be called multiple times. +/// +/// `Fn` is implemented automatically by closures that capture only copyable variables. +pub trait Fn { + /// The returned type after the call operator is used. + type Output; + /// Performs the call operation. + fn call(self: @T, args: Args) -> Self::Output; +} diff --git a/corelib/src/test/language_features/closure_test.cairo b/corelib/src/test/language_features/closure_test.cairo index b5e9dc50b3c..96ef3ff94ad 100644 --- a/corelib/src/test/language_features/closure_test.cairo +++ b/corelib/src/test/language_features/closure_test.cairo @@ -57,3 +57,15 @@ fn option_map_test() { assert_eq!(option_map(Option::Some(2), |x| Option::Some(x)), Option::Some(Option::Some(2))); } +fn array_map, +Drop, +Drop, +Drop>( + arr: [T; 2], f: F, +) -> [core::ops::Fn::::Output; 2] { + let [a, b] = arr; + [f(a), f(b)] +} + +#[test] +fn array_map_test() { + assert_eq!(array_map([2, 3], |x| x + 3), [5, 6]); +} + diff --git a/crates/cairo-lang-lowering/src/borrow_check/test_data/closure b/crates/cairo-lang-lowering/src/borrow_check/test_data/closure index 3c0b3fe2ef4..d4e2dfa9c55 100644 --- a/crates/cairo-lang-lowering/src/borrow_check/test_data/closure +++ b/crates/cairo-lang-lowering/src/borrow_check/test_data/closure @@ -272,12 +272,13 @@ blk0 (root): Statements: (v0: core::felt252) <- 8 (v1: {closure@lib.cairo:3:13: 3:16}) <- struct_construct(v0{`x`}) - (v2: core::felt252) <- 2 - (v3: (core::felt252,)) <- struct_construct(v2{`2`}) - (v4: core::felt252) <- Generated core::ops::function::FnOnce::<{closure@lib.cairo:3:13: 3:16}, (core::felt252,)>::call(v1{`c`}, v3{`c(2)`}) - (v5: core::felt252) <- core::Felt252Add::add(v4{`y`}, v0{`x`}) + (v2: {closure@lib.cairo:3:13: 3:16}, v3: @{closure@lib.cairo:3:13: 3:16}) <- snapshot(v1{`c`}) + (v4: core::felt252) <- 2 + (v5: (core::felt252,)) <- struct_construct(v4{`2`}) + (v6: core::felt252) <- Generated core::ops::function::Fn::<{closure@lib.cairo:3:13: 3:16}, (core::felt252,)>::call(v3{`c`}, v5{`c(2)`}) + (v7: core::felt252) <- core::Felt252Add::add(v6{`y`}, v0{`x`}) End: - Return(v5) + Return(v7) //! > ========================================================================== @@ -314,15 +315,16 @@ Statements: (v3: core::array::Array::, v2: ()) <- core::array::ArrayImpl::::append(v0{`__array_builder_macro_result__`}, v1{`99_felt252`}) (v4: core::array::Array::, v5: @core::array::Array::) <- snapshot(v3{`x`}) (v6: {closure@lib.cairo:3:13: 3:16}) <- struct_construct(v5{`|a| { (@x).len() * (a + 3) }`}) - (v7: core::integer::u32) <- 2 - (v8: (core::integer::u32,)) <- struct_construct(v7{`2`}) - (v9: core::integer::u32) <- Generated core::ops::function::FnOnce::<{closure@lib.cairo:3:13: 3:16}, (core::integer::u32,)>::call(v6{`c`}, v8{`c(2)`}) - (v10: core::array::Array::, v11: @core::array::Array::) <- snapshot(v4{`x`}) - (v12: core::integer::u32) <- 0 - (v13: @core::felt252) <- core::ops::index::DeprecatedIndexViewImpl::, core::integer::u32, @core::felt252, core::array::ArrayIndex::>::index(v11{`x`}, v12{`0`}) - (v14: core::felt252) <- desnap(v13{`0`}) + (v7: {closure@lib.cairo:3:13: 3:16}, v8: @{closure@lib.cairo:3:13: 3:16}) <- snapshot(v6{`c`}) + (v9: core::integer::u32) <- 2 + (v10: (core::integer::u32,)) <- struct_construct(v9{`2`}) + (v11: core::integer::u32) <- Generated core::ops::function::Fn::<{closure@lib.cairo:3:13: 3:16}, (core::integer::u32,)>::call(v8{`c`}, v10{`c(2)`}) + (v12: core::array::Array::, v13: @core::array::Array::) <- snapshot(v4{`x`}) + (v14: core::integer::u32) <- 0 + (v15: @core::felt252) <- core::ops::index::DeprecatedIndexViewImpl::, core::integer::u32, @core::felt252, core::array::ArrayIndex::>::index(v13{`x`}, v14{`0`}) + (v16: core::felt252) <- desnap(v15{`0`}) End: - Return(v14) + Return(v16) //! > ========================================================================== diff --git a/crates/cairo-lang-lowering/src/ids.rs b/crates/cairo-lang-lowering/src/ids.rs index 54c618afbe9..792b64720cc 100644 --- a/crates/cairo-lang-lowering/src/ids.rs +++ b/crates/cairo-lang-lowering/src/ids.rs @@ -267,6 +267,7 @@ impl FunctionLongId { semantic::corelib::destruct_trait_fn(semantic_db), semantic::corelib::panic_destruct_trait_fn(semantic_db), semantic::corelib::fn_once_call_trait_fn(semantic_db), + semantic::corelib::fn_call_trait_fn(semantic_db), ] .contains(&function) ); diff --git a/crates/cairo-lang-lowering/src/lower/mod.rs b/crates/cairo-lang-lowering/src/lower/mod.rs index 07729ccee30..3d4e85470c7 100644 --- a/crates/cairo-lang-lowering/src/lower/mod.rs +++ b/crates/cairo-lang-lowering/src/lower/mod.rs @@ -1,4 +1,4 @@ -use std::{iter, vec}; +use std::vec; use block_builder::BlockBuilder; use cairo_lang_debug::DebugWithDb; @@ -7,7 +7,7 @@ use cairo_lang_diagnostics::{Diagnostics, Maybe}; use cairo_lang_semantic::corelib::{ErrorPropagationType, unwrap_error_propagation_type}; use cairo_lang_semantic::db::SemanticGroup; use cairo_lang_semantic::items::functions::{GenericFunctionId, ImplGenericFunctionId}; -use cairo_lang_semantic::items::imp::{GeneratedImplItems, GeneratedImplLongId, ImplLongId}; +use cairo_lang_semantic::items::imp::ImplLongId; use cairo_lang_semantic::usage::MemberPath; use cairo_lang_semantic::{ ConcreteFunction, ConcreteTraitLongId, ExprVar, LocalVariable, VarId, corelib, @@ -1780,11 +1780,11 @@ fn add_closure_call_function( encapsulated_ctx: &mut LoweringContext<'_, '_>, expr: &semantic::ExprClosure, closure_info: &ClosureInfo, + trait_id: cairo_lang_defs::ids::TraitId, ) -> Maybe<()> { - let semantic_db = encapsulated_ctx.db.upcast(); + let semantic_db: &dyn SemanticGroup = encapsulated_ctx.db.upcast(); let closure_ty = extract_matches!(expr.ty.lookup_intern(semantic_db), TypeLongId::Closure); let expr_location = encapsulated_ctx.get_location(expr.stable_ptr.untyped()); - let trait_id = semantic::corelib::fn_once_trait(semantic_db); let parameters_ty = TypeLongId::Tuple(closure_ty.param_tys.clone()).intern(semantic_db); let concrete_trait = ConcreteTraitLongId { trait_id, @@ -1794,18 +1794,26 @@ fn add_closure_call_function( ], } .intern(semantic_db); - let trait_function = semantic::corelib::fn_once_call_trait_fn(semantic_db); - - let ret_ty = semantic_db.trait_type_by_name(trait_id, "Output".into()).unwrap().unwrap(); - let impl_id = ImplLongId::GeneratedImpl( - GeneratedImplLongId { - concrete_trait, - generic_params: vec![], - impl_items: GeneratedImplItems(iter::once((ret_ty, closure_ty.ret_ty)).collect()), - } - .intern(semantic_db), - ) - .intern(semantic_db); + let Ok(impl_id) = semantic::types::get_impl_at_context( + semantic_db, + encapsulated_ctx.variables.lookup_context.clone(), + concrete_trait, + None, + ) else { + // If the impl doesn't exist, there won't be a call to the call-function, so we don't need + // to generate it. + return Ok(()); + }; + if !matches!(impl_id.lookup_intern(semantic_db), ImplLongId::GeneratedImpl(_)) { + // If the impl is not generated, we don't need to generate a lowering for it. + return Ok(()); + } + + let trait_function: cairo_lang_defs::ids::TraitFunctionId = semantic_db + .trait_function_by_name(trait_id, "call".into()) + .unwrap() + .expect("Call function must exist for an Fn trait."); + let generic_function = GenericFunctionId::Impl(ImplGenericFunctionId { impl_id, function: trait_function }); let function = semantic::FunctionLongId { @@ -1827,49 +1835,65 @@ fn add_closure_call_function( let root_block_id = alloc_empty_block(&mut ctx); let mut builder = BlockBuilder::root(&mut ctx, root_block_id); + let (closure_param_var_id, closure_var) = if trait_id + == semantic::corelib::fn_once_trait(semantic_db) + { + // If the closure is FnOnce, the closure is passed by value. + let closure_param_var = ctx.new_var(VarRequest { ty: expr.ty, location: expr_location }); + let closure_var = VarUsage { var_id: closure_param_var, location: expr_location }; + (closure_param_var, closure_var) + } else { + // If the closure is Fn the closure argument will be a snapshot, so we need to desnap it. + let closure_param_var = ctx.new_var(VarRequest { + ty: wrap_in_snapshots(semantic_db, expr.ty, 1), + location: expr_location, + }); + + let closure_var = generators::Desnap { + input: VarUsage { var_id: closure_param_var, location: expr_location }, + location: expr_location, + } + .add(&mut ctx, &mut builder.statements); + (closure_param_var, closure_var) + }; let parameters: Vec = [ - ctx.new_var(VarRequest { ty: expr.ty, location: expr_location }), + closure_param_var_id, ctx.new_var(VarRequest { ty: parameters_ty, location: expr_location }), ] .into(); - - let root_ok = { - let captured_vars = generators::StructDestructure { - input: VarUsage { var_id: parameters[0], location: expr_location }, - var_reqs: chain!(closure_info.members.iter(), closure_info.snapshots.iter()) - .map(|(_, ty)| VarRequest { ty: *ty, location: expr_location }) - .collect_vec(), - } - .add(&mut ctx, &mut builder.statements); - for (i, (param, _)) in closure_info.members.iter().enumerate() { - builder.semantics.introduce(param.clone(), captured_vars[i]); - } - for (i, (param, _)) in closure_info.snapshots.iter().enumerate() { - builder - .snapped_semantics - .insert(param.clone(), captured_vars[i + closure_info.members.len()]); - } - let param_vars = generators::StructDestructure { - input: VarUsage { var_id: parameters[1], location: expr_location }, - var_reqs: closure_ty - .param_tys - .iter() - .map(|ty| VarRequest { ty: *ty, location: expr_location }) - .collect_vec(), - } - .add(&mut ctx, &mut builder.statements); - for (param_var, param) in param_vars.into_iter().zip(expr.params.iter()) { - builder - .semantics - .introduce((¶meter_as_member_path(param.clone())).into(), param_var); - } - let lowered_expr = lower_expr(&mut ctx, &mut builder, expr.body); - let maybe_sealed_block = lowered_expr_to_block_scope_end(&mut ctx, builder, lowered_expr); - maybe_sealed_block.and_then(|block_sealed| { - wrap_sealed_block_as_function(&mut ctx, block_sealed, expr.stable_ptr.untyped())?; - Ok(root_block_id) - }) - }; + let captured_vars = generators::StructDestructure { + input: closure_var, + var_reqs: chain!(closure_info.members.iter(), closure_info.snapshots.iter()) + .map(|(_, ty)| VarRequest { ty: *ty, location: expr_location }) + .collect_vec(), + } + .add(&mut ctx, &mut builder.statements); + for (i, (param, _)) in closure_info.members.iter().enumerate() { + builder.semantics.introduce(param.clone(), captured_vars[i]); + } + for (i, (param, _)) in closure_info.snapshots.iter().enumerate() { + builder + .snapped_semantics + .insert(param.clone(), captured_vars[i + closure_info.members.len()]); + } + let param_vars = generators::StructDestructure { + input: VarUsage { var_id: parameters[1], location: expr_location }, + var_reqs: closure_ty + .param_tys + .iter() + .map(|ty| VarRequest { ty: *ty, location: expr_location }) + .collect_vec(), + } + .add(&mut ctx, &mut builder.statements); + for (param_var, param) in param_vars.into_iter().zip(expr.params.iter()) { + builder.semantics.introduce((¶meter_as_member_path(param.clone())).into(), param_var); + } + let lowered_expr = lower_expr(&mut ctx, &mut builder, expr.body); + let maybe_sealed_block = lowered_expr_to_block_scope_end(&mut ctx, builder, lowered_expr); + let root_ok = maybe_sealed_block.and_then(|block_sealed| { + wrap_sealed_block_as_function(&mut ctx, block_sealed, expr.stable_ptr.untyped())?; + Ok(root_block_id) + }); let blocks = root_ok .map(|_| ctx.blocks.build().expect("Root block must exist.")) .unwrap_or_else(FlatBlocks::new_errored); @@ -1910,8 +1934,14 @@ fn lower_expr_closure( ctx, expr, builder.semantics.closures.get(&capture_var_usage.var_id).unwrap(), + if ctx.variables[capture_var_usage.var_id].copyable.is_ok() { + semantic::corelib::fn_trait(ctx.db.upcast()) + } else { + semantic::corelib::fn_once_trait(ctx.db.upcast()) + }, ) .map_err(LoweringFlowError::Failed)?; + Ok(closure_variable) } diff --git a/crates/cairo-lang-lowering/src/lower/test_data/closure b/crates/cairo-lang-lowering/src/lower/test_data/closure index 066c4c53929..8cca0c3e3ec 100644 --- a/crates/cairo-lang-lowering/src/lower/test_data/closure +++ b/crates/cairo-lang-lowering/src/lower/test_data/closure @@ -277,23 +277,25 @@ End: Return() -Generated core::ops::function::FnOnce::call lowering for source location: +Generated core::ops::function::Fn::call lowering for source location: let c = || a; ^^ -Parameters: v0: {closure@lib.cairo:6:14: 6:16}, v1: () +Parameters: v0: @{closure@lib.cairo:6:14: 6:16}, v2: () blk0 (root): Statements: - (v2: core::integer::u32) <- struct_destructure(v0) - () <- struct_destructure(v1) + (v1: {closure@lib.cairo:6:14: 6:16}) <- desnap(v0) + (v3: core::integer::u32) <- struct_destructure(v1) + () <- struct_destructure(v2) End: - Return(v2) + Return(v3) Final lowering: -Parameters: v0: {closure@lib.cairo:6:14: 6:16}, v1: () +Parameters: v0: @{closure@lib.cairo:6:14: 6:16}, v1: () blk0 (root): Statements: - (v2: core::integer::u32) <- struct_destructure(v0) + (v2: {closure@lib.cairo:6:14: 6:16}) <- desnap(v0) + (v3: core::integer::u32) <- struct_destructure(v2) End: - Return(v2) + Return(v3) diff --git a/crates/cairo-lang-semantic/src/corelib.rs b/crates/cairo-lang-semantic/src/corelib.rs index f3bceee80e3..e99b632cf9e 100644 --- a/crates/cairo-lang-semantic/src/corelib.rs +++ b/crates/cairo-lang-semantic/src/corelib.rs @@ -621,10 +621,22 @@ pub fn fn_once_trait(db: &dyn SemanticGroup) -> TraitId { get_core_trait(db, CoreTraitContext::Ops, "FnOnce".into()) } +pub fn fn_trait(db: &dyn SemanticGroup) -> TraitId { + get_core_trait(db, CoreTraitContext::Ops, "Fn".into()) +} + +pub fn fn_traits(db: &dyn SemanticGroup) -> [TraitId; 2] { + [fn_trait(db), fn_once_trait(db)] +} + pub fn fn_once_call_trait_fn(db: &dyn SemanticGroup) -> TraitFunctionId { get_core_trait_fn(db, CoreTraitContext::Ops, "FnOnce".into(), "call".into()) } +pub fn fn_call_trait_fn(db: &dyn SemanticGroup) -> TraitFunctionId { + get_core_trait_fn(db, CoreTraitContext::Ops, "Fn".into(), "call".into()) +} + pub fn copy_trait(db: &dyn SemanticGroup) -> TraitId { get_core_trait(db, CoreTraitContext::TopLevel, "Copy".into()) } diff --git a/crates/cairo-lang-semantic/src/diagnostic.rs b/crates/cairo-lang-semantic/src/diagnostic.rs index 290e343897b..4121b7099e3 100644 --- a/crates/cairo-lang-semantic/src/diagnostic.rs +++ b/crates/cairo-lang-semantic/src/diagnostic.rs @@ -860,8 +860,19 @@ impl DiagnosticEntry for SemanticDiagnostic { ) } } - SemanticDiagnosticKind::TypeEqualTraitReImplementation => { - "Type equals trait should not be re-implemented.".into() + SemanticDiagnosticKind::CallExpressionRequiresFunction { ty, inference_errors } => { + if inference_errors.is_empty() { + format!("Call expression requires a function, found `{}`.", ty.format(db)) + } else { + format!( + "Call expression requires a function, found `{}`.\n{}", + ty.format(db), + inference_errors.format(db) + ) + } + } + SemanticDiagnosticKind::CompilerTraitReImplementation { trait_id } => { + format!("Trait `{}` should not be re-implemented.", trait_id.full_path(db.upcast())) } SemanticDiagnosticKind::ClosureInGlobalScope => { "Closures are not allowed in this context.".into() @@ -1189,7 +1200,12 @@ pub enum SemanticDiagnosticKind { trait_name: SmolStr, inference_errors: TraitInferenceErrors, }, + CallExpressionRequiresFunction { + ty: semantic::TypeId, + inference_errors: TraitInferenceErrors, + }, MultipleImplementationOfIndexOperator(semantic::TypeId), + UnsupportedInlineArguments, RedundantInlineAttribute, InlineAttrForExternFunctionNotAllowed, @@ -1231,7 +1247,9 @@ pub enum SemanticDiagnosticKind { DerefCycle { deref_chain: String, }, - TypeEqualTraitReImplementation, + CompilerTraitReImplementation { + trait_id: TraitId, + }, ClosureInGlobalScope, MaybeMissingColonColon, CallingShadowedFunction { diff --git a/crates/cairo-lang-semantic/src/expr/compute.rs b/crates/cairo-lang-semantic/src/expr/compute.rs index ad9559ad397..36488c13bf5 100644 --- a/crates/cairo-lang-semantic/src/expr/compute.rs +++ b/crates/cairo-lang-semantic/src/expr/compute.rs @@ -845,21 +845,22 @@ fn compute_expr_function_call_semantic( if ctx.are_closures_in_context { // TODO(TomerStarkware): find the correct trait based on captured variables. let fn_once_trait = crate::corelib::fn_once_trait(db); + let fn_trait = crate::corelib::fn_trait(db); let self_expr = ExprAndId { expr: var.clone(), id: ctx.arenas.exprs.alloc(var) }; - let (call_function_id, _, fixed_closure, closure_mutability) = + let mut closure_call_data = |call_trait| { compute_method_function_call_data( ctx, - &[fn_once_trait], + &[call_trait], "call".into(), - self_expr, + self_expr.clone(), syntax.into(), None, |ty, _, inference_errors| { - Some(NoImplementationOfTrait { - ty, - inference_errors, - trait_name: "FnOnce".into(), - }) + if call_trait == fn_once_trait { + Some(CallExpressionRequiresFunction { ty, inference_errors }) + } else { + None + } }, |_, _, _| { unreachable!( @@ -869,7 +870,10 @@ fn compute_expr_function_call_semantic( NoImplementationOfTrait function." ) }, - )?; + ) + }; + let (call_function_id, _, fixed_closure, closure_mutability) = + closure_call_data(fn_trait).or_else(|_| closure_call_data(fn_once_trait))?; let args_iter = args_syntax.elements(syntax_db).into_iter(); // Normal parameters @@ -1049,7 +1053,7 @@ pub fn compute_root_expr( }; let ConcreteTraitLongId { trait_id, generic_args } = concrete_trait_id.lookup_intern(ctx.db); - if trait_id == crate::corelib::fn_once_trait(ctx.db) { + if crate::corelib::fn_traits(ctx.db).contains(&trait_id) { ctx.are_closures_in_context = true; } if trait_id != get_core_trait(ctx.db, CoreTraitContext::MetaProgramming, "TypeEqual".into()) @@ -1910,7 +1914,7 @@ fn compute_method_function_call_data( self_expr: ExprAndId, method_syntax: SyntaxStablePtrId, generic_args_syntax: Option>, - no_implementation_diagnostic: fn( + no_implementation_diagnostic: impl Fn( TypeId, SmolStr, TraitInferenceErrors, diff --git a/crates/cairo-lang-semantic/src/expr/inference.rs b/crates/cairo-lang-semantic/src/expr/inference.rs index fde34cc15e9..c8c7818dff7 100644 --- a/crates/cairo-lang-semantic/src/expr/inference.rs +++ b/crates/cairo-lang-semantic/src/expr/inference.rs @@ -307,11 +307,11 @@ impl Hash for ImplVarTraitItemMappings { pub struct InferenceData { pub inference_id: InferenceId, /// Current inferred assignment for type variables. - pub type_assignment: HashMap, + pub type_assignment: OrderedHashMap, /// Current inferred assignment for const variables. - pub const_assignment: HashMap, + pub const_assignment: OrderedHashMap, /// Current inferred assignment for impl variables. - pub impl_assignment: HashMap, + pub impl_assignment: OrderedHashMap, /// Unsolved impl variables mapping to a maps of trait items to a corresponding item variable. /// Upon solution of the trait conforms the fully known item to the variable. pub impl_vars_trait_item_mappings: HashMap, @@ -346,9 +346,9 @@ impl InferenceData { pub fn new(inference_id: InferenceId) -> Self { Self { inference_id, - type_assignment: HashMap::new(), - impl_assignment: HashMap::new(), - const_assignment: HashMap::new(), + type_assignment: OrderedHashMap::default(), + impl_assignment: OrderedHashMap::default(), + const_assignment: OrderedHashMap::default(), impl_vars_trait_item_mappings: HashMap::new(), type_vars: Vec::new(), impl_vars: Vec::new(), @@ -981,14 +981,16 @@ impl<'db> Inference<'db> { let GenericArgumentId::Type(ty) = garg else { continue; }; - + let ty = inference.rewrite(ty).no_err(); // If the negative impl has a generic argument that is not fully // concrete we can't tell if we should rule out the candidate impl. // For example if we have -TypeEqual we can't tell if S and // T are going to be assigned the same concrete type. // We return `SolutionSet::Ambiguous` here to indicate that more // information is needed. - if !ty.is_fully_concrete(inference.db) { + if !matches!(ty.lookup_intern(inference.db), TypeLongId::Closure(_)) + && !ty.is_fully_concrete(inference.db) + { // TODO(ilya): Try to detect the ambiguity earlier in the // inference process. return Ok(SolutionSet::Ambiguous( diff --git a/crates/cairo-lang-semantic/src/expr/inference/solver.rs b/crates/cairo-lang-semantic/src/expr/inference/solver.rs index 8867d55f0ae..ec8af92bdf8 100644 --- a/crates/cairo-lang-semantic/src/expr/inference/solver.rs +++ b/crates/cairo-lang-semantic/src/expr/inference/solver.rs @@ -227,6 +227,13 @@ impl CandidateSolver { let mut inference_data: InferenceData = InferenceData::new(InferenceId::Canonical); let mut inference = inference_data.inference(db); let (canonical_trait, canonical_embedding) = canonical_trait.embed(&mut inference); + + // If the closure params are not var free, we cannot infer the negative impl. + // We use the canonical trait concretize the closure params. + if let UninferredImpl::GeneratedImpl(imp) = candidate { + inference.conform_traits(imp.lookup_intern(db).concrete_trait, canonical_trait.id)?; + } + // Add the defining module of the candidate to the lookup. let mut lookup_context = lookup_context.clone(); lookup_context.insert_lookup_scope(db, &candidate); diff --git a/crates/cairo-lang-semantic/src/expr/semantic_test_data/closure b/crates/cairo-lang-semantic/src/expr/semantic_test_data/closure index 38893f18aff..4270d0f9e3d 100644 --- a/crates/cairo-lang-semantic/src/expr/semantic_test_data/closure +++ b/crates/cairo-lang-semantic/src/expr/semantic_test_data/closure @@ -142,7 +142,7 @@ Block( ), expr: FunctionCall( ExprFunctionCall { - function: test::bar::<{closure@lib.cairo:6:13: 6:16}, Generated core::ops::function::FnOnce::<{closure@lib.cairo:6:13: 6:16}, (core::integer::u8,)>>, + function: test::bar::<{closure@lib.cairo:6:13: 6:16}, core::ops::function::FnOnceImpl::<{closure@lib.cairo:6:13: 6:16}, (core::integer::u8,), core::traits::DestructFromDrop::<{closure@lib.cairo:6:13: 6:16}, Generated core::traits::Drop::<{closure@lib.cairo:6:13: 6:16}>>, Generated core::ops::function::Fn::<{closure@lib.cairo:6:13: 6:16}, (core::integer::u8,)>>>, args: [ Value( Var( diff --git a/crates/cairo-lang-semantic/src/expr/test_data/closure b/crates/cairo-lang-semantic/src/expr/test_data/closure index 06ffc68f8db..de181b04d9c 100644 --- a/crates/cairo-lang-semantic/src/expr/test_data/closure +++ b/crates/cairo-lang-semantic/src/expr/test_data/closure @@ -235,7 +235,7 @@ fn bar>(c: T) -> core::ops::FnOnce:: } //! > expected_diagnostics -error: Type mismatch: `core::integer::u32` and `core::integer::u64`. +error: Trait has no implementation in context: core::ops::function::FnOnce::<{closure@lib.cairo:5:13: 5:16}, (core::integer::u32,)>. --> lib.cairo:10:23 let _k: felt252 = bar(c); ^*^ @@ -266,10 +266,10 @@ fn bar>(c: T) -> core::ops::FnOnce:: } //! > expected_diagnostics -error: Type mismatch: `core::felt252` and `core::integer::u128`. - --> lib.cairo:10:23 - let _k: felt252 = bar(c); - ^*^ +error: Type mismatch: `core::integer::u128` and `core::felt252`. + --> lib.cairo:9:39 + let _f: u128 = core::ops::FnOnce::call(c, (2,)); + ^**^ //! > ========================================================================== @@ -375,7 +375,7 @@ fn baz>(c: T) -> core::ops::FnOnce:: } //! > expected_diagnostics -error: Trait has no implementation in context: core::ops::function::FnOnce::. +error: Trait has no implementation in context: core::ops::function::Fn::. --> lib.cairo:7:19 let _x: u32 = bar(); ^***^ @@ -409,10 +409,10 @@ fn bar(a: felt252) -> u32 { } //! > expected_diagnostics -error: Type mismatch: `(?6,)` and `(?0, ?1, ?2)`. - --> lib.cairo:8:19 - let _f: u32 = bar(2); - ^****^ +error: Type annotations needed. Failed to infer ?0. + --> lib.cairo:5:18 + let bar = |_a, _b, _c| -> u32 { + ^ //! > ========================================================================== @@ -646,3 +646,60 @@ foo fn bar>(a: T) {} //! > expected_diagnostics + +//! > ========================================================================== + +//! > Calling a non closure type. + +//! > test_runner_name +test_function_diagnostics(expect_diagnostics: true) + +//! > function +fn foo() { + let _y = || 2; + let x: u32 = 5; + x() +} + +//! > function_name +foo + +//! > module_code + +//! > expected_diagnostics +error: Call expression requires a function, found `core::integer::u32`. +Candidate `core::ops::function::FnOnce::call` inference failed with: Trait has no implementation in context: core::ops::function::FnOnce::. + --> lib.cairo:4:5 + x() + ^*^ + +//! > ========================================================================== + +//! > Implementing a closure trait. + +//! > test_runner_name +test_function_diagnostics(expect_diagnostics: true) + +//! > function +fn foo() -> u32 { + let _y = || 2; + let x: u32 = 5; + x() +} + +//! > function_name +foo + +//! > module_code +impl MyImpl of core::ops::FnOnce { + type Output = u32; + fn call(self: u32, args: ()) -> u32 { + self + } +} + +//! > expected_diagnostics +error: Trait `core::ops::function::FnOnce` should not be re-implemented. + --> lib.cairo:1:16 +impl MyImpl of core::ops::FnOnce { + ^************************^ diff --git a/crates/cairo-lang-semantic/src/items/imp.rs b/crates/cairo-lang-semantic/src/items/imp.rs index e93bc1a2e86..012d406bc9f 100644 --- a/crates/cairo-lang-semantic/src/items/imp.rs +++ b/crates/cairo-lang-semantic/src/items/imp.rs @@ -61,8 +61,9 @@ use super::type_aliases::{ }; use super::{TraitOrImplContext, resolve_trait_path}; use crate::corelib::{ - CoreTraitContext, concrete_destruct_trait, concrete_drop_trait, copy_trait, core_submodule, - deref_trait, destruct_trait, drop_trait, fn_once_trait, get_core_trait, panic_destruct_trait, + CoreTraitContext, concrete_destruct_trait, concrete_drop_trait, copy_trait, core_crate, + deref_trait, destruct_trait, drop_trait, fn_once_trait, fn_trait, get_core_trait, + panic_destruct_trait, }; use crate::db::{SemanticGroup, get_resolver_data_options}; use crate::diagnostic::SemanticDiagnosticKind::{self, *}; @@ -662,13 +663,19 @@ pub fn priv_impl_declaration_data_inner( Err(diagnostics.report(&trait_path_syntax, ImplRequirementCycle)) }; - // Check for reimplementation of TypeEqual Trait. + // Check for reimplementation of compilers' Traits. if let Ok(concrete_trait) = concrete_trait { - if concrete_trait.trait_id(db) - == get_core_trait(db, CoreTraitContext::MetaProgramming, "TypeEqual".into()) - && impl_def_id.module_file_id(db.upcast()).0 != core_submodule(db, "metaprogramming") + if [ + get_core_trait(db, CoreTraitContext::MetaProgramming, "TypeEqual".into()), + fn_trait(db), + fn_once_trait(db), + ] + .contains(&concrete_trait.trait_id(db)) + && impl_def_id.parent_module(db.upcast()).owning_crate(db.upcast()) != core_crate(db) { - diagnostics.report(&trait_path_syntax, TypeEqualTraitReImplementation); + diagnostics.report(&trait_path_syntax, CompilerTraitReImplementation { + trait_id: concrete_trait.trait_id(db), + }); } } @@ -676,8 +683,9 @@ pub fn priv_impl_declaration_data_inner( let inference = &mut resolver.inference(); inference.finalize(&mut diagnostics, impl_ast.stable_ptr().untyped()); - let concrete_trait = inference.rewrite(concrete_trait).no_err(); - let generic_params = inference.rewrite(generic_params).no_err(); + let concrete_trait: Result = + inference.rewrite(concrete_trait).no_err(); + let generic_params: Vec = inference.rewrite(generic_params).no_err(); let attributes = impl_ast.attributes(syntax_db).structurize(syntax_db); let mut resolver_data = resolver.data; @@ -1725,30 +1733,29 @@ pub fn find_closure_generated_candidate( }; // Handles the special cases of `Copy`, `Drop`, `Destruct` and `PanicDestruct`. - let handle_mem_trait = |trait_id, neg_impl_trait: Option<_>| { + let mem_trait_generic_params = |trait_id, neg_impl_trait: Option<_>| { let id = db.trait_generic_params(trait_id).unwrap().first().unwrap().id(); - ( - concrete_trait_id, - chain!( - closure_type_long.captured_types.iter().map(|ty| { - GenericParam::Impl(GenericParamImpl { - id, - concrete_trait: Maybe::Ok(db.intern_concrete_trait(ConcreteTraitLongId { - trait_id, - generic_args: vec![GenericArgumentId::Type(*ty)], - })), - }) - }), - neg_impl_trait.map(|neg_impl_trait| { - GenericParam::NegImpl(GenericParamImpl { - id, - concrete_trait: Maybe::Ok(neg_impl_trait), - }) + chain!( + closure_type_long.captured_types.iter().map(|ty| { + GenericParam::Impl(GenericParamImpl { + id, + concrete_trait: Maybe::Ok(db.intern_concrete_trait(ConcreteTraitLongId { + trait_id, + generic_args: vec![GenericArgumentId::Type(*ty)], + })), }) - ) - .collect(), - [].into(), + }), + neg_impl_trait.map(|neg_impl_trait| { + GenericParam::NegImpl(GenericParamImpl { + id, + concrete_trait: Maybe::Ok(neg_impl_trait), + }) + }) ) + .collect() + }; + let handle_mem_trait = |trait_id, neg_impl_trait: Option<_>| { + (concrete_trait_id, mem_trait_generic_params(trait_id, neg_impl_trait), [].into()) }; let (concrete_trait, generic_params, impl_items) = match concrete_trait_id.trait_id(db) { trait_id if trait_id == fn_once_trait(db) => { @@ -1763,7 +1770,43 @@ pub fn find_closure_generated_candidate( } .intern(db); let ret_ty = db.trait_type_by_name(trait_id, "Output".into()).unwrap().unwrap(); - (concrete_trait, vec![], [(ret_ty, closure_type_long.ret_ty)].into()) + + let id = db.trait_generic_params(trait_id).unwrap().first().unwrap().id(); + let param: GenericParam = GenericParam::NegImpl(GenericParamImpl { + id, + concrete_trait: Maybe::Ok( + ConcreteTraitLongId { + trait_id: fn_trait(db), + generic_args: vec![ + GenericArgumentId::Type(closure_type), + GenericArgumentId::Type( + TypeLongId::Tuple(closure_type_long.param_tys.clone()).intern(db), + ), + ], + } + .intern(db), + ), + }); + (concrete_trait, vec![param], [(ret_ty, closure_type_long.ret_ty)].into()) + } + trait_id if trait_id == fn_trait(db) => { + let concrete_trait = ConcreteTraitLongId { + trait_id, + generic_args: vec![ + GenericArgumentId::Type(closure_type), + GenericArgumentId::Type( + TypeLongId::Tuple(closure_type_long.param_tys.clone()).intern(db), + ), + ], + } + .intern(db); + let ret_ty = db.trait_type_by_name(trait_id, "Output".into()).unwrap().unwrap(); + + ( + concrete_trait, + mem_trait_generic_params(copy_trait(db), None), + [(ret_ty, closure_type_long.ret_ty)].into(), + ) } trait_id if trait_id == drop_trait(db) => handle_mem_trait(trait_id, None), trait_id if trait_id == destruct_trait(db) => { diff --git a/crates/cairo-lang-semantic/src/items/tests/trait_type b/crates/cairo-lang-semantic/src/items/tests/trait_type index 524536fbb95..6dac1ee7814 100644 --- a/crates/cairo-lang-semantic/src/items/tests/trait_type +++ b/crates/cairo-lang-semantic/src/items/tests/trait_type @@ -2561,7 +2561,7 @@ foo impl I of core::metaprogramming::TypeEqual; //! > expected_diagnostics -error: Type equals trait should not be re-implemented. +error: Trait `core::metaprogramming::TypeEqual` should not be re-implemented. --> lib.cairo:1:11 impl I of core::metaprogramming::TypeEqual; ^********************************************^