diff --git a/zcash_client_backend/CHANGELOG.md b/zcash_client_backend/CHANGELOG.md index bb6badaa30..8042970374 100644 --- a/zcash_client_backend/CHANGELOG.md +++ b/zcash_client_backend/CHANGELOG.md @@ -49,6 +49,8 @@ and this library adheres to Rust's notion of - `WalletOrchardOutput` - `WalletTx::{orchard_spends, orchard_outputs}` - `ReceivedNote::map_note` + - `ReceivedNote<_, sapling::Note>::note_value` + - `ReceivedNote<_, orchard::note::Note>::note_value` ### Changed - `zcash_client_backend::data_api`: @@ -64,8 +66,8 @@ and this library adheres to Rust's notion of - Added `get_orchard_nullifiers` method. - Changes to the `InputSource` trait: - `select_spendable_notes` now takes its `target_value` argument as a - `NonNegativeAmount`. Also, the values of the returned map are also - `NonNegativeAmount`s instead of `Amount`s. + `NonNegativeAmount`. Also, it now returns a `SpendableNotes` data + structure instead of a vector. - Fields of `DecryptedTransaction` are now private. Use `DecryptedTransaction::new` and the newly provided accessors instead. - Fields of `SentTransaction` are now private. Use `SentTransaction::new` diff --git a/zcash_client_backend/src/data_api.rs b/zcash_client_backend/src/data_api.rs index 16b40805eb..cb28431479 100644 --- a/zcash_client_backend/src/data_api.rs +++ b/zcash_client_backend/src/data_api.rs @@ -494,6 +494,23 @@ pub trait NoteRetention { fn should_retain_orchard(&self, note: &ReceivedNote) -> bool; } +pub(crate) struct SimpleNoteRetention { + pub(crate) sapling: bool, + #[cfg(feature = "orchard")] + pub(crate) orchard: bool, +} + +impl NoteRetention for SimpleNoteRetention { + fn should_retain_sapling(&self, _: &ReceivedNote) -> bool { + self.sapling + } + + #[cfg(feature = "orchard")] + fn should_retain_orchard(&self, _: &ReceivedNote) -> bool { + self.orchard + } +} + /// Spendable shielded outputs controlled by the wallet. pub struct SpendableNotes { sapling: Vec>, @@ -502,6 +519,15 @@ pub struct SpendableNotes { } impl SpendableNotes { + /// Construct a new empty [`SpendableNotes`]. + pub fn empty() -> Self { + Self::new( + vec![], + #[cfg(feature = "orchard")] + vec![], + ) + } + /// Construct a new [`SpendableNotes`] from its constituent parts. pub fn new( sapling: Vec>, @@ -525,6 +551,34 @@ impl SpendableNotes { self.orchard.as_ref() } + /// Computes the total value of Sapling notes. + pub fn sapling_value(&self) -> Result { + self.sapling + .iter() + .try_fold(NonNegativeAmount::ZERO, |acc, n| { + (acc + n.note_value()?).ok_or(BalanceError::Overflow) + }) + } + + /// Computes the total value of Sapling notes. + #[cfg(feature = "orchard")] + pub fn orchard_value(&self) -> Result { + self.orchard + .iter() + .try_fold(NonNegativeAmount::ZERO, |acc, n| { + (acc + n.note_value()?).ok_or(BalanceError::Overflow) + }) + } + + /// Computes the total value of spendable inputs + pub fn total_value(&self) -> Result { + #[cfg(not(feature = "orchard"))] + return self.sapling_value(); + + #[cfg(feature = "orchard")] + return (self.sapling_value()? + self.orchard_value()?).ok_or(BalanceError::Overflow); + } + /// Consumes this [`SpendableNotes`] value and produces a vector of /// [`ReceivedNote`] values. pub fn into_vec( @@ -589,7 +643,7 @@ pub trait InputSource { sources: &[ShieldedProtocol], anchor_height: BlockHeight, exclude: &[Self::NoteRef], - ) -> Result>, Self::Error>; + ) -> Result, Self::Error>; /// Fetches a spendable transparent output. /// @@ -1552,8 +1606,8 @@ pub mod testing { chain::{ChainState, CommitmentTreeRoot}, scanning::ScanRange, AccountBirthday, BlockMetadata, DecryptedTransaction, InputSource, NullifierQuery, - ScannedBlock, SentTransaction, WalletCommitmentTrees, WalletRead, WalletSummary, - WalletWrite, SAPLING_SHARD_HEIGHT, + ScannedBlock, SentTransaction, SpendableNotes, WalletCommitmentTrees, WalletRead, + WalletSummary, WalletWrite, SAPLING_SHARD_HEIGHT, }; #[cfg(feature = "transparent-inputs")] @@ -1609,8 +1663,8 @@ pub mod testing { _sources: &[ShieldedProtocol], _anchor_height: BlockHeight, _exclude: &[Self::NoteRef], - ) -> Result>, Self::Error> { - Ok(Vec::new()) + ) -> Result, Self::Error> { + Ok(SpendableNotes::empty()) } } diff --git a/zcash_client_backend/src/data_api/wallet/input_selection.rs b/zcash_client_backend/src/data_api/wallet/input_selection.rs index 5e3f16ae94..88f61142ab 100644 --- a/zcash_client_backend/src/data_api/wallet/input_selection.rs +++ b/zcash_client_backend/src/data_api/wallet/input_selection.rs @@ -21,10 +21,10 @@ use zcash_primitives::{ use crate::{ address::{Address, UnifiedAddress}, - data_api::InputSource, + data_api::{InputSource, SimpleNoteRetention, SpendableNotes}, fees::{sapling, ChangeError, ChangeStrategy, DustOutputPolicy}, proposal::{Proposal, ProposalError, ShieldedInputs}, - wallet::{Note, ReceivedNote, WalletTransparentOutput}, + wallet::WalletTransparentOutput, zip321::TransactionRequest, PoolType, ShieldedProtocol, }; @@ -386,7 +386,7 @@ where } } - let mut shielded_inputs: Vec> = vec![]; + let mut shielded_inputs = SpendableNotes::empty(); let mut prior_available = NonNegativeAmount::ZERO; let mut amount_required = NonNegativeAmount::ZERO; let mut exclude: Vec = vec![]; @@ -396,58 +396,39 @@ where loop { #[cfg(feature = "orchard")] let (sapling_input_total, orchard_input_total) = ( - shielded_inputs - .iter() - .filter(|i| matches!(i.note(), Note::Sapling(_))) - .map(|i| i.note().value()) - .sum::>() - .ok_or(BalanceError::Overflow)?, - shielded_inputs - .iter() - .filter(|i| matches!(i.note(), Note::Orchard(_))) - .map(|i| i.note().value()) - .sum::>() - .ok_or(BalanceError::Overflow)?, + shielded_inputs.sapling_value()?, + shielded_inputs.orchard_value()?, ); #[cfg(not(feature = "orchard"))] let orchard_input_total = NonNegativeAmount::ZERO; - let sapling_inputs = - if sapling_outputs.is_empty() && orchard_input_total >= amount_required { - // Avoid selecting Sapling inputs if we don't have Sapling outputs and the value is - // fully covered by Orchard inputs. - #[cfg(feature = "orchard")] - shielded_inputs.retain(|i| matches!(i.note(), Note::Orchard(_))); - vec![] - } else { - #[allow(clippy::unnecessary_filter_map)] - shielded_inputs - .iter() - .filter_map(|i| match i.note() { - Note::Sapling(n) => Some((*i.internal_note_id(), n.value())), - #[cfg(feature = "orchard")] - Note::Orchard(_) => None, - }) - .collect::>() - }; + let use_sapling = + !(sapling_outputs.is_empty() && orchard_input_total >= amount_required); + #[cfg(feature = "orchard")] + let use_orchard = + !(orchard_outputs.is_empty() && sapling_input_total >= amount_required); + + let sapling_inputs = if use_sapling { + shielded_inputs + .sapling() + .iter() + .map(|i| (*i.internal_note_id(), i.note().value())) + .collect() + } else { + vec![] + }; #[cfg(feature = "orchard")] - let orchard_inputs = - if orchard_outputs.is_empty() && sapling_input_total >= amount_required { - // Avoid selecting Orchard inputs if we don't have Orchard outputs and the value is - // fully covered by Sapling inputs. - shielded_inputs.retain(|i| matches!(i.note(), Note::Sapling(_))); - vec![] - } else { - shielded_inputs - .iter() - .filter_map(|i| match i.note() { - Note::Sapling(_) => None, - Note::Orchard(n) => Some((*i.internal_note_id(), n.value())), - }) - .collect::>() - }; + let orchard_inputs = if use_orchard { + shielded_inputs + .orchard() + .iter() + .map(|i| (*i.internal_note_id(), i.note().value())) + .collect() + } else { + vec![] + }; let balance = self.change_strategy.compute_balance( params, @@ -474,8 +455,12 @@ where transaction_request, payment_pools, vec![], - NonEmpty::from_vec(shielded_inputs) - .map(|notes| ShieldedInputs::from_parts(anchor_height, notes)), + NonEmpty::from_vec(shielded_inputs.into_vec(&SimpleNoteRetention { + sapling: use_sapling, + #[cfg(feature = "orchard")] + orchard: use_orchard, + })) + .map(|notes| ShieldedInputs::from_parts(anchor_height, notes)), balance, (*self.change_strategy.fee_rule()).clone(), target_height, @@ -514,12 +499,7 @@ where ) .map_err(InputSelectorError::DataSource)?; - let new_available = shielded_inputs - .iter() - .map(|n| n.note().value()) - .sum::>() - .ok_or(BalanceError::Overflow)?; - + let new_available = shielded_inputs.total_value()?; if new_available <= prior_available { return Err(InputSelectorError::InsufficientFunds { required: amount_required, diff --git a/zcash_client_backend/src/wallet.rs b/zcash_client_backend/src/wallet.rs index e3b1872476..e849d5d707 100644 --- a/zcash_client_backend/src/wallet.rs +++ b/zcash_client_backend/src/wallet.rs @@ -16,6 +16,7 @@ use zcash_primitives::{ }, zip32::Scope, }; +use zcash_protocol::value::BalanceError; use crate::{address::UnifiedAddress, fees::sapling as sapling_fees, PoolType, ShieldedProtocol}; @@ -447,6 +448,19 @@ impl ReceivedNote { } } +impl ReceivedNote { + pub fn note_value(&self) -> Result { + self.note.value().inner().try_into() + } +} + +#[cfg(feature = "orchard")] +impl ReceivedNote { + pub fn note_value(&self) -> Result { + self.note.value().inner().try_into() + } +} + impl sapling_fees::InputView for (NoteRef, sapling::value::NoteValue) { fn note_id(&self) -> &NoteRef { &self.0 diff --git a/zcash_client_sqlite/src/lib.rs b/zcash_client_sqlite/src/lib.rs index f4137d8be5..57ca55654b 100644 --- a/zcash_client_sqlite/src/lib.rs +++ b/zcash_client_sqlite/src/lib.rs @@ -61,8 +61,8 @@ use zcash_client_backend::{ chain::{BlockSource, ChainState, CommitmentTreeRoot}, scanning::{ScanPriority, ScanRange}, Account, AccountBirthday, AccountKind, BlockMetadata, DecryptedTransaction, InputSource, - NullifierQuery, ScannedBlock, SentTransaction, WalletCommitmentTrees, WalletRead, - WalletSummary, WalletWrite, SAPLING_SHARD_HEIGHT, + NullifierQuery, ScannedBlock, SentTransaction, SpendableNotes, WalletCommitmentTrees, + WalletRead, WalletSummary, WalletWrite, SAPLING_SHARD_HEIGHT, }, keys::{ AddressGenerationError, UnifiedAddressRequest, UnifiedFullViewingKey, UnifiedSpendingKey, @@ -212,7 +212,8 @@ impl, P: consensus::Parameters> InputSource for &self.params, txid, index, - ), + ) + .map(|opt| opt.map(|n| n.map_note(Note::Sapling))), ShieldedProtocol::Orchard => { #[cfg(feature = "orchard")] return wallet::orchard::get_spendable_orchard_note( @@ -220,7 +221,8 @@ impl, P: consensus::Parameters> InputSource for &self.params, txid, index, - ); + ) + .map(|opt| opt.map(|n| n.map_note(Note::Orchard))); #[cfg(not(feature = "orchard"))] return Err(SqliteClientError::UnsupportedPoolType(PoolType::Shielded( @@ -237,26 +239,26 @@ impl, P: consensus::Parameters> InputSource for _sources: &[ShieldedProtocol], anchor_height: BlockHeight, exclude: &[Self::NoteRef], - ) -> Result>, Self::Error> { - let received_iter = std::iter::empty(); - let received_iter = received_iter.chain(wallet::sapling::select_spendable_sapling_notes( - self.conn.borrow(), - &self.params, - account, - target_value, - anchor_height, - exclude, - )?); - #[cfg(feature = "orchard")] - let received_iter = received_iter.chain(wallet::orchard::select_spendable_orchard_notes( - self.conn.borrow(), - &self.params, - account, - target_value, - anchor_height, - exclude, - )?); - Ok(received_iter.collect()) + ) -> Result, Self::Error> { + Ok(SpendableNotes::new( + wallet::sapling::select_spendable_sapling_notes( + self.conn.borrow(), + &self.params, + account, + target_value, + anchor_height, + exclude, + )?, + #[cfg(feature = "orchard")] + wallet::orchard::select_spendable_orchard_notes( + self.conn.borrow(), + &self.params, + account, + target_value, + anchor_height, + exclude, + )?, + )) } #[cfg(feature = "transparent-inputs")] diff --git a/zcash_client_sqlite/src/testing/pool.rs b/zcash_client_sqlite/src/testing/pool.rs index fec28fed9e..750c9802a1 100644 --- a/zcash_client_sqlite/src/testing/pool.rs +++ b/zcash_client_sqlite/src/testing/pool.rs @@ -83,6 +83,7 @@ pub(crate) trait ShieldedPoolTester { type Sk; type Fvk: TestFvk; type MerkleTreeHash; + type Note; fn test_account_fvk(st: &TestState) -> Self::Fvk; fn usk_to_sk(usk: &UnifiedSpendingKey) -> &Self::Sk; @@ -109,7 +110,7 @@ pub(crate) trait ShieldedPoolTester { target_value: NonNegativeAmount, anchor_height: BlockHeight, exclude: &[ReceivedNoteId], - ) -> Result>, SqliteClientError>; + ) -> Result>, SqliteClientError>; fn decrypted_pool_outputs_count(d_tx: &DecryptedTransaction<'_, AccountId>) -> usize; diff --git a/zcash_client_sqlite/src/wallet/common.rs b/zcash_client_sqlite/src/wallet/common.rs index 1a47fc404a..663d9b717e 100644 --- a/zcash_client_sqlite/src/wallet/common.rs +++ b/zcash_client_sqlite/src/wallet/common.rs @@ -3,10 +3,7 @@ use rusqlite::{named_params, types::Value, Connection, Row}; use std::rc::Rc; -use zcash_client_backend::{ - wallet::{Note, ReceivedNote}, - ShieldedProtocol, -}; +use zcash_client_backend::{wallet::ReceivedNote, ShieldedProtocol}; use zcash_primitives::transaction::{components::amount::NonNegativeAmount, TxId}; use zcash_protocol::consensus::{self, BlockHeight}; @@ -54,7 +51,7 @@ fn unscanned_tip_exists( // (https://github.com/rust-lang/rust-clippy/issues/11308) means it fails to identify that the `result` temporary // is required in order to resolve the borrows involved in the `query_and_then` call. #[allow(clippy::let_and_return)] -pub(crate) fn get_spendable_note( +pub(crate) fn get_spendable_note( conn: &Connection, params: &P, txid: &TxId, @@ -99,7 +96,7 @@ where } #[allow(clippy::too_many_arguments)] -pub(crate) fn select_spendable_notes( +pub(crate) fn select_spendable_notes( conn: &Connection, params: &P, account: AccountId, diff --git a/zcash_client_sqlite/src/wallet/orchard.rs b/zcash_client_sqlite/src/wallet/orchard.rs index 7143869a28..3d0a575ad3 100644 --- a/zcash_client_sqlite/src/wallet/orchard.rs +++ b/zcash_client_sqlite/src/wallet/orchard.rs @@ -96,10 +96,7 @@ impl ReceivedOrchardOutput for DecryptedOutput { fn to_spendable_note( params: &P, row: &Row, -) -> Result< - Option>, - SqliteClientError, -> { +) -> Result>, SqliteClientError> { let note_id = ReceivedNoteId(ShieldedProtocol::Orchard, row.get(0)?); let txid = row.get::<_, [u8; 32]>(1).map(TxId::from_bytes)?; let action_index = row.get(2)?; @@ -175,7 +172,7 @@ fn to_spendable_note( note_id, txid, action_index, - zcash_client_backend::wallet::Note::Orchard(note), + note, spending_key_scope, note_commitment_tree_position, )) @@ -188,10 +185,7 @@ pub(crate) fn get_spendable_orchard_note( params: &P, txid: &TxId, index: u32, -) -> Result< - Option>, - SqliteClientError, -> { +) -> Result>, SqliteClientError> { super::common::get_spendable_note( conn, params, @@ -209,8 +203,7 @@ pub(crate) fn select_spendable_orchard_notes( target_value: Zatoshis, anchor_height: BlockHeight, exclude: &[ReceivedNoteId], -) -> Result>, SqliteClientError> -{ +) -> Result>, SqliteClientError> { super::common::select_spendable_notes( conn, params, @@ -389,6 +382,7 @@ pub(crate) mod tests { type Sk = SpendingKey; type Fvk = FullViewingKey; type MerkleTreeHash = MerkleHashOrchard; + type Note = orchard::note::Note; fn test_account_fvk(st: &TestState) -> Self::Fvk { st.test_account_orchard().unwrap() @@ -457,7 +451,8 @@ pub(crate) mod tests { target_value: zcash_protocol::value::Zatoshis, anchor_height: BlockHeight, exclude: &[crate::ReceivedNoteId], - ) -> Result>, SqliteClientError> { + ) -> Result>, SqliteClientError> + { select_spendable_orchard_notes( &st.wallet().conn, &st.wallet().params, diff --git a/zcash_client_sqlite/src/wallet/sapling.rs b/zcash_client_sqlite/src/wallet/sapling.rs index 38072f9132..14b8861008 100644 --- a/zcash_client_sqlite/src/wallet/sapling.rs +++ b/zcash_client_sqlite/src/wallet/sapling.rs @@ -7,7 +7,7 @@ use rusqlite::{named_params, params, Connection, Row}; use sapling::{self, Diversifier, Nullifier, Rseed}; use zcash_client_backend::{ data_api::NullifierQuery, - wallet::{Note, ReceivedNote, WalletSaplingOutput}, + wallet::{ReceivedNote, WalletSaplingOutput}, DecryptedOutput, ShieldedProtocol, TransferType, }; use zcash_keys::keys::UnifiedFullViewingKey; @@ -95,7 +95,7 @@ impl ReceivedSaplingOutput for DecryptedOutput { fn to_spendable_note( params: &P, row: &Row, -) -> Result>, SqliteClientError> { +) -> Result>, SqliteClientError> { let note_id = ReceivedNoteId(ShieldedProtocol::Sapling, row.get(0)?); let txid = row.get::<_, [u8; 32]>(1).map(TxId::from_bytes)?; let output_index = row.get(2)?; @@ -169,11 +169,11 @@ fn to_spendable_note( note_id, txid, output_index, - Note::Sapling(sapling::Note::from_parts( + sapling::Note::from_parts( recipient, sapling::value::NoteValue::from_raw(note_value), rseed, - )), + ), spending_key_scope, note_commitment_tree_position, )) @@ -190,7 +190,7 @@ pub(crate) fn get_spendable_sapling_note( params: &P, txid: &TxId, index: u32, -) -> Result>, SqliteClientError> { +) -> Result>, SqliteClientError> { super::common::get_spendable_note( conn, params, @@ -213,7 +213,7 @@ pub(crate) fn select_spendable_sapling_notes( target_value: NonNegativeAmount, anchor_height: BlockHeight, exclude: &[ReceivedNoteId], -) -> Result>, SqliteClientError> { +) -> Result>, SqliteClientError> { super::common::select_spendable_notes( conn, params, @@ -403,6 +403,7 @@ pub(crate) mod tests { type Sk = ExtendedSpendingKey; type Fvk = DiversifiableFullViewingKey; type MerkleTreeHash = sapling::Node; + type Note = sapling::Note; fn test_account_fvk(st: &TestState) -> Self::Fvk { st.test_account_sapling().unwrap() @@ -459,7 +460,7 @@ pub(crate) mod tests { target_value: NonNegativeAmount, anchor_height: BlockHeight, exclude: &[ReceivedNoteId], - ) -> Result>, SqliteClientError> { + ) -> Result>, SqliteClientError> { select_spendable_sapling_notes( &st.wallet().conn, &st.wallet().params,