diff --git a/lightning/src/routing/scoring.rs b/lightning/src/routing/scoring.rs index 0cc5e8d27b9..94e16b96b0e 100644 --- a/lightning/src/routing/scoring.rs +++ b/lightning/src/routing/scoring.rs @@ -1218,14 +1218,33 @@ fn nonlinear_success_probability( /// Given liquidity bounds, calculates the success probability (in the form of a numerator and /// denominator) of an HTLC. This is a key assumption in our scoring models. /// -/// Must not return a numerator or denominator greater than 2^31 for arguments less than 2^31. -/// /// `total_inflight_amount_msat` includes the amount of the HTLC and any HTLCs in flight over the /// channel. /// /// min_zero_implies_no_successes signals that a `min_liquidity_msat` of 0 means we've not /// (recently) seen an HTLC successfully complete over this channel. #[inline(always)] +fn success_probability_float( + total_inflight_amount_msat: u64, min_liquidity_msat: u64, max_liquidity_msat: u64, + capacity_msat: u64, params: &ProbabilisticScoringFeeParameters, + min_zero_implies_no_successes: bool, +) -> (f64, f64) { + debug_assert!(min_liquidity_msat <= total_inflight_amount_msat); + debug_assert!(total_inflight_amount_msat < max_liquidity_msat); + debug_assert!(max_liquidity_msat <= capacity_msat); + + if params.linear_success_probability { + let (numerator, denominator) = linear_success_probability(total_inflight_amount_msat, min_liquidity_msat, max_liquidity_msat, min_zero_implies_no_successes); + (numerator as f64, denominator as f64) + } else { + nonlinear_success_probability(total_inflight_amount_msat, min_liquidity_msat, max_liquidity_msat, capacity_msat, min_zero_implies_no_successes) + } +} + +#[inline(always)] +/// Identical to [`success_probability_float`] but returns integer numerator and denominators. +/// +/// Must not return a numerator or denominator greater than 2^31 for arguments less than 2^31. fn success_probability( total_inflight_amount_msat: u64, min_liquidity_msat: u64, max_liquidity_msat: u64, capacity_msat: u64, params: &ProbabilisticScoringFeeParameters, @@ -1798,7 +1817,12 @@ mod bucketed_history { // Because the first thing we do is check if `total_valid_points` is sufficient to consider // the data here at all, and can return early if it is not, we want this to go first to // avoid hitting a second cache line load entirely in that case. - total_valid_points_tracked: u64, + // + // Note that we store it as an `f64` rather than a `u64` (potentially losing some + // precision) because we ultimately need the value as an `f64` when dividing bucket weights + // by it. Storing it as an `f64` avoids doing the additional int -> float conversion in the + // hot score-calculation path. + total_valid_points_tracked: f64, min_liquidity_offset_history: HistoricalBucketRangeTracker, max_liquidity_offset_history: HistoricalBucketRangeTracker, } @@ -1808,7 +1832,7 @@ mod bucketed_history { HistoricalLiquidityTracker { min_liquidity_offset_history: HistoricalBucketRangeTracker::new(), max_liquidity_offset_history: HistoricalBucketRangeTracker::new(), - total_valid_points_tracked: 0, + total_valid_points_tracked: 0.0, } } @@ -1819,7 +1843,7 @@ mod bucketed_history { let mut res = HistoricalLiquidityTracker { min_liquidity_offset_history, max_liquidity_offset_history, - total_valid_points_tracked: 0, + total_valid_points_tracked: 0.0, }; res.recalculate_valid_point_count(); res @@ -1842,12 +1866,18 @@ mod bucketed_history { } fn recalculate_valid_point_count(&mut self) { - self.total_valid_points_tracked = 0; + let mut total_valid_points_tracked = 0; for (min_idx, min_bucket) in self.min_liquidity_offset_history.buckets.iter().enumerate() { for max_bucket in self.max_liquidity_offset_history.buckets.iter().take(32 - min_idx) { - self.total_valid_points_tracked += (*min_bucket as u64) * (*max_bucket as u64); + // In testing, raising the weights of buckets to a high power led to better + // scoring results. Thus, we raise the bucket weights to the 4th power here (by + // squaring the result of multiplying the weights). + let mut bucket_weight = (*min_bucket as u64) * (*max_bucket as u64); + bucket_weight *= bucket_weight; + total_valid_points_tracked += bucket_weight; } } + self.total_valid_points_tracked = total_valid_points_tracked as f64; } pub(super) fn writeable_min_offset_history(&self) -> &HistoricalBucketRangeTracker { @@ -1933,20 +1963,23 @@ mod bucketed_history { let mut actual_valid_points_tracked = 0; for (min_idx, min_bucket) in min_liquidity_offset_history_buckets.iter().enumerate() { for max_bucket in max_liquidity_offset_history_buckets.iter().take(32 - min_idx) { - actual_valid_points_tracked += (*min_bucket as u64) * (*max_bucket as u64); + let mut bucket_weight = (*min_bucket as u64) * (*max_bucket as u64); + bucket_weight *= bucket_weight; + actual_valid_points_tracked += bucket_weight; } } - assert_eq!(total_valid_points_tracked, actual_valid_points_tracked); + assert_eq!(total_valid_points_tracked, actual_valid_points_tracked as f64); } // If the total valid points is smaller than 1.0 (i.e. 32 in our fixed-point scheme), // treat it as if we were fully decayed. - const FULLY_DECAYED: u16 = BUCKET_FIXED_POINT_ONE * BUCKET_FIXED_POINT_ONE; + const FULLY_DECAYED: f64 = BUCKET_FIXED_POINT_ONE as f64 * BUCKET_FIXED_POINT_ONE as f64 * + BUCKET_FIXED_POINT_ONE as f64 * BUCKET_FIXED_POINT_ONE as f64; if total_valid_points_tracked < FULLY_DECAYED.into() { return None; } - let mut cumulative_success_prob_times_billion = 0; + let mut cumulative_success_prob = 0.0f64; // Special-case the 0th min bucket - it generally means we failed a payment, so only // consider the highest (i.e. largest-offset-from-max-capacity) max bucket for all // points against the 0th min bucket. This avoids the case where we fail to route @@ -1959,7 +1992,7 @@ mod bucketed_history { // max-bucket with at least BUCKET_FIXED_POINT_ONE. let mut highest_max_bucket_with_points = 0; let mut highest_max_bucket_with_full_points = None; - let mut total_max_points = 0; // Total points in max-buckets to consider + let mut total_weight = 0; for (max_idx, max_bucket) in max_liquidity_offset_history_buckets.iter().enumerate() { if *max_bucket >= BUCKET_FIXED_POINT_ONE { highest_max_bucket_with_full_points = Some(cmp::max(highest_max_bucket_with_full_points.unwrap_or(0), max_idx)); @@ -1967,8 +2000,14 @@ mod bucketed_history { if *max_bucket != 0 { highest_max_bucket_with_points = cmp::max(highest_max_bucket_with_points, max_idx); } - total_max_points += *max_bucket as u64; + // In testing, raising the weights of buckets to a high power led to better + // scoring results. Thus, we raise the bucket weights to the 4th power here (by + // squaring the result of multiplying the weights), matching the logic in + // `recalculate_valid_point_count`. + let bucket_weight = (*max_bucket as u64) * (min_liquidity_offset_history_buckets[0] as u64); + total_weight += bucket_weight * bucket_weight; } + debug_assert!(total_weight as f64 <= total_valid_points_tracked); // Use the highest max-bucket with at least BUCKET_FIXED_POINT_ONE, but if none is // available use the highest max-bucket with any non-zero value. This ensures that // if we have substantially decayed data we don't end up thinking the highest @@ -1977,13 +2016,10 @@ mod bucketed_history { let selected_max = highest_max_bucket_with_full_points.unwrap_or(highest_max_bucket_with_points); let max_bucket_end_pos = BUCKET_START_POS[32 - selected_max] - 1; if payment_pos < max_bucket_end_pos { - let (numerator, denominator) = success_probability(payment_pos as u64, 0, + let (numerator, denominator) = success_probability_float(payment_pos as u64, 0, max_bucket_end_pos as u64, POSITION_TICKS as u64 - 1, params, true); - let bucket_prob_times_billion = - (min_liquidity_offset_history_buckets[0] as u64) * total_max_points - * 1024 * 1024 * 1024 / total_valid_points_tracked; - cumulative_success_prob_times_billion += bucket_prob_times_billion * - numerator / denominator; + let bucket_prob = total_weight as f64 / total_valid_points_tracked; + cumulative_success_prob += bucket_prob * numerator / denominator; } } @@ -1991,26 +2027,32 @@ mod bucketed_history { let min_bucket_start_pos = BUCKET_START_POS[min_idx]; for (max_idx, max_bucket) in max_liquidity_offset_history_buckets.iter().enumerate().take(32 - min_idx) { let max_bucket_end_pos = BUCKET_START_POS[32 - max_idx] - 1; - // Note that this multiply can only barely not overflow - two 16 bit ints plus - // 30 bits is 62 bits. - let bucket_prob_times_billion = (*min_bucket as u64) * (*max_bucket as u64) - * 1024 * 1024 * 1024 / total_valid_points_tracked; if payment_pos >= max_bucket_end_pos { // Success probability 0, the payment amount may be above the max liquidity break; - } else if payment_pos < min_bucket_start_pos { - cumulative_success_prob_times_billion += bucket_prob_times_billion; + } + + // In testing, raising the weights of buckets to a high power led to better + // scoring results. Thus, we raise the bucket weights to the 4th power here (by + // squaring the result of multiplying the weights), matching the logic in + // `recalculate_valid_point_count`. + let mut bucket_weight = (*min_bucket as u64) * (*max_bucket as u64); + bucket_weight *= bucket_weight; + debug_assert!(bucket_weight as f64 <= total_valid_points_tracked); + let bucket_prob = bucket_weight as f64 / total_valid_points_tracked; + + if payment_pos < min_bucket_start_pos { + cumulative_success_prob += bucket_prob; } else { - let (numerator, denominator) = success_probability(payment_pos as u64, + let (numerator, denominator) = success_probability_float(payment_pos as u64, min_bucket_start_pos as u64, max_bucket_end_pos as u64, POSITION_TICKS as u64 - 1, params, true); - cumulative_success_prob_times_billion += bucket_prob_times_billion * - numerator / denominator; + cumulative_success_prob += bucket_prob * numerator / denominator; } } } - Some(cumulative_success_prob_times_billion) + Some((cumulative_success_prob * (1024.0 * 1024.0 * 1024.0)) as u64) } } }