Skip to content

Commit

Permalink
Merge pull request zcash#1271 from nuttycom/sqlite_wallet/cross_pool_…
Browse files Browse the repository at this point in the history
…selection_refactor

zcash_client_backend: Pure refactoring for note selection ergonomics.
  • Loading branch information
nuttycom authored Mar 14, 2024
2 parents 2e0a300 + 0bae47b commit be312f8
Show file tree
Hide file tree
Showing 9 changed files with 241 additions and 113 deletions.
9 changes: 7 additions & 2 deletions zcash_client_backend/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ and this library adheres to Rust's notion of
- `AccountKind`
- `BlockMetadata::orchard_tree_size`
- `DecryptedTransaction::{new, tx(), orchard_outputs()}`
- `NoteRetention`
- `ScannedBlock::orchard`
- `ScannedBlockCommitments::orchard`
- `SentTransaction::new`
- `SpendableNotes`
- `ORCHARD_SHARD_HEIGHT`
- `BlockMetadata::orchard_tree_size`
- `WalletSummary::next_orchard_subtree_index`
Expand All @@ -46,6 +48,9 @@ and this library adheres to Rust's notion of
- `WalletOrchardSpend`
- `WalletOrchardOutput`
- `WalletTx::{orchard_spends, orchard_outputs}`
- `ReceivedNote::map_note`
- `ReceivedNote<_, sapling::Note>::note_value`
- `ReceivedNote<_, orchard::note::Note>::note_value`

### Changed
- `zcash_client_backend::data_api`:
Expand All @@ -61,8 +66,8 @@ and this library adheres to Rust's notion of
- Added `get_orchard_nullifiers` method.
- Changes to the `InputSource` trait:
- `select_spendable_notes` now takes its `target_value` argument as a
`NonNegativeAmount`. Also, the values of the returned map are also
`NonNegativeAmount`s instead of `Amount`s.
`NonNegativeAmount`. Also, it now returns a `SpendableNotes` data
structure instead of a vector.
- Fields of `DecryptedTransaction` are now private. Use `DecryptedTransaction::new`
and the newly provided accessors instead.
- Fields of `SentTransaction` are now private. Use `SentTransaction::new`
Expand Down
128 changes: 123 additions & 5 deletions zcash_client_backend/src/data_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,124 @@ impl<AccountId: Eq + Hash> WalletSummary<AccountId> {
}
}

/// A predicate that can be used to choose whether or not a particular note is retained in note
/// selection.
pub trait NoteRetention<NoteRef> {
/// Returns whether the specified Sapling note should be retained.
fn should_retain_sapling(&self, note: &ReceivedNote<NoteRef, sapling::Note>) -> bool;
/// Returns whether the specified Orchard note should be retained.
#[cfg(feature = "orchard")]
fn should_retain_orchard(&self, note: &ReceivedNote<NoteRef, orchard::note::Note>) -> bool;
}

pub(crate) struct SimpleNoteRetention {
pub(crate) sapling: bool,
#[cfg(feature = "orchard")]
pub(crate) orchard: bool,
}

impl<NoteRef> NoteRetention<NoteRef> for SimpleNoteRetention {
fn should_retain_sapling(&self, _: &ReceivedNote<NoteRef, sapling::Note>) -> bool {
self.sapling
}

#[cfg(feature = "orchard")]
fn should_retain_orchard(&self, _: &ReceivedNote<NoteRef, orchard::note::Note>) -> bool {
self.orchard
}
}

/// Spendable shielded outputs controlled by the wallet.
pub struct SpendableNotes<NoteRef> {
sapling: Vec<ReceivedNote<NoteRef, sapling::Note>>,
#[cfg(feature = "orchard")]
orchard: Vec<ReceivedNote<NoteRef, orchard::note::Note>>,
}

impl<NoteRef> SpendableNotes<NoteRef> {
/// Construct a new empty [`SpendableNotes`].
pub fn empty() -> Self {
Self::new(
vec![],
#[cfg(feature = "orchard")]
vec![],
)
}

/// Construct a new [`SpendableNotes`] from its constituent parts.
pub fn new(
sapling: Vec<ReceivedNote<NoteRef, sapling::Note>>,
#[cfg(feature = "orchard")] orchard: Vec<ReceivedNote<NoteRef, orchard::note::Note>>,
) -> Self {
Self {
sapling,
#[cfg(feature = "orchard")]
orchard,
}
}

/// Returns the set of spendable Sapling notes.
pub fn sapling(&self) -> &[ReceivedNote<NoteRef, sapling::Note>] {
self.sapling.as_ref()
}

/// Returns the set of spendable Orchard notes.
#[cfg(feature = "orchard")]
pub fn orchard(&self) -> &[ReceivedNote<NoteRef, orchard::note::Note>] {
self.orchard.as_ref()
}

/// Computes the total value of Sapling notes.
pub fn sapling_value(&self) -> Result<NonNegativeAmount, BalanceError> {
self.sapling
.iter()
.try_fold(NonNegativeAmount::ZERO, |acc, n| {
(acc + n.note_value()?).ok_or(BalanceError::Overflow)
})
}

/// Computes the total value of Sapling notes.
#[cfg(feature = "orchard")]
pub fn orchard_value(&self) -> Result<NonNegativeAmount, BalanceError> {
self.orchard
.iter()
.try_fold(NonNegativeAmount::ZERO, |acc, n| {
(acc + n.note_value()?).ok_or(BalanceError::Overflow)
})
}

/// Computes the total value of spendable inputs
pub fn total_value(&self) -> Result<NonNegativeAmount, BalanceError> {
#[cfg(not(feature = "orchard"))]
return self.sapling_value();

#[cfg(feature = "orchard")]
return (self.sapling_value()? + self.orchard_value()?).ok_or(BalanceError::Overflow);
}

/// Consumes this [`SpendableNotes`] value and produces a vector of
/// [`ReceivedNote<NoteRef, Note>`] values.
pub fn into_vec(
self,
retention: &impl NoteRetention<NoteRef>,
) -> Vec<ReceivedNote<NoteRef, Note>> {
let iter = self.sapling.into_iter().filter_map(|n| {
retention
.should_retain_sapling(&n)
.then(|| n.map_note(Note::Sapling))
});

#[cfg(feature = "orchard")]
let iter = iter.chain(self.orchard.into_iter().filter_map(|n| {
retention
.should_retain_orchard(&n)
.then(|| n.map_note(Note::Orchard))
}));

iter.collect()
}
}

/// A trait representing the capability to query a data store for unspent transaction outputs
/// belonging to a wallet.
pub trait InputSource {
Expand Down Expand Up @@ -525,7 +643,7 @@ pub trait InputSource {
sources: &[ShieldedProtocol],
anchor_height: BlockHeight,
exclude: &[Self::NoteRef],
) -> Result<Vec<ReceivedNote<Self::NoteRef, Note>>, Self::Error>;
) -> Result<SpendableNotes<Self::NoteRef>, Self::Error>;

/// Fetches a spendable transparent output.
///
Expand Down Expand Up @@ -1488,8 +1606,8 @@ pub mod testing {
chain::{ChainState, CommitmentTreeRoot},
scanning::ScanRange,
AccountBirthday, BlockMetadata, DecryptedTransaction, InputSource, NullifierQuery,
ScannedBlock, SentTransaction, WalletCommitmentTrees, WalletRead, WalletSummary,
WalletWrite, SAPLING_SHARD_HEIGHT,
ScannedBlock, SentTransaction, SpendableNotes, WalletCommitmentTrees, WalletRead,
WalletSummary, WalletWrite, SAPLING_SHARD_HEIGHT,
};

#[cfg(feature = "transparent-inputs")]
Expand Down Expand Up @@ -1545,8 +1663,8 @@ pub mod testing {
_sources: &[ShieldedProtocol],
_anchor_height: BlockHeight,
_exclude: &[Self::NoteRef],
) -> Result<Vec<ReceivedNote<Self::NoteRef, Note>>, Self::Error> {
Ok(Vec::new())
) -> Result<SpendableNotes<Self::NoteRef>, Self::Error> {
Ok(SpendableNotes::empty())
}
}

Expand Down
92 changes: 36 additions & 56 deletions zcash_client_backend/src/data_api/wallet/input_selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ use zcash_primitives::{

use crate::{
address::{Address, UnifiedAddress},
data_api::InputSource,
data_api::{InputSource, SimpleNoteRetention, SpendableNotes},
fees::{sapling, ChangeError, ChangeStrategy, DustOutputPolicy},
proposal::{Proposal, ProposalError, ShieldedInputs},
wallet::{Note, ReceivedNote, WalletTransparentOutput},
wallet::WalletTransparentOutput,
zip321::TransactionRequest,
PoolType, ShieldedProtocol,
};
Expand Down Expand Up @@ -386,7 +386,7 @@ where
}
}

let mut shielded_inputs: Vec<ReceivedNote<DbT::NoteRef, Note>> = vec![];
let mut shielded_inputs = SpendableNotes::empty();
let mut prior_available = NonNegativeAmount::ZERO;
let mut amount_required = NonNegativeAmount::ZERO;
let mut exclude: Vec<DbT::NoteRef> = vec![];
Expand All @@ -396,58 +396,39 @@ where
loop {
#[cfg(feature = "orchard")]
let (sapling_input_total, orchard_input_total) = (
shielded_inputs
.iter()
.filter(|i| matches!(i.note(), Note::Sapling(_)))
.map(|i| i.note().value())
.sum::<Option<NonNegativeAmount>>()
.ok_or(BalanceError::Overflow)?,
shielded_inputs
.iter()
.filter(|i| matches!(i.note(), Note::Orchard(_)))
.map(|i| i.note().value())
.sum::<Option<NonNegativeAmount>>()
.ok_or(BalanceError::Overflow)?,
shielded_inputs.sapling_value()?,
shielded_inputs.orchard_value()?,
);

#[cfg(not(feature = "orchard"))]
let orchard_input_total = NonNegativeAmount::ZERO;

let sapling_inputs =
if sapling_outputs.is_empty() && orchard_input_total >= amount_required {
// Avoid selecting Sapling inputs if we don't have Sapling outputs and the value is
// fully covered by Orchard inputs.
#[cfg(feature = "orchard")]
shielded_inputs.retain(|i| matches!(i.note(), Note::Orchard(_)));
vec![]
} else {
#[allow(clippy::unnecessary_filter_map)]
shielded_inputs
.iter()
.filter_map(|i| match i.note() {
Note::Sapling(n) => Some((*i.internal_note_id(), n.value())),
#[cfg(feature = "orchard")]
Note::Orchard(_) => None,
})
.collect::<Vec<_>>()
};
let use_sapling =
!(sapling_outputs.is_empty() && orchard_input_total >= amount_required);
#[cfg(feature = "orchard")]
let use_orchard =
!(orchard_outputs.is_empty() && sapling_input_total >= amount_required);

let sapling_inputs = if use_sapling {
shielded_inputs
.sapling()
.iter()
.map(|i| (*i.internal_note_id(), i.note().value()))
.collect()
} else {
vec![]
};

#[cfg(feature = "orchard")]
let orchard_inputs =
if orchard_outputs.is_empty() && sapling_input_total >= amount_required {
// Avoid selecting Orchard inputs if we don't have Orchard outputs and the value is
// fully covered by Sapling inputs.
shielded_inputs.retain(|i| matches!(i.note(), Note::Sapling(_)));
vec![]
} else {
shielded_inputs
.iter()
.filter_map(|i| match i.note() {
Note::Sapling(_) => None,
Note::Orchard(n) => Some((*i.internal_note_id(), n.value())),
})
.collect::<Vec<_>>()
};
let orchard_inputs = if use_orchard {
shielded_inputs
.orchard()
.iter()
.map(|i| (*i.internal_note_id(), i.note().value()))
.collect()
} else {
vec![]
};

let balance = self.change_strategy.compute_balance(
params,
Expand All @@ -474,8 +455,12 @@ where
transaction_request,
payment_pools,
vec![],
NonEmpty::from_vec(shielded_inputs)
.map(|notes| ShieldedInputs::from_parts(anchor_height, notes)),
NonEmpty::from_vec(shielded_inputs.into_vec(&SimpleNoteRetention {
sapling: use_sapling,
#[cfg(feature = "orchard")]
orchard: use_orchard,
}))
.map(|notes| ShieldedInputs::from_parts(anchor_height, notes)),
balance,
(*self.change_strategy.fee_rule()).clone(),
target_height,
Expand Down Expand Up @@ -514,12 +499,7 @@ where
)
.map_err(InputSelectorError::DataSource)?;

let new_available = shielded_inputs
.iter()
.map(|n| n.note().value())
.sum::<Option<NonNegativeAmount>>()
.ok_or(BalanceError::Overflow)?;

let new_available = shielded_inputs.total_value()?;
if new_available <= prior_available {
return Err(InputSelectorError::InsufficientFunds {
required: amount_required,
Expand Down
29 changes: 29 additions & 0 deletions zcash_client_backend/src/wallet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use zcash_primitives::{
},
zip32::Scope,
};
use zcash_protocol::value::BalanceError;

use crate::{address::UnifiedAddress, fees::sapling as sapling_fees, PoolType, ShieldedProtocol};

Expand Down Expand Up @@ -430,6 +431,34 @@ impl<NoteRef, NoteT> ReceivedNote<NoteRef, NoteT> {
pub fn note_commitment_tree_position(&self) -> Position {
self.note_commitment_tree_position
}

/// Map over the `note` field of this data structure.
///
/// Consume this value, applying the provided function to the value of its `note` field and
/// returning a new `ReceivedNote` with the result as its `note` field value.
pub fn map_note<N, F: Fn(NoteT) -> N>(self, f: F) -> ReceivedNote<NoteRef, N> {
ReceivedNote {
note_id: self.note_id,
txid: self.txid,
output_index: self.output_index,
note: f(self.note),
spending_key_scope: self.spending_key_scope,
note_commitment_tree_position: self.note_commitment_tree_position,
}
}
}

impl<NoteRef> ReceivedNote<NoteRef, sapling::Note> {
pub fn note_value(&self) -> Result<NonNegativeAmount, BalanceError> {
self.note.value().inner().try_into()
}
}

#[cfg(feature = "orchard")]
impl<NoteRef> ReceivedNote<NoteRef, orchard::note::Note> {
pub fn note_value(&self) -> Result<NonNegativeAmount, BalanceError> {
self.note.value().inner().try_into()
}
}

impl<NoteRef> sapling_fees::InputView<NoteRef> for (NoteRef, sapling::value::NoteValue) {
Expand Down
Loading

0 comments on commit be312f8

Please sign in to comment.