diff --git a/warp_binja/src/matcher.rs b/warp_binja/src/matcher.rs index 97930c3..2a15fe6 100644 --- a/warp_binja/src/matcher.rs +++ b/warp_binja/src/matcher.rs @@ -1,3 +1,4 @@ +use std::cmp::Ordering; use dashmap::DashMap; use fastbloom::BloomFilter; use std::collections::{HashMap, HashSet}; @@ -305,9 +306,6 @@ impl Matcher { function: &BNFunction, matched_functions: &'a [Function], ) -> Option<&'a Function> { - let mut matched_func = None; - let mut highest_score = 0.0; - // TODO: To prevent invoking adjacent constraint function analysis, we must call call_site constraints specifically. let call_sites = cached_call_site_constraints(function); @@ -317,6 +315,8 @@ impl Matcher { } // Check call site guids + let mut highest_guid_count = 0; + let mut matched_guid_func = None; let call_site_guids = call_sites .iter() .filter_map(|c| c.guid) @@ -331,14 +331,22 @@ impl Matcher { let common_guid_count = call_site_guids .intersection(&matched_call_site_guids) .count(); - let score = common_guid_count as f64; - if score > highest_score { - highest_score = score; - matched_func = Some(matched); + match common_guid_count.cmp(&highest_guid_count) { + Ordering::Equal => { + // Multiple matches with same count, don't match on ONE of them. + matched_guid_func = None; + } + Ordering::Greater => { + highest_guid_count = common_guid_count; + matched_guid_func = Some(matched); + } + Ordering::Less => {} } } // Check call site symbol names + let mut highest_symbol_count = 0; + let mut matched_symbol_func = None; let call_site_symbol_names = call_sites .into_iter() .filter_map(|c| Some(c.symbol?.name)) @@ -353,14 +361,24 @@ impl Matcher { let common_symbol_count = call_site_symbol_names .intersection(&matched_call_site_symbol_names) .count(); - let score = common_symbol_count as f64; - if score > highest_score { - highest_score = score; - matched_func = Some(matched); + match common_symbol_count.cmp(&highest_symbol_count) { + Ordering::Equal => { + // Multiple matches with same count, don't match on ONE of them. + matched_symbol_func = None; + } + Ordering::Greater => { + highest_symbol_count = common_symbol_count; + matched_symbol_func = Some(matched); + } + Ordering::Less => {} } } - matched_func + match highest_guid_count.cmp(&highest_symbol_count) { + Ordering::Less => matched_symbol_func, + Ordering::Greater => matched_guid_func, + Ordering::Equal => None, + } } }