diff --git a/psbt/examples/sign.rs b/psbt/examples/sign.rs index 14b375d..966ce68 100644 --- a/psbt/examples/sign.rs +++ b/psbt/examples/sign.rs @@ -22,24 +22,25 @@ fn main() { let file = std::fs::read(file).unwrap(); let xpriv = Xpriv::from_str(&xpriv).unwrap(); - let details = match validate::<_, _, VerboseError<_>, 10>(file.as_slice(), SECP256K1, xpriv) { - Ok(v) => v, - Err(e) => { - match e { - Error::Parse(e) => match e { - nom::Err::Incomplete(_) => println!("unexpected end of file"), - nom::Err::Error(e) | nom::Err::Failure(e) => { - for (i, e) in e.errors.iter().enumerate() { - println!("Error {i}: {e:?}"); + let details = + match validate::<_, _, VerboseError<_>, 10>(file.as_slice(), SECP256K1, xpriv | _ | ()) { + Ok(v) => v, + Err(e) => { + match e { + Error::Parse(e) => match e { + nom::Err::Incomplete(_) => println!("unexpected end of file"), + nom::Err::Error(e) | nom::Err::Failure(e) => { + for (i, e) in e.errors.iter().enumerate() { + println!("Error {i}: {e:?}"); + } } - } - }, - Error::Validation(e) => println!("{e}"), - } + }, + Error::Validation(e) => println!("{e}"), + } - std::process::exit(1); - } - }; + std::process::exit(1); + } + }; println!("Transaction details:"); if details.is_self_send() { diff --git a/psbt/src/validation.rs b/psbt/src/validation.rs index dff42e9..5b33412 100644 --- a/psbt/src/validation.rs +++ b/psbt/src/validation.rs @@ -42,10 +42,17 @@ impl TransactionDetails { } } -pub fn validate( +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ValidationEvent { + /// Validation progress percentage update. + Progress(u64), +} + +pub fn validate( i: Input, secp: &secp256k1::Secp256k1, master_key: Xpriv, + mut handle_event: F, ) -> Result> where Input: for<'a> nom::Compare<&'a [u8]> @@ -58,6 +65,7 @@ where + nom::Slice> + nom::Slice>, C: secp256k1::Signing + secp256k1::Verification, + F: FnMut(ValidationEvent), E: core::fmt::Debug + nom::error::ContextError + nom::error::ParseError @@ -65,6 +73,8 @@ where + nom::error::FromExternalError + nom::error::FromExternalError, { + handle_event(ValidationEvent::Progress(0)); + log::debug!("validating PSBT"); let (i, _) = tag::<_, Input, E>(b"psbt\xff")(i)?; @@ -81,6 +91,9 @@ where let wallet_fingerprint = master_key.fingerprint(secp); log::debug!("wallet fingerprint {:?}", wallet_fingerprint); + let total_items = input_count + output_count; + let mut processed_items = 0; + log::debug!("validating inputs"); let mut input = i.clone(); for _ in 0..input_count { @@ -95,6 +108,11 @@ where Err(Err::Error(e)) => return Err(Err::Error(E::append(i, ErrorKind::Count, e)).into()), Err(e) => return Err(e.into()), } + + processed_items += 1; + handle_event(ValidationEvent::Progress( + (processed_items * 100) / total_items, + )); } log::debug!("validating outputs"); @@ -170,6 +188,11 @@ where Err(Err::Error(e)) => return Err(Err::Error(E::append(i, ErrorKind::Count, e)).into()), Err(e) => return Err(e.into()), }; + + processed_items += 1; + handle_event(ValidationEvent::Progress( + (processed_items * 100) / total_items, + )); } log::debug!("total with total_change: {total_with_change} sats");