Skip to content

Commit

Permalink
Fix possibly missed optimization.
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyalesokhin-starkware committed Nov 17, 2024
1 parent e9f3b1b commit 2c1eb9d
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 10 deletions.
31 changes: 21 additions & 10 deletions crates/cairo-lang-lowering/src/optimizations/match_optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ use cairo_lang_semantic::MatchArmSelector;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
use itertools::{Itertools, zip_eq};
use itertools::{zip_eq, Itertools};

use super::var_renamer::VarRenamer;
use crate::borrow_check::Demand;
use crate::borrow_check::analysis::{Analyzer, BackAnalysis, StatementLocation};
use crate::borrow_check::demand::EmptyDemandReporter;
use crate::borrow_check::Demand;
use crate::utils::RebuilderEx;
use crate::{
BlockId, FlatBlock, FlatBlockEnd, FlatLowered, MatchArm, MatchEnumInfo, MatchInfo, Statement,
Expand Down Expand Up @@ -322,15 +322,22 @@ impl<'a> Analyzer<'a> for MatchOptimizerContext {
return;
};

let Some(var_usage) = remapping.get(&candidate.match_variable) else {
// Revoke the candidate.
info.candidate = None;
return;
};
let orig_match_variable = candidate.match_variable;
candidate.match_variable = var_usage.var_id;

if remapping.len() > 1 {
// The term 'additional_remappings' refers to remappings for variables other than the match
// variable.
let goto_has_additional_remappings =
if let Some(var_usage) = remapping.get(&candidate.match_variable) {
candidate.match_variable = var_usage.var_id;
remapping.len() > 1
} else {
// Note that remapping.is_empty() is false here.
true
};

if goto_has_additional_remappings {
// here, we have remappings for variables other than the match variable.

if candidate.future_merge || candidate.additional_remappings.is_some() {
// TODO(ilya): Support multiple remappings with future merges.

Expand All @@ -342,7 +349,11 @@ impl<'a> Analyzer<'a> for MatchOptimizerContext {
remapping: remapping
.iter()
.filter_map(|(var, dst)| {
if *var != orig_match_variable { Some((*var, *dst)) } else { None }
if *var != orig_match_variable {
Some((*var, *dst))
} else {
None
}
})
.collect(),
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1157,3 +1157,123 @@ blk10:
Statements:
End:
Return(v8, v17)

//! > ==========================================================================

//! > Remapping where input var is not renamed.

//! > test_runner_name
test_match_optimizer

//! > function
fn foo(ref a: u32) {
let v = true;
if v {
a = 1;
}
let v = true;
if v {
a = 2;
}
}

//! > function_name
foo

//! > module_code

//! > semantic_diagnostics

//! > lowering_diagnostics

//! > before
Parameters: v0: core::integer::u32
blk0 (root):
Statements:
(v1: ()) <- struct_construct()
(v2: core::bool) <- bool::True(v1)
End:
Match(match_enum(v2) {
bool::False(v3) => blk1,
bool::True(v4) => blk2,
})

blk1:
Statements:
End:
Goto(blk3, {v0 -> v5})

blk2:
Statements:
(v6: core::integer::u32) <- 1
End:
Goto(blk3, {v6 -> v5})

blk3:
Statements:
(v7: ()) <- struct_construct()
(v8: core::bool) <- bool::True(v7)
End:
Match(match_enum(v8) {
bool::False(v9) => blk4,
bool::True(v10) => blk5,
})

blk4:
Statements:
End:
Goto(blk6, {v5 -> v11})

blk5:
Statements:
(v12: core::integer::u32) <- 2
End:
Goto(blk6, {v12 -> v11})

blk6:
Statements:
(v13: ()) <- struct_construct()
End:
Return(v11, v13)

//! > after
Parameters: v0: core::integer::u32
blk0 (root):
Statements:
(v1: ()) <- struct_construct()
End:
Goto(blk2, {v1 -> v4})

blk1:
Statements:
End:
Goto(blk3, {v0 -> v5})

blk2:
Statements:
(v6: core::integer::u32) <- 1
End:
Goto(blk3, {v6 -> v5})

blk3:
Statements:
(v7: ()) <- struct_construct()
End:
Goto(blk5, {v7 -> v10})

blk4:
Statements:
End:
Goto(blk6, {v5 -> v11})

blk5:
Statements:
(v12: core::integer::u32) <- 2
End:
Goto(blk6, {v12 -> v11})

blk6:
Statements:
(v13: ()) <- struct_construct()
End:
Return(v11, v13)

0 comments on commit 2c1eb9d

Please sign in to comment.