Skip to content

Commit

Permalink
zcash_client_sqlite: Modify Progress::scan to be non-optional & fix…
Browse files Browse the repository at this point in the history
… tests.
  • Loading branch information
nuttycom committed Oct 10, 2024
1 parent eaf020e commit a56de09
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 56 deletions.
4 changes: 2 additions & 2 deletions zcash_client_sqlite/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ pub mod error;
pub mod wallet;
use wallet::{
commitment_tree::{self, put_shard_roots},
SubtreeScanProgress,
SubtreeProgressEstimator,
};

#[cfg(test)]
Expand Down Expand Up @@ -461,7 +461,7 @@ impl<C: Borrow<rusqlite::Connection>, P: consensus::Parameters> WalletRead for W
&self.conn.borrow().unchecked_transaction()?,
&self.params,
min_confirmations,
&SubtreeScanProgress,
&SubtreeProgressEstimator,
)
}

Expand Down
143 changes: 90 additions & 53 deletions zcash_client_sqlite/src/wallet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -806,20 +806,34 @@ pub(crate) fn get_derived_account<P: consensus::Parameters>(

#[derive(Debug)]
pub(crate) struct Progress {
scan: Option<Ratio<u64>>,
recover: Option<Ratio<u64>>,
scan: Ratio<u64>,
recovery: Option<Ratio<u64>>,
}

pub(crate) trait ScanProgress {
impl Progress {
pub(crate) fn new(scan: Ratio<u64>, recovery: Option<Ratio<u64>>) -> Self {
Self { scan, recovery }
}

pub(crate) fn scan(&self) -> Ratio<u64> {
self.scan

Check warning on line 819 in zcash_client_sqlite/src/wallet.rs

View check run for this annotation

Codecov / codecov/patch

zcash_client_sqlite/src/wallet.rs#L818-L819

Added lines #L818 - L819 were not covered by tests
}

pub(crate) fn recovery(&self) -> Option<Ratio<u64>> {
self.recovery

Check warning on line 823 in zcash_client_sqlite/src/wallet.rs

View check run for this annotation

Codecov / codecov/patch

zcash_client_sqlite/src/wallet.rs#L822-L823

Added lines #L822 - L823 were not covered by tests
}
}

pub(crate) trait ProgressEstimator {
fn sapling_scan_progress<P: consensus::Parameters>(
&self,
conn: &rusqlite::Connection,
params: &P,
birthday_height: BlockHeight,
recover_until_height: Option<BlockHeight>,
fully_scanned_height: BlockHeight,
fully_scanned_height: Option<BlockHeight>,
chain_tip_height: BlockHeight,
) -> Result<Progress, SqliteClientError>;
) -> Result<Option<Progress>, SqliteClientError>;

#[cfg(feature = "orchard")]
fn orchard_scan_progress<P: consensus::Parameters>(
Expand All @@ -828,13 +842,13 @@ pub(crate) trait ScanProgress {
params: &P,
birthday_height: BlockHeight,
recover_until_height: Option<BlockHeight>,
fully_scanned_height: BlockHeight,
fully_scanned_height: Option<BlockHeight>,
chain_tip_height: BlockHeight,
) -> Result<Progress, SqliteClientError>;
) -> Result<Option<Progress>, SqliteClientError>;
}

#[derive(Debug)]
pub(crate) struct SubtreeScanProgress;
pub(crate) struct SubtreeProgressEstimator;

fn table_constants(
shielded_protocol: ShieldedProtocol,
Expand Down Expand Up @@ -1005,8 +1019,22 @@ fn estimate_tree_size<P: consensus::Parameters>(
}))
}
} else {
// We don't have subtree information, so give up. We'll get it soon.
Ok(None)
// If there are no completed subtrees, but we have scanned some blocks, we can still
// interpolate based upon the tree size as of the last scanned block. Here, since we
// don't have any subtree data to draw on, we will interpolate based on the number of
// blocks since the pool activation height
Ok(
last_scanned.and_then(|(last_scanned_height, last_scanned_tree_size)| {
let subtree_range = u64::from(last_scanned_height - pool_activation_height);
let unscanned_range = u64::from(chain_tip_height - last_scanned_height);

(last_scanned_tree_size * unscanned_range)
.checked_div(subtree_range)
.map(|extrapolated_incomplete_subtree_notes| {
last_scanned_tree_size + extrapolated_incomplete_subtree_notes
})
}),
)
}
}

Expand All @@ -1018,9 +1046,9 @@ fn subtree_scan_progress<P: consensus::Parameters>(
pool_activation_height: BlockHeight,
birthday_height: BlockHeight,
recover_until_height: Option<BlockHeight>,
fully_scanned_height: BlockHeight,
fully_scanned_height: Option<BlockHeight>,
chain_tip_height: BlockHeight,
) -> Result<Progress, SqliteClientError> {
) -> Result<Option<Progress>, SqliteClientError> {
let (table_prefix, output_count_col, shard_height) = table_constants(shielded_protocol)?;

let mut stmt_scanned_count_until = conn.prepare_cached(&format!(
Expand All @@ -1044,7 +1072,7 @@ fn subtree_scan_progress<P: consensus::Parameters>(
WHERE height = :height",
))?;

if fully_scanned_height == chain_tip_height {
if fully_scanned_height == Some(chain_tip_height) {
// Compute the total blocks scanned since the wallet birthday on either side of
// the recover-until height.
let recover = recover_until_height
Expand Down Expand Up @@ -1077,7 +1105,8 @@ fn subtree_scan_progress<P: consensus::Parameters>(
Ok(scanned.map(|n| Ratio::new(n, n)))
},
)?;
Ok(Progress { scan, recover })

Ok(scan.map(|scan| Progress::new(scan, recover)))
} else {
// In case we didn't have information about the tree size at the recover-until
// height, get the tree size from a nearby subtree. It's fine for this to be
Expand Down Expand Up @@ -1153,8 +1182,9 @@ fn subtree_scan_progress<P: consensus::Parameters>(
)
})
.transpose()?;
// If we've scanned the block at the chain tip, we know how many notes are
// currently in the tree.

// If we've scanned the block at the chain tip, we know how many notes are currently in the
// tree.
let tip_tree_size = match stmt_end_tree_size_at
.query_row(
named_params! {":height": u32::from(chain_tip_height)},
Expand Down Expand Up @@ -1201,21 +1231,21 @@ fn subtree_scan_progress<P: consensus::Parameters>(
})
};

Ok(Progress { scan, recover })
Ok(scan.map(|scan| Progress::new(scan, recover)))
}
}

impl ScanProgress for SubtreeScanProgress {
impl ProgressEstimator for SubtreeProgressEstimator {
#[tracing::instrument(skip(conn, params))]
fn sapling_scan_progress<P: consensus::Parameters>(
&self,
conn: &rusqlite::Connection,
params: &P,
birthday_height: BlockHeight,
recover_until_height: Option<BlockHeight>,
fully_scanned_height: BlockHeight,
fully_scanned_height: Option<BlockHeight>,
chain_tip_height: BlockHeight,
) -> Result<Progress, SqliteClientError> {
) -> Result<Option<Progress>, SqliteClientError> {
subtree_scan_progress(
conn,
params,
Expand All @@ -1238,9 +1268,9 @@ impl ScanProgress for SubtreeScanProgress {
params: &P,
birthday_height: BlockHeight,
recover_until_height: Option<BlockHeight>,
fully_scanned_height: BlockHeight,
fully_scanned_height: Option<BlockHeight>,
chain_tip_height: BlockHeight,
) -> Result<Progress, SqliteClientError> {
) -> Result<Option<Progress>, SqliteClientError> {
subtree_scan_progress(
conn,
params,
Expand Down Expand Up @@ -1268,7 +1298,7 @@ pub(crate) fn get_wallet_summary<P: consensus::Parameters>(
tx: &rusqlite::Transaction,
params: &P,
min_confirmations: u32,
progress: &impl ScanProgress,
progress: &impl ProgressEstimator,
) -> Result<Option<WalletSummary<AccountId>>, SqliteClientError> {
let chain_tip_height = match chain_tip_height(tx)? {
Some(h) => h,
Expand All @@ -1277,12 +1307,16 @@ pub(crate) fn get_wallet_summary<P: consensus::Parameters>(
}
};

let birthday_height =
wallet_birthday(tx)?.expect("If a scan range exists, we know the wallet birthday.");
let birthday_height = match wallet_birthday(tx)? {
Some(h) => h,
None => {
return Ok(None);

Check warning on line 1313 in zcash_client_sqlite/src/wallet.rs

View check run for this annotation

Codecov / codecov/patch

zcash_client_sqlite/src/wallet.rs#L1310-L1313

Added lines #L1310 - L1313 were not covered by tests
}
};

let recover_until_height = recover_until_height(tx)?;

let fully_scanned_height =
block_fully_scanned(tx, params)?.map_or(birthday_height - 1, |m| m.block_height());
let fully_scanned_height = block_fully_scanned(tx, params)?.map(|m| m.block_height());
let summary_height = (chain_tip_height + 1).saturating_sub(std::cmp::max(min_confirmations, 1));

let sapling_progress = progress.sapling_scan_progress(
Expand All @@ -1304,34 +1338,37 @@ pub(crate) fn get_wallet_summary<P: consensus::Parameters>(
chain_tip_height,
)?;
#[cfg(not(feature = "orchard"))]
let orchard_progress: Progress = Progress {
scan: None,
recover: None,
};
let orchard_progress: Option<Progress> = None;

Check warning on line 1341 in zcash_client_sqlite/src/wallet.rs

View check run for this annotation

Codecov / codecov/patch

zcash_client_sqlite/src/wallet.rs#L1341

Added line #L1341 was not covered by tests

// Treat Sapling and Orchard outputs as having the same cost to scan.
let scan_progress = sapling_progress
.scan
.zip(orchard_progress.scan)
.map(|(s, o)| {
Ratio::new(
s.numerator() + o.numerator(),
s.denominator() + o.denominator(),
)
})
.or(sapling_progress.scan)
.or(orchard_progress.scan);
let recover_progress = sapling_progress
.recover
.zip(orchard_progress.recover)
let progress = sapling_progress

Check warning on line 1344 in zcash_client_sqlite/src/wallet.rs

View check run for this annotation

Codecov / codecov/patch

zcash_client_sqlite/src/wallet.rs#L1344

Added line #L1344 was not covered by tests
.as_ref()
.zip(orchard_progress.as_ref())

Check warning on line 1346 in zcash_client_sqlite/src/wallet.rs

View check run for this annotation

Codecov / codecov/patch

zcash_client_sqlite/src/wallet.rs#L1346

Added line #L1346 was not covered by tests
.map(|(s, o)| {
Ratio::new(
s.numerator() + o.numerator(),
s.denominator() + o.denominator(),
Progress::new(
Ratio::new(
s.scan().numerator() + o.scan().numerator(),
s.scan().denominator() + o.scan().denominator(),

Check warning on line 1351 in zcash_client_sqlite/src/wallet.rs

View check run for this annotation

Codecov / codecov/patch

zcash_client_sqlite/src/wallet.rs#L1348-L1351

Added lines #L1348 - L1351 were not covered by tests
),
s.recovery()
.zip(o.recovery())
.map(|(s, o)| {
Ratio::new(
s.numerator() + o.numerator(),
s.denominator() + o.denominator(),

Check warning on line 1358 in zcash_client_sqlite/src/wallet.rs

View check run for this annotation

Codecov / codecov/patch

zcash_client_sqlite/src/wallet.rs#L1353-L1358

Added lines #L1353 - L1358 were not covered by tests
)
})
.or_else(|| s.recovery())
.or_else(|| o.recovery()),

Check warning on line 1362 in zcash_client_sqlite/src/wallet.rs

View check run for this annotation

Codecov / codecov/patch

zcash_client_sqlite/src/wallet.rs#L1361-L1362

Added lines #L1361 - L1362 were not covered by tests
)
})
.or(sapling_progress.recover)
.or(orchard_progress.recover);
.or(sapling_progress)
.or(orchard_progress);

Check warning on line 1366 in zcash_client_sqlite/src/wallet.rs

View check run for this annotation

Codecov / codecov/patch

zcash_client_sqlite/src/wallet.rs#L1365-L1366

Added lines #L1365 - L1366 were not covered by tests

let progress = match progress {
Some(p) => p,
None => return Ok(None),

Check warning on line 1370 in zcash_client_sqlite/src/wallet.rs

View check run for this annotation

Codecov / codecov/patch

zcash_client_sqlite/src/wallet.rs#L1368-L1370

Added lines #L1368 - L1370 were not covered by tests
};

let mut stmt_accounts = tx.prepare_cached("SELECT id FROM accounts")?;
let mut account_balances = stmt_accounts
Expand Down Expand Up @@ -1552,9 +1589,9 @@ pub(crate) fn get_wallet_summary<P: consensus::Parameters>(
let summary = WalletSummary::new(
account_balances,
chain_tip_height,
fully_scanned_height,
scan_progress,
recover_progress,
fully_scanned_height.unwrap_or(birthday_height - 1),
Some(progress.scan),
progress.recovery,

Check warning on line 1594 in zcash_client_sqlite/src/wallet.rs

View check run for this annotation

Codecov / codecov/patch

zcash_client_sqlite/src/wallet.rs#L1592-L1594

Added lines #L1592 - L1594 were not covered by tests
next_sapling_subtree_index,
#[cfg(feature = "orchard")]
next_orchard_subtree_index,
Expand Down
5 changes: 4 additions & 1 deletion zcash_client_sqlite/src/wallet/scanning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1303,7 +1303,10 @@ pub(crate) mod tests {
summary.as_ref().and_then(|s| s.recovery_progress()),
no_recovery,
);
assert_eq!(summary.and_then(|s| s.scan_progress()), None);
assert_matches!(
summary.and_then(|s| s.scan_progress()),
Some(progress) if progress.numerator() == &0
);

// Set up prior chain state. This simulates us having imported a wallet
// with a birthday 520 blocks below the chain tip.
Expand Down

0 comments on commit a56de09

Please sign in to comment.