Skip to content

Commit

Permalink
Merge pull request #1268 from nuttycom/sqlite_wallet/cross_pool_note_…
Browse files Browse the repository at this point in the history
…selection

zcash_client_backend: Fix note selection & add more multi-pool tests.
  • Loading branch information
str4d authored Mar 14, 2024
2 parents 1c72b0b + a81e7ff commit 2e0a300
Show file tree
Hide file tree
Showing 7 changed files with 244 additions and 54 deletions.
2 changes: 2 additions & 0 deletions zcash_client_backend/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ and this library adheres to Rust's notion of
### Removed
- `zcash_client_backend::PoolType::is_receiver`: use
`zcash_keys::Address::has_receiver` instead.
- `zcash_client_backend::wallet::ReceivedNote::traverse_opt` removed as
unnecessary.

### Fixed
- This release fixes an error in amount parsing in `zip321` that previously
Expand Down
91 changes: 64 additions & 27 deletions zcash_client_backend/src/data_api/wallet/input_selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,13 +273,13 @@ impl sapling::OutputView for SaplingPayment {
#[cfg(feature = "orchard")]
pub(crate) struct OrchardPayment(NonNegativeAmount);

// TODO: introduce this method when it is needed for testing.
// #[cfg(test)]
// impl OrchardPayment {
// pub(crate) fn new(amount: NonNegativeAmount) -> Self {
// OrchardPayment(amount)
// }
// }
#[cfg(test)]
#[cfg(feature = "orchard")]
impl OrchardPayment {
pub(crate) fn new(amount: NonNegativeAmount) -> Self {
OrchardPayment(amount)
}
}

#[cfg(feature = "orchard")]
impl orchard_fees::OutputView for OrchardPayment {
Expand Down Expand Up @@ -394,38 +394,75 @@ where
// of funds selected is strictly increasing. The loop will either return a successful
// result or the wallet will eventually run out of funds to select.
loop {
#[cfg(feature = "orchard")]
let (sapling_input_total, orchard_input_total) = (
shielded_inputs
.iter()
.filter(|i| matches!(i.note(), Note::Sapling(_)))
.map(|i| i.note().value())
.sum::<Option<NonNegativeAmount>>()
.ok_or(BalanceError::Overflow)?,
shielded_inputs
.iter()
.filter(|i| matches!(i.note(), Note::Orchard(_)))
.map(|i| i.note().value())
.sum::<Option<NonNegativeAmount>>()
.ok_or(BalanceError::Overflow)?,
);

#[cfg(not(feature = "orchard"))]
let orchard_input_total = NonNegativeAmount::ZERO;

let sapling_inputs =
if sapling_outputs.is_empty() && orchard_input_total >= amount_required {
// Avoid selecting Sapling inputs if we don't have Sapling outputs and the value is
// fully covered by Orchard inputs.
#[cfg(feature = "orchard")]
shielded_inputs.retain(|i| matches!(i.note(), Note::Orchard(_)));
vec![]
} else {
#[allow(clippy::unnecessary_filter_map)]
shielded_inputs
.iter()
.filter_map(|i| match i.note() {
Note::Sapling(n) => Some((*i.internal_note_id(), n.value())),
#[cfg(feature = "orchard")]
Note::Orchard(_) => None,
})
.collect::<Vec<_>>()
};

#[cfg(feature = "orchard")]
let orchard_inputs =
if orchard_outputs.is_empty() && sapling_input_total >= amount_required {
// Avoid selecting Orchard inputs if we don't have Orchard outputs and the value is
// fully covered by Sapling inputs.
shielded_inputs.retain(|i| matches!(i.note(), Note::Sapling(_)));
vec![]
} else {
shielded_inputs
.iter()
.filter_map(|i| match i.note() {
Note::Sapling(_) => None,
Note::Orchard(n) => Some((*i.internal_note_id(), n.value())),
})
.collect::<Vec<_>>()
};

let balance = self.change_strategy.compute_balance(
params,
target_height,
&Vec::<WalletTransparentOutput>::new(),
&transparent_outputs,
&(
::sapling::builder::BundleType::DEFAULT,
&shielded_inputs
.iter()
.cloned()
.filter_map(|i| {
i.traverse_opt(|wn| match wn {
Note::Sapling(n) => Some(n),
#[cfg(feature = "orchard")]
_ => None,
})
})
.collect::<Vec<_>>()[..],
&sapling_inputs[..],
&sapling_outputs[..],
),
#[cfg(feature = "orchard")]
&(
::orchard::builder::BundleType::DEFAULT,
&shielded_inputs
.iter()
.filter_map(|i| {
i.clone().traverse_opt(|wn| match wn {
Note::Orchard(n) => Some(n),
_ => None,
})
})
.collect::<Vec<_>>()[..],
&orchard_inputs[..],
&orchard_outputs[..],
),
&self.dust_output_policy,
Expand Down
46 changes: 46 additions & 0 deletions zcash_client_backend/src/fees/zip317.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,9 @@ mod tests {
ShieldedProtocol,
};

#[cfg(feature = "orchard")]
use crate::data_api::wallet::input_selection::OrchardPayment;

#[test]
fn change_without_dust() {
let change_strategy = SingleOutputChangeStrategy::new(
Expand Down Expand Up @@ -280,6 +283,49 @@ mod tests {
);
}

#[test]
#[cfg(feature = "orchard")]
fn cross_pool_change_without_dust() {
let change_strategy = SingleOutputChangeStrategy::new(
Zip317FeeRule::standard(),
None,
ShieldedProtocol::Orchard,
);

// spend a single Sapling note that is sufficient to pay the fee
let result = change_strategy.compute_balance(
&Network::TestNetwork,
Network::TestNetwork
.activation_height(NetworkUpgrade::Nu5)
.unwrap(),
&Vec::<TestTransparentInput>::new(),
&Vec::<TxOut>::new(),
&(
sapling::builder::BundleType::DEFAULT,
&[TestSaplingInput {
note_id: 0,
value: NonNegativeAmount::const_from_u64(55000),
}][..],
&Vec::<Infallible>::new()[..],
),
&(
orchard::builder::BundleType::DEFAULT,
&Vec::<Infallible>::new()[..],
&[OrchardPayment::new(NonNegativeAmount::const_from_u64(
30000,
))][..],
),
&DustOutputPolicy::default(),
);

assert_matches!(
result,
Ok(balance) if
balance.proposed_change() == [ChangeValue::orchard(NonNegativeAmount::const_from_u64(5000), None)] &&
balance.fee_required() == NonNegativeAmount::const_from_u64(20000)
);
}

#[test]
fn change_with_transparent_payments() {
let change_strategy = SingleOutputChangeStrategy::new(
Expand Down
44 changes: 25 additions & 19 deletions zcash_client_backend/src/wallet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -430,26 +430,18 @@ impl<NoteRef, NoteT> ReceivedNote<NoteRef, NoteT> {
pub fn note_commitment_tree_position(&self) -> Position {
self.note_commitment_tree_position
}
}

/// Applies the given function to the `note` field of this ReceivedNote and returns
/// `None` if that function returns `None`, or otherwise a `Some` containing
/// a `ReceivedNote` with its `note` field swapped out for the result of the function.
///
/// The name `traverse` refers to the general operation that has the Haskell type
/// `Applicative f => (a -> f b) -> t a -> f (t b)`, that this method specializes
/// with `ReceivedNote<NoteRef, _>` for `t` and `Option<_>` for `f`.
pub fn traverse_opt<B>(
self,
f: impl FnOnce(NoteT) -> Option<B>,
) -> Option<ReceivedNote<NoteRef, B>> {
f(self.note).map(|n0| ReceivedNote {
note_id: self.note_id,
txid: self.txid,
output_index: self.output_index,
note: n0,
spending_key_scope: self.spending_key_scope,
note_commitment_tree_position: self.note_commitment_tree_position,
})
impl<NoteRef> sapling_fees::InputView<NoteRef> for (NoteRef, sapling::value::NoteValue) {
fn note_id(&self) -> &NoteRef {
&self.0
}

fn value(&self) -> NonNegativeAmount {
self.1
.inner()
.try_into()
.expect("Sapling note values are indirectly checked by consensus.")
}
}

Expand All @@ -467,6 +459,20 @@ impl<NoteRef> sapling_fees::InputView<NoteRef> for ReceivedNote<NoteRef, sapling
}
}

#[cfg(feature = "orchard")]
impl<NoteRef> orchard_fees::InputView<NoteRef> for (NoteRef, orchard::value::NoteValue) {
fn note_id(&self) -> &NoteRef {
&self.0
}

fn value(&self) -> NonNegativeAmount {
self.1
.inner()
.try_into()
.expect("Orchard note values are indirectly checked by consensus.")
}
}

#[cfg(feature = "orchard")]
impl<NoteRef> orchard_fees::InputView<NoteRef> for ReceivedNote<NoteRef, orchard::Note> {
fn note_id(&self) -> &NoteRef {
Expand Down
92 changes: 88 additions & 4 deletions zcash_client_sqlite/src/testing/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1406,7 +1406,7 @@ pub(crate) fn checkpoint_gaps<T: ShieldedPoolTester>() {
}

#[cfg(feature = "orchard")]
pub(crate) fn cross_pool_exchange<P0: ShieldedPoolTester, P1: ShieldedPoolTester>() {
pub(crate) fn pool_crossing_required<P0: ShieldedPoolTester, P1: ShieldedPoolTester>() {
let mut st = TestBuilder::new()
.with_block_cache()
.with_test_account(|params| AccountBirthday::from_activation(params, NetworkUpgrade::Nu5))
Expand All @@ -1419,12 +1419,11 @@ pub(crate) fn cross_pool_exchange<P0: ShieldedPoolTester, P1: ShieldedPoolTester
let p1_fvk = P1::test_account_fvk(&st);
let p1_to = P1::fvk_default_address(&p1_fvk);

let note_value = NonNegativeAmount::const_from_u64(300000);
let note_value = NonNegativeAmount::const_from_u64(350000);
st.generate_next_block(&p0_fvk, AddressType::DefaultExternal, note_value);
st.generate_next_block(&p1_fvk, AddressType::DefaultExternal, note_value);
st.scan_cached_blocks(birthday.height(), 2);

let initial_balance = (note_value * 2).unwrap();
let initial_balance = note_value;
assert_eq!(st.get_total_balance(account), initial_balance);
assert_eq!(st.get_spendable_balance(account, 1), initial_balance);

Expand Down Expand Up @@ -1489,6 +1488,91 @@ pub(crate) fn cross_pool_exchange<P0: ShieldedPoolTester, P1: ShieldedPoolTester
);
}

#[cfg(feature = "orchard")]
pub(crate) fn fully_funded_fully_private<P0: ShieldedPoolTester, P1: ShieldedPoolTester>() {
let mut st = TestBuilder::new()
.with_block_cache()
.with_test_account(|params| AccountBirthday::from_activation(params, NetworkUpgrade::Nu5))
.build();

let (account, usk, birthday) = st.test_account().unwrap();

let p0_fvk = P0::test_account_fvk(&st);

let p1_fvk = P1::test_account_fvk(&st);
let p1_to = P1::fvk_default_address(&p1_fvk);

let note_value = NonNegativeAmount::const_from_u64(350000);
st.generate_next_block(&p0_fvk, AddressType::DefaultExternal, note_value);
st.generate_next_block(&p1_fvk, AddressType::DefaultExternal, note_value);
st.scan_cached_blocks(birthday.height(), 2);

let initial_balance = (note_value * 2).unwrap();
assert_eq!(st.get_total_balance(account), initial_balance);
assert_eq!(st.get_spendable_balance(account, 1), initial_balance);

let transfer_amount = NonNegativeAmount::const_from_u64(200000);
let p0_to_p1 = zip321::TransactionRequest::new(vec![Payment {
recipient_address: p1_to,
amount: transfer_amount,
memo: None,
label: None,
message: None,
other_params: vec![],
}])
.unwrap();

let fee_rule = StandardFeeRule::Zip317;
let input_selector = GreedyInputSelector::new(
// We set the default change output pool to P0, because we want to verify later that
// change is actually sent to P1 (as the transaction is fully fundable from P1).
standard::SingleOutputChangeStrategy::new(fee_rule, None, P0::SHIELDED_PROTOCOL),
DustOutputPolicy::default(),
);
let proposal0 = st
.propose_transfer(
account,
&input_selector,
p0_to_p1,
NonZeroU32::new(1).unwrap(),
)
.unwrap();

let _min_target_height = proposal0.min_target_height();
assert_eq!(proposal0.steps().len(), 1);
let step0 = &proposal0.steps().head;

// We expect 2 logical actions, since either pool can pay the full balance required
// and note selection should choose the fully-private path.
let expected_fee = NonNegativeAmount::const_from_u64(10000);
assert_eq!(step0.balance().fee_required(), expected_fee);

let expected_change = (note_value - transfer_amount - expected_fee).unwrap();
let proposed_change = step0.balance().proposed_change();
assert_eq!(proposed_change.len(), 1);
let change_output = proposed_change.get(0).unwrap();
// Since there are sufficient funds in either pool, change is kept in the same pool as
// the source note (the target pool), and does not necessarily follow preference order.
assert_eq!(change_output.output_pool(), P1::SHIELDED_PROTOCOL);
assert_eq!(change_output.value(), expected_change);

let create_proposed_result =
st.create_proposed_transactions::<Infallible, _>(&usk, OvkPolicy::Sender, &proposal0);
assert_matches!(&create_proposed_result, Ok(txids) if txids.len() == 1);

let (h, _) = st.generate_next_block_including(create_proposed_result.unwrap()[0]);
st.scan_cached_blocks(h, 1);

assert_eq!(
st.get_total_balance(account),
(initial_balance - expected_fee).unwrap()
);
assert_eq!(
st.get_spendable_balance(account, 1),
(initial_balance - expected_fee).unwrap()
);
}

pub(crate) fn valid_chain_states<T: ShieldedPoolTester>() {
let mut st = TestBuilder::new()
.with_block_cache()
Expand Down
11 changes: 9 additions & 2 deletions zcash_client_sqlite/src/wallet/orchard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -604,9 +604,16 @@ pub(crate) mod tests {
}

#[test]
fn cross_pool_exchange() {
fn pool_crossing_required() {
use crate::wallet::sapling::tests::SaplingPoolTester;

testing::pool::cross_pool_exchange::<OrchardPoolTester, SaplingPoolTester>()
testing::pool::pool_crossing_required::<OrchardPoolTester, SaplingPoolTester>()
}

#[test]
fn fully_funded_fully_private() {
use crate::wallet::sapling::tests::SaplingPoolTester;

testing::pool::fully_funded_fully_private::<OrchardPoolTester, SaplingPoolTester>()
}
}
12 changes: 10 additions & 2 deletions zcash_client_sqlite/src/wallet/sapling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -607,9 +607,17 @@ pub(crate) mod tests {

#[test]
#[cfg(feature = "orchard")]
fn cross_pool_exchange() {
fn pool_crossing_required() {
use crate::wallet::orchard::tests::OrchardPoolTester;

testing::pool::cross_pool_exchange::<SaplingPoolTester, OrchardPoolTester>()
testing::pool::pool_crossing_required::<SaplingPoolTester, OrchardPoolTester>()
}

#[test]
#[cfg(feature = "orchard")]
fn fully_funded_fully_private() {
use crate::wallet::orchard::tests::OrchardPoolTester;

testing::pool::fully_funded_fully_private::<SaplingPoolTester, OrchardPoolTester>()
}
}

0 comments on commit 2e0a300

Please sign in to comment.