Skip to content

Commit

Permalink
[Draft] Persist send requestcontext
Browse files Browse the repository at this point in the history
  • Loading branch information
DanGould committed Dec 11, 2023
1 parent dc5240d commit 7fab64f
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 0 deletions.
57 changes: 57 additions & 0 deletions payjoin/src/input_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,63 @@ pub(crate) enum InputType {
Taproot,
}

#[cfg(feature = "v2")]
impl serde::Serialize for InputType {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
use InputType::*;

match self {
P2Pk => serializer.serialize_str("P2PK"),
P2Pkh => serializer.serialize_str("P2PKH"),
P2Sh => serializer.serialize_str("P2SH"),
SegWitV0 { ty, nested } =>
serializer.serialize_str(&format!("SegWitV0: type={}, nested={}", ty, nested)),
Taproot => serializer.serialize_str("Taproot"),
}
}
}

#[cfg(feature = "v2")]
impl serde::Deserialize for InputType {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
use InputType::*;

let s = String::deserialize(deserializer)?;
let mut split = s.split(':');
let ty = split.next().ok_or_else(|| serde::de::Error::custom("missing type"))?;
let ty = match ty {
"P2PK" => P2Pk,
"P2PKH" => P2Pkh,
"P2SH" => P2Sh,
"SegWitV0" => {
let ty = split
.next()
.ok_or_else(|| serde::de::Error::custom("missing SegWitV0 type"))?;
let ty = match ty {
"pubkey" => SegWitV0Type::Pubkey,
"script" => SegWitV0Type::Script,
_ => return Err(serde::de::Error::custom("invalid SegWitV0 type")),
};
let nested = split
.next()
.ok_or_else(|| serde::de::Error::custom("missing SegWitV0 nested"))?;
let nested = match nested {
"true" => true,
"false" => false,
_ => return Err(serde::de::Error::custom("invalid SegWitV0 nested")),
};
SegWitV0 { ty, nested }
}
"Taproot" => Taproot,
_ => return Err(serde::de::Error::custom("invalid type")),
};
if split.next().is_some() {
return Err(serde::de::Error::custom("unexpected trailing data"));
}
Ok(ty)
}
}

impl InputType {
pub(crate) fn from_spent_input(
txout: &TxOut,
Expand Down
122 changes: 122 additions & 0 deletions payjoin/src/send/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ use bitcoin::psbt::Psbt;
use bitcoin::{FeeRate, Script, ScriptBuf, Sequence, TxOut, Weight};
pub use error::{CreateRequestError, ValidationError};
pub(crate) use error::{InternalCreateRequestError, InternalValidationError};
use serde::ser::SerializeStruct;
use url::Url;

use crate::input_type::InputType;
Expand Down Expand Up @@ -467,12 +468,133 @@ pub struct ContextV1 {
payee: ScriptBuf,
}

#[cfg(feature = "v2")]
impl serde::Serialize for ContextV1 {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut state = serializer.serialize_struct("ContextV1", 7)?; // Adjust the number of fields accordingly
state.serialize_field("original_psbt", &self.original_psbt)?;
state.serialize_field("disable_output_substitution", &self.disable_output_substitution)?;
match &self.fee_contribution {
Some(fee_contribution) => {
// Serialize the tuple as a struct for clarity
let mut fee_contribution_state = serializer.serialize_struct("FeeContribution", 2)?;
fee_contribution_state.serialize_field("amount", &fee_contribution.0.to_sat())?;
fee_contribution_state.serialize_field("index", &fee_contribution.1)?;
fee_contribution_state.end()?;
}
None => {
// Serialize None as a null value or omit it, based on your requirements
state.serialize_field("fee_contribution", &None::<&str>)?;
}
}
state.serialize_field("min_fee_rate", &self.min_fee_rate)?;
state.serialize_field("input_type", &self.input_type)?;
state.serialize_field("sequence", &self.sequence)?;
state.serialize_field("payee", &self.payee)?;
state.end()
}
}

#[cfg(feature = "v2")]
impl<'de> serde::Deserialize<'de> for ContextV1 {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
const FIELDS: &'static [&'static str] = &[
"original_psbt",
"disable_output_substitution",
"fee_contribution",
"min_fee_rate",
"input_type",
"sequence",
"payee",
];

struct ContextV1Visitor;

impl<'de> Visitor<'de> for ContextV1Visitor {
type Value = ContextV1;

fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("struct ContextV1")
}

fn visit_map<V>(self, mut map: V) -> Result<ContextV1, V::Error>
where
V: MapAccess<'de>,
{
let mut original_psbt = None;
let mut disable_output_substitution = None;
let mut fee_contribution = None;
let mut min_fee_rate = None;
let mut input_type = None;
let mut sequence = None;
let mut payee = None;

while let Some(key) = map.next_key()? {
match key {
"original_psbt" => { original_psbt = Some(map.next_value()?); },
"disable_output_substitution" => { disable_output_substitution = Some(map.next_value()?); },
"fee_contribution" => {
let fee_contribution_option: Option<serde_json::Value> = map.next_value()?;
fee_contribution = match fee_contribution_option {
Some(value) => {
let amount = value.get("amount").and_then(|v| v.as_u64()).ok_or_else(|| de::Error::missing_field("amount"))?;
let index = value.get("index").and_then(|v| v.as_u64()).ok_or_else(|| de::Error::missing_field("index"))?;
Some((Amount::from_sat(amount as u64), index as usize))
},
None => None,
};
},
"min_fee_rate" => { min_fee_rate = Some(map.next_value()?); },
"input_type" => { input_type = Some(map.next_value()?); },
"sequence" => { sequence = Some(map.next_value()?); },
"payee" => { payee = Some(map.next_value()?); },
_ => return Err(de::Error::unknown_field(key, FIELDS)),
}
}

Ok(ContextV1 {
original_psbt: original_psbt.ok_or_else(|| de::Error::missing_field("original_psbt"))?,
disable_output_substitution: disable_output_substitution.ok_or_else(|| de::Error::missing_field("disable_output_substitution"))?,
fee_contribution,
min_fee_rate: min_fee_rate.ok_or_else(|| de::Error::missing_field("min_fee_rate"))?,
input_type: input_type.ok_or_else(|| de::Error::missing_field("input_type"))?,
sequence: sequence.ok_or_else(|| de::Error::missing_field("sequence"))?,
payee: payee.ok_or_else(|| de::Error::missing_field("payee"))?,
})
}
}

const FIELDS: &'static [&'static str] = &["original_psbt", "disable_output_substitution", "fee_contribution", /* ... other fields ... */];
deserializer.deserialize_struct("ContextV1", FIELDS, ContextV1Visitor)
}
}

#[cfg(feature = "v2")]
pub struct ContextV2 {
context_v1: ContextV1,
e: bitcoin::secp256k1::SecretKey,
}

#[cfg(feature = "v2")]
impl serde::Serialize for ContextV2 {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut state = serializer.serialize_struct("Enrolled", 4)?;
state.serialize_field("context_v1", &self.context_v1)?;
state.serialize_field("e", &self.e.secret_bytes())?;
state.end()
}
}


macro_rules! check_eq {
($proposed:expr, $original:expr, $error:ident) => {
match ($proposed, $original) {
Expand Down

0 comments on commit 7fab64f

Please sign in to comment.