Skip to content

Commit

Permalink
zcash_client_backend: Improve API ergonomics for input selection.
Browse files Browse the repository at this point in the history
  • Loading branch information
nuttycom committed Mar 14, 2024
1 parent 22f3418 commit 0bae47b
Show file tree
Hide file tree
Showing 9 changed files with 159 additions and 113 deletions.
6 changes: 4 additions & 2 deletions zcash_client_backend/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ and this library adheres to Rust's notion of
- `WalletOrchardOutput`
- `WalletTx::{orchard_spends, orchard_outputs}`
- `ReceivedNote::map_note`
- `ReceivedNote<_, sapling::Note>::note_value`
- `ReceivedNote<_, orchard::note::Note>::note_value`

### Changed
- `zcash_client_backend::data_api`:
Expand All @@ -64,8 +66,8 @@ and this library adheres to Rust's notion of
- Added `get_orchard_nullifiers` method.
- Changes to the `InputSource` trait:
- `select_spendable_notes` now takes its `target_value` argument as a
`NonNegativeAmount`. Also, the values of the returned map are also
`NonNegativeAmount`s instead of `Amount`s.
`NonNegativeAmount`. Also, it now returns a `SpendableNotes` data
structure instead of a vector.
- Fields of `DecryptedTransaction` are now private. Use `DecryptedTransaction::new`
and the newly provided accessors instead.
- Fields of `SentTransaction` are now private. Use `SentTransaction::new`
Expand Down
64 changes: 59 additions & 5 deletions zcash_client_backend/src/data_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,23 @@ pub trait NoteRetention<NoteRef> {
fn should_retain_orchard(&self, note: &ReceivedNote<NoteRef, orchard::note::Note>) -> bool;
}

pub(crate) struct SimpleNoteRetention {
pub(crate) sapling: bool,
#[cfg(feature = "orchard")]
pub(crate) orchard: bool,
}

impl<NoteRef> NoteRetention<NoteRef> for SimpleNoteRetention {
fn should_retain_sapling(&self, _: &ReceivedNote<NoteRef, sapling::Note>) -> bool {
self.sapling
}

#[cfg(feature = "orchard")]
fn should_retain_orchard(&self, _: &ReceivedNote<NoteRef, orchard::note::Note>) -> bool {
self.orchard
}
}

/// Spendable shielded outputs controlled by the wallet.
pub struct SpendableNotes<NoteRef> {
sapling: Vec<ReceivedNote<NoteRef, sapling::Note>>,
Expand All @@ -502,6 +519,15 @@ pub struct SpendableNotes<NoteRef> {
}

impl<NoteRef> SpendableNotes<NoteRef> {
/// Construct a new empty [`SpendableNotes`].
pub fn empty() -> Self {
Self::new(
vec![],
#[cfg(feature = "orchard")]
vec![],
)
}

/// Construct a new [`SpendableNotes`] from its constituent parts.
pub fn new(
sapling: Vec<ReceivedNote<NoteRef, sapling::Note>>,
Expand All @@ -525,6 +551,34 @@ impl<NoteRef> SpendableNotes<NoteRef> {
self.orchard.as_ref()
}

/// Computes the total value of Sapling notes.
pub fn sapling_value(&self) -> Result<NonNegativeAmount, BalanceError> {
self.sapling
.iter()
.try_fold(NonNegativeAmount::ZERO, |acc, n| {
(acc + n.note_value()?).ok_or(BalanceError::Overflow)
})
}

/// Computes the total value of Sapling notes.
#[cfg(feature = "orchard")]
pub fn orchard_value(&self) -> Result<NonNegativeAmount, BalanceError> {
self.orchard
.iter()
.try_fold(NonNegativeAmount::ZERO, |acc, n| {
(acc + n.note_value()?).ok_or(BalanceError::Overflow)
})
}

/// Computes the total value of spendable inputs
pub fn total_value(&self) -> Result<NonNegativeAmount, BalanceError> {
#[cfg(not(feature = "orchard"))]
return self.sapling_value();

#[cfg(feature = "orchard")]
return (self.sapling_value()? + self.orchard_value()?).ok_or(BalanceError::Overflow);
}

/// Consumes this [`SpendableNotes`] value and produces a vector of
/// [`ReceivedNote<NoteRef, Note>`] values.
pub fn into_vec(
Expand Down Expand Up @@ -589,7 +643,7 @@ pub trait InputSource {
sources: &[ShieldedProtocol],
anchor_height: BlockHeight,
exclude: &[Self::NoteRef],
) -> Result<Vec<ReceivedNote<Self::NoteRef, Note>>, Self::Error>;
) -> Result<SpendableNotes<Self::NoteRef>, Self::Error>;

/// Fetches a spendable transparent output.
///
Expand Down Expand Up @@ -1552,8 +1606,8 @@ pub mod testing {
chain::{ChainState, CommitmentTreeRoot},
scanning::ScanRange,
AccountBirthday, BlockMetadata, DecryptedTransaction, InputSource, NullifierQuery,
ScannedBlock, SentTransaction, WalletCommitmentTrees, WalletRead, WalletSummary,
WalletWrite, SAPLING_SHARD_HEIGHT,
ScannedBlock, SentTransaction, SpendableNotes, WalletCommitmentTrees, WalletRead,
WalletSummary, WalletWrite, SAPLING_SHARD_HEIGHT,
};

#[cfg(feature = "transparent-inputs")]
Expand Down Expand Up @@ -1609,8 +1663,8 @@ pub mod testing {
_sources: &[ShieldedProtocol],
_anchor_height: BlockHeight,
_exclude: &[Self::NoteRef],
) -> Result<Vec<ReceivedNote<Self::NoteRef, Note>>, Self::Error> {
Ok(Vec::new())
) -> Result<SpendableNotes<Self::NoteRef>, Self::Error> {
Ok(SpendableNotes::empty())
}
}

Expand Down
92 changes: 36 additions & 56 deletions zcash_client_backend/src/data_api/wallet/input_selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ use zcash_primitives::{

use crate::{
address::{Address, UnifiedAddress},
data_api::InputSource,
data_api::{InputSource, SimpleNoteRetention, SpendableNotes},
fees::{sapling, ChangeError, ChangeStrategy, DustOutputPolicy},
proposal::{Proposal, ProposalError, ShieldedInputs},
wallet::{Note, ReceivedNote, WalletTransparentOutput},
wallet::WalletTransparentOutput,
zip321::TransactionRequest,
PoolType, ShieldedProtocol,
};
Expand Down Expand Up @@ -386,7 +386,7 @@ where
}
}

let mut shielded_inputs: Vec<ReceivedNote<DbT::NoteRef, Note>> = vec![];
let mut shielded_inputs = SpendableNotes::empty();
let mut prior_available = NonNegativeAmount::ZERO;
let mut amount_required = NonNegativeAmount::ZERO;
let mut exclude: Vec<DbT::NoteRef> = vec![];
Expand All @@ -396,58 +396,39 @@ where
loop {
#[cfg(feature = "orchard")]
let (sapling_input_total, orchard_input_total) = (
shielded_inputs
.iter()
.filter(|i| matches!(i.note(), Note::Sapling(_)))
.map(|i| i.note().value())
.sum::<Option<NonNegativeAmount>>()
.ok_or(BalanceError::Overflow)?,
shielded_inputs
.iter()
.filter(|i| matches!(i.note(), Note::Orchard(_)))
.map(|i| i.note().value())
.sum::<Option<NonNegativeAmount>>()
.ok_or(BalanceError::Overflow)?,
shielded_inputs.sapling_value()?,
shielded_inputs.orchard_value()?,
);

#[cfg(not(feature = "orchard"))]
let orchard_input_total = NonNegativeAmount::ZERO;

let sapling_inputs =
if sapling_outputs.is_empty() && orchard_input_total >= amount_required {
// Avoid selecting Sapling inputs if we don't have Sapling outputs and the value is
// fully covered by Orchard inputs.
#[cfg(feature = "orchard")]
shielded_inputs.retain(|i| matches!(i.note(), Note::Orchard(_)));
vec![]
} else {
#[allow(clippy::unnecessary_filter_map)]
shielded_inputs
.iter()
.filter_map(|i| match i.note() {
Note::Sapling(n) => Some((*i.internal_note_id(), n.value())),
#[cfg(feature = "orchard")]
Note::Orchard(_) => None,
})
.collect::<Vec<_>>()
};
let use_sapling =
!(sapling_outputs.is_empty() && orchard_input_total >= amount_required);
#[cfg(feature = "orchard")]
let use_orchard =
!(orchard_outputs.is_empty() && sapling_input_total >= amount_required);

let sapling_inputs = if use_sapling {
shielded_inputs
.sapling()
.iter()
.map(|i| (*i.internal_note_id(), i.note().value()))
.collect()
} else {
vec![]
};

#[cfg(feature = "orchard")]
let orchard_inputs =
if orchard_outputs.is_empty() && sapling_input_total >= amount_required {
// Avoid selecting Orchard inputs if we don't have Orchard outputs and the value is
// fully covered by Sapling inputs.
shielded_inputs.retain(|i| matches!(i.note(), Note::Sapling(_)));
vec![]
} else {
shielded_inputs
.iter()
.filter_map(|i| match i.note() {
Note::Sapling(_) => None,
Note::Orchard(n) => Some((*i.internal_note_id(), n.value())),
})
.collect::<Vec<_>>()
};
let orchard_inputs = if use_orchard {
shielded_inputs
.orchard()
.iter()
.map(|i| (*i.internal_note_id(), i.note().value()))
.collect()
} else {
vec![]
};

let balance = self.change_strategy.compute_balance(
params,
Expand All @@ -474,8 +455,12 @@ where
transaction_request,
payment_pools,
vec![],
NonEmpty::from_vec(shielded_inputs)
.map(|notes| ShieldedInputs::from_parts(anchor_height, notes)),
NonEmpty::from_vec(shielded_inputs.into_vec(&SimpleNoteRetention {
sapling: use_sapling,
#[cfg(feature = "orchard")]
orchard: use_orchard,
}))
.map(|notes| ShieldedInputs::from_parts(anchor_height, notes)),
balance,
(*self.change_strategy.fee_rule()).clone(),
target_height,
Expand Down Expand Up @@ -514,12 +499,7 @@ where
)
.map_err(InputSelectorError::DataSource)?;

let new_available = shielded_inputs
.iter()
.map(|n| n.note().value())
.sum::<Option<NonNegativeAmount>>()
.ok_or(BalanceError::Overflow)?;

let new_available = shielded_inputs.total_value()?;
if new_available <= prior_available {
return Err(InputSelectorError::InsufficientFunds {
required: amount_required,
Expand Down
14 changes: 14 additions & 0 deletions zcash_client_backend/src/wallet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use zcash_primitives::{
},
zip32::Scope,
};
use zcash_protocol::value::BalanceError;

use crate::{address::UnifiedAddress, fees::sapling as sapling_fees, PoolType, ShieldedProtocol};

Expand Down Expand Up @@ -447,6 +448,19 @@ impl<NoteRef, NoteT> ReceivedNote<NoteRef, NoteT> {
}
}

impl<NoteRef> ReceivedNote<NoteRef, sapling::Note> {
pub fn note_value(&self) -> Result<NonNegativeAmount, BalanceError> {
self.note.value().inner().try_into()
}
}

#[cfg(feature = "orchard")]
impl<NoteRef> ReceivedNote<NoteRef, orchard::note::Note> {
pub fn note_value(&self) -> Result<NonNegativeAmount, BalanceError> {
self.note.value().inner().try_into()
}
}

impl<NoteRef> sapling_fees::InputView<NoteRef> for (NoteRef, sapling::value::NoteValue) {
fn note_id(&self) -> &NoteRef {
&self.0
Expand Down
50 changes: 26 additions & 24 deletions zcash_client_sqlite/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ use zcash_client_backend::{
chain::{BlockSource, ChainState, CommitmentTreeRoot},
scanning::{ScanPriority, ScanRange},
Account, AccountBirthday, AccountKind, BlockMetadata, DecryptedTransaction, InputSource,
NullifierQuery, ScannedBlock, SentTransaction, WalletCommitmentTrees, WalletRead,
WalletSummary, WalletWrite, SAPLING_SHARD_HEIGHT,
NullifierQuery, ScannedBlock, SentTransaction, SpendableNotes, WalletCommitmentTrees,
WalletRead, WalletSummary, WalletWrite, SAPLING_SHARD_HEIGHT,
},
keys::{
AddressGenerationError, UnifiedAddressRequest, UnifiedFullViewingKey, UnifiedSpendingKey,
Expand Down Expand Up @@ -212,15 +212,17 @@ impl<C: Borrow<rusqlite::Connection>, P: consensus::Parameters> InputSource for
&self.params,
txid,
index,
),
)
.map(|opt| opt.map(|n| n.map_note(Note::Sapling))),
ShieldedProtocol::Orchard => {
#[cfg(feature = "orchard")]
return wallet::orchard::get_spendable_orchard_note(
self.conn.borrow(),
&self.params,
txid,
index,
);
)
.map(|opt| opt.map(|n| n.map_note(Note::Orchard)));

#[cfg(not(feature = "orchard"))]
return Err(SqliteClientError::UnsupportedPoolType(PoolType::Shielded(
Expand All @@ -237,26 +239,26 @@ impl<C: Borrow<rusqlite::Connection>, P: consensus::Parameters> InputSource for
_sources: &[ShieldedProtocol],
anchor_height: BlockHeight,
exclude: &[Self::NoteRef],
) -> Result<Vec<ReceivedNote<Self::NoteRef, Note>>, Self::Error> {
let received_iter = std::iter::empty();
let received_iter = received_iter.chain(wallet::sapling::select_spendable_sapling_notes(
self.conn.borrow(),
&self.params,
account,
target_value,
anchor_height,
exclude,
)?);
#[cfg(feature = "orchard")]
let received_iter = received_iter.chain(wallet::orchard::select_spendable_orchard_notes(
self.conn.borrow(),
&self.params,
account,
target_value,
anchor_height,
exclude,
)?);
Ok(received_iter.collect())
) -> Result<SpendableNotes<Self::NoteRef>, Self::Error> {
Ok(SpendableNotes::new(
wallet::sapling::select_spendable_sapling_notes(
self.conn.borrow(),
&self.params,
account,
target_value,
anchor_height,
exclude,
)?,
#[cfg(feature = "orchard")]
wallet::orchard::select_spendable_orchard_notes(
self.conn.borrow(),
&self.params,
account,
target_value,
anchor_height,
exclude,
)?,
))
}

#[cfg(feature = "transparent-inputs")]
Expand Down
3 changes: 2 additions & 1 deletion zcash_client_sqlite/src/testing/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ pub(crate) trait ShieldedPoolTester {
type Sk;
type Fvk: TestFvk;
type MerkleTreeHash;
type Note;

fn test_account_fvk<Cache>(st: &TestState<Cache>) -> Self::Fvk;
fn usk_to_sk(usk: &UnifiedSpendingKey) -> &Self::Sk;
Expand All @@ -109,7 +110,7 @@ pub(crate) trait ShieldedPoolTester {
target_value: NonNegativeAmount,
anchor_height: BlockHeight,
exclude: &[ReceivedNoteId],
) -> Result<Vec<ReceivedNote<ReceivedNoteId, Note>>, SqliteClientError>;
) -> Result<Vec<ReceivedNote<ReceivedNoteId, Self::Note>>, SqliteClientError>;

fn decrypted_pool_outputs_count(d_tx: &DecryptedTransaction<'_, AccountId>) -> usize;

Expand Down
Loading

0 comments on commit 0bae47b

Please sign in to comment.