From 63a83f811fe6dcd073ed1a75ae14f8e2b3d4ae54 Mon Sep 17 00:00:00 2001 From: Jesse de Wit Date: Fri, 13 Dec 2024 16:56:41 +0100 Subject: [PATCH] ensure not pay and refund can't happen together --- itest/tests/helpers.py | 14 +- itest/tests/test_cltv_limit.py | 8 +- itest/tests/test_coop_refund.py | 152 ++++++++++++++++-- swapd/src/cln/client.rs | 54 +++++-- swapd/src/lightning/client.rs | 4 + swapd/src/lnd/client.rs | 34 ++++ .../migrations/20241206_initial.up.sql | 1 + swapd/src/postgresql/swap_repository.rs | 66 +++++++- swapd/src/public_server.rs | 60 ++++++- swapd/src/swap/swap_repository.rs | 43 +++++ 10 files changed, 400 insertions(+), 36 deletions(-) diff --git a/itest/tests/helpers.py b/itest/tests/helpers.py index eb1cceb..5c9f5b3 100644 --- a/itest/tests/helpers.py +++ b/itest/tests/helpers.py @@ -26,6 +26,7 @@ "setup_user_and_swapper", "create_swap_no_invoice_extended", "create_swap_no_invoice", + "create_swap_extended", "create_swap", "whatthefee", "postgres_factory", @@ -70,15 +71,22 @@ def create_swap_no_invoice_extended(user: ClnNode, swapper: SwapdServer): def create_swap_no_invoice(user: ClnNode, swapper: SwapdServer): - address, preimage, h, _, _ = create_swap_no_invoice_extended(user, swapper) + address, preimage, h, _, _, _ = create_swap_no_invoice_extended(user, swapper) return address, preimage, h -def create_swap(user: ClnNode, swapper: SwapdServer, amount=100_000_000): - address, preimage, h = create_swap_no_invoice(user, swapper) +def create_swap_extended(user: ClnNode, swapper: SwapdServer, amount=100_000_000): + address, preimage, h, refund_privkey, claim_pubkey, lock_height = ( + create_swap_no_invoice_extended(user, swapper) + ) payment_request = user.create_invoice( amount, description="test", preimage=preimage, ) + return address, payment_request, h, refund_privkey, claim_pubkey, lock_height + + +def create_swap(user: ClnNode, swapper: SwapdServer, amount=100_000_000): + address, payment_request, h, _, _, _ = create_swap_extended(user, swapper, amount) return address, payment_request, h diff --git a/itest/tests/test_cltv_limit.py b/itest/tests/test_cltv_limit.py index 2d0daf4..f22a32b 100644 --- a/itest/tests/test_cltv_limit.py +++ b/itest/tests/test_cltv_limit.py @@ -6,7 +6,7 @@ def test_below_cltv_limit( node_factory, swapd_factory, lock_time, min_claim_blocks, min_viable_cltv ): user, swapper = setup_user_and_swapper(node_factory, swapd_factory) - address, preimage, payment_hash, _ = create_swap_no_invoice(user, swapper) + address, preimage, payment_hash = create_swap_no_invoice(user, swapper) user.bitcoin.rpc.sendtoaddress(address, 100_000 / 10**8) user.bitcoin.generate_block(1) @@ -30,7 +30,7 @@ def test_on_cltv_limit( node_factory, swapd_factory, lock_time, min_claim_blocks, min_viable_cltv ): user, swapper = setup_user_and_swapper(node_factory, swapd_factory) - address, preimage, payment_hash, _ = create_swap_no_invoice(user, swapper) + address, preimage, payment_hash = create_swap_no_invoice(user, swapper) user.bitcoin.rpc.sendtoaddress(address, 100_000 / 10**8) user.bitcoin.generate_block(1) @@ -53,7 +53,7 @@ def test_below_cltv_limit_with_router( user, router, swapper = setup_user_router_swapper( node_factory, swapd_factory, swapd_opts={"min-viable-cltv": 0} ) - address, preimage, payment_hash, _ = create_swap_no_invoice(user, swapper) + address, preimage, payment_hash = create_swap_no_invoice(user, swapper) user.bitcoin.rpc.sendtoaddress(address, 100_000 / 10**8) user.bitcoin.generate_block(1) @@ -81,7 +81,7 @@ def test_on_cltv_limit_with_router( user, router, swapper = setup_user_router_swapper( node_factory, swapd_factory, swapd_opts={"min-viable-cltv": 0} ) - address, preimage, payment_hash, _ = create_swap_no_invoice(user, swapper) + address, preimage, payment_hash = create_swap_no_invoice(user, swapper) user.bitcoin.rpc.sendtoaddress(address, 100_000 / 10**8) user.bitcoin.generate_block(1) diff --git a/itest/tests/test_coop_refund.py b/itest/tests/test_coop_refund.py index 6db4b6d..33355a0 100644 --- a/itest/tests/test_coop_refund.py +++ b/itest/tests/test_coop_refund.py @@ -1,4 +1,5 @@ from helpers import * +import grpc import musig2 import os from bitcoinutils.transactions import Locktime @@ -11,26 +12,27 @@ from decimal import Decimal -def test_cooperative_refund(node_factory, swapd_factory): - setup("regtest") - user, swapper = setup_user_and_swapper(node_factory, swapd_factory) - address, payment_request, h, refund_privkey, claim_pubkey, lock_height = ( - create_swap_no_invoice_extended(user, swapper) - ) - to_spend_txid = user.bitcoin.rpc.sendtoaddress(address, 100_000 / 10**8) - user.bitcoin.generate_block(1) +def coop_refund( + user, + swapper, + address, + h, + refund_privkey, + claim_pubkey, + lock_height, + to_spend_txid, + refund_amount, +): to_spend_tx = user.bitcoin.rpc.getrawtransaction(to_spend_txid, True) to_spend_output_index = 0 if to_spend_tx["vout"][1]["value"] == Decimal("0.00100000"): to_spend_output_index = 1 - wait_for(lambda: len(swapper.internal_rpc.get_swap(address).outputs) > 0) - refund_address = P2trAddress(user.new_address()) extra_in = os.urandom(32) - tx_in = TxInput(to_spend_txid, to_spend_output_index) - tx_out = TxOutput(99_000, refund_address.to_script_pub_key()) + tx_in = TxInput(to_spend_txid, to_spend_output_index, sequence="00000001") + tx_out = TxOutput(refund_amount, refund_address.to_script_pub_key()) tx = Transaction([tx_in], [tx_out], has_segwit=True, witnesses=[TxWitnessInput([])]) tx_digest = tx.get_transaction_taproot_digest( 0, [P2trAddress(address).to_script_pub_key()], [100_000] @@ -81,8 +83,132 @@ def test_cooperative_refund(node_factory, swapd_factory): sig_agg = musig2.partial_sig_agg([their_partial_sig, our_partial_sig], session) tx.witnesses[0].stack.append(sig_agg.hex()) + return tx.to_hex() + + +def test_cooperative_refund_success(node_factory, swapd_factory): + setup("regtest") + user, swapper = setup_user_and_swapper(node_factory, swapd_factory) + address, _, h, refund_privkey, claim_pubkey, lock_height = ( + create_swap_no_invoice_extended(user, swapper) + ) + to_spend_txid = user.bitcoin.rpc.sendtoaddress(address, 100_000 / 10**8) + user.bitcoin.generate_block(1) + + wait_for(lambda: len(swapper.internal_rpc.get_swap(address).outputs) > 0) + + tx = coop_refund( + user, + swapper, + address, + h, + refund_privkey, + claim_pubkey, + lock_height, + to_spend_txid, + 99_000, + ) expected_utxos = len(user.list_utxos()) + 1 - user.bitcoin.rpc.sendrawtransaction(tx.to_hex()) + user.bitcoin.rpc.sendrawtransaction(tx) user.bitcoin.generate_block(1) wait_for(lambda: len(user.list_utxos()) == expected_utxos) + + +def test_cooperative_refund_rbf_success(node_factory, swapd_factory): + setup("regtest") + user, swapper = setup_user_and_swapper(node_factory, swapd_factory) + address, _, h, refund_privkey, claim_pubkey, lock_height = ( + create_swap_no_invoice_extended(user, swapper) + ) + to_spend_txid = user.bitcoin.rpc.sendtoaddress(address, 100_000 / 10**8) + user.bitcoin.generate_block(1) + + wait_for(lambda: len(swapper.internal_rpc.get_swap(address).outputs) > 0) + + expected_utxos = len(user.list_utxos()) + 1 + tx = coop_refund( + user, + swapper, + address, + h, + refund_privkey, + claim_pubkey, + lock_height, + to_spend_txid, + 99_000, + ) + user.bitcoin.rpc.sendrawtransaction(tx) + tx = coop_refund( + user, + swapper, + address, + h, + refund_privkey, + claim_pubkey, + lock_height, + to_spend_txid, + 98_000, + ) + user.bitcoin.rpc.sendrawtransaction(tx) + user.bitcoin.generate_block(1) + wait_for(lambda: len(user.list_utxos()) == expected_utxos) + + +def test_cooperative_refund_then_pay_failure(node_factory, swapd_factory): + setup("regtest") + user, swapper = setup_user_and_swapper(node_factory, swapd_factory) + address, payment_request, h, refund_privkey, claim_pubkey, lock_height = ( + create_swap_extended(user, swapper) + ) + to_spend_txid = user.bitcoin.rpc.sendtoaddress(address, 100_000 / 10**8) + user.bitcoin.generate_block(1) + + wait_for(lambda: len(swapper.internal_rpc.get_swap(address).outputs) > 0) + + coop_refund( + user, + swapper, + address, + h, + refund_privkey, + claim_pubkey, + lock_height, + to_spend_txid, + 99_000, + ) + try: + swapper.rpc.pay_swap(payment_request) + assert False + except grpc._channel._InactiveRpcError as e: + assert e.details() == "swap is locked" + + +def test_pay_then_cooperative_refund_failure(node_factory, swapd_factory): + setup("regtest") + user, swapper = setup_user_and_swapper(node_factory, swapd_factory) + address, payment_request, h, refund_privkey, claim_pubkey, lock_height = ( + create_swap_extended(user, swapper) + ) + to_spend_txid = user.bitcoin.rpc.sendtoaddress(address, 100_000 / 10**8) + user.bitcoin.generate_block(1) + + wait_for(lambda: len(swapper.internal_rpc.get_swap(address).outputs) > 0) + + swapper.rpc.pay_swap(payment_request) + + try: + coop_refund( + user, + swapper, + address, + h, + refund_privkey, + claim_pubkey, + lock_height, + to_spend_txid, + 99_000, + ) + assert False + except grpc._channel._InactiveRpcError as e: + assert e.details() == "swap is locked" diff --git a/swapd/src/cln/client.rs b/swapd/src/cln/client.rs index 3bde9dd..b07f7e6 100644 --- a/swapd/src/cln/client.rs +++ b/swapd/src/cln/client.rs @@ -2,13 +2,13 @@ use bitcoin::{ hashes::{sha256, Hash}, Network, }; -use futures::{stream::FuturesUnordered, StreamExt}; +use futures::{future::join_all, stream::FuturesUnordered, StreamExt}; use regex::Regex; use thiserror::Error; use tokio::join; use tonic::{ transport::{Certificate, Channel, ClientTlsConfig, Identity, Uri}, - Request, Status, + Status, }; use tracing::{debug, error, instrument, warn}; @@ -72,10 +72,10 @@ impl LightningClient for Client { ) -> Result, LightningError> { let mut client = self.get_client().await?; let resp = client - .list_pays(Request::new(ListpaysRequest { + .list_pays(ListpaysRequest { payment_hash: Some(hash.as_byte_array().to_vec()), ..Default::default() - })) + }) .await?; let result = match resp @@ -97,12 +97,42 @@ impl LightningClient for Client { Ok(result) } + #[instrument(level = "trace", skip(self))] + async fn has_pending_or_complete_payment( + &self, + hash: &sha256::Hash, + ) -> Result { + let mut client = self.get_client().await?; + let mut clone = client.clone(); + let fut1 = clone.list_send_pays(ListsendpaysRequest { + payment_hash: Some(hash.as_byte_array().to_vec()), + status: Some(ListsendpaysStatus::Pending.into()), + ..Default::default() + }); + + let fut2 = client.list_send_pays(ListsendpaysRequest { + payment_hash: Some(hash.as_byte_array().to_vec()), + status: Some(ListsendpaysStatus::Complete.into()), + ..Default::default() + }); + + let resps = join_all([fut1, fut2]).await; + for resp in resps { + let resp = resp?; + if !resp.into_inner().payments.is_empty() { + return Ok(true); + } + } + + Ok(false) + } + #[instrument(level = "trace", skip(self))] async fn pay(&self, request: PaymentRequest) -> Result { let mut client = self.get_client().await?; let pay_resp = match client - .pay(Request::new(PayRequest { + .pay(PayRequest { label: Some(request.label), bolt11: request.bolt11, maxfee: Some(Amount { @@ -111,7 +141,7 @@ impl LightningClient for Client { retry_for: Some(request.timeout_seconds as u32), maxdelay: Some(request.cltv_limit), ..Default::default() - })) + }) .await { Ok(resp) => resp, @@ -167,22 +197,22 @@ async fn wait_payment<'a>( payment_hash: sha256::Hash, ) -> Result, LightningError> { let mut client2 = client.clone(); - let completed_payments_fut = client.list_send_pays(Request::new(ListsendpaysRequest { + let completed_payments_fut = client.list_send_pays(ListsendpaysRequest { payment_hash: Some(payment_hash.as_byte_array().to_vec()), bolt11: None, index: None, limit: None, start: None, status: Some(ListsendpaysStatus::Complete.into()), - })); - let pending_payments_fut = client2.list_send_pays(Request::new(ListsendpaysRequest { + }); + let pending_payments_fut = client2.list_send_pays(ListsendpaysRequest { payment_hash: Some(payment_hash.as_byte_array().to_vec()), bolt11: None, index: None, limit: None, start: None, status: Some(ListsendpaysStatus::Pending.into()), - })); + }); let (completed_payments, pending_payments) = join!(completed_payments_fut, pending_payments_fut); let (completed_payments, pending_payments) = (completed_payments?, pending_payments?); @@ -206,12 +236,12 @@ async fn wait_payment<'a>( let mut client = client.clone(); tasks.push(async move { client - .wait_send_pay(Request::new(WaitsendpayRequest { + .wait_send_pay(WaitsendpayRequest { groupid: Some(payment.groupid), partid: payment.partid, payment_hash: payment_hash.as_byte_array().to_vec(), timeout: None, - })) + }) .await }); } diff --git a/swapd/src/lightning/client.rs b/swapd/src/lightning/client.rs index 1125189..217cb4e 100644 --- a/swapd/src/lightning/client.rs +++ b/swapd/src/lightning/client.rs @@ -36,5 +36,9 @@ pub trait LightningClient { &self, hash: sha256::Hash, ) -> Result, LightningError>; + async fn has_pending_or_complete_payment( + &self, + hash: &sha256::Hash, + ) -> Result; async fn pay(&self, request: PaymentRequest) -> Result; } diff --git a/swapd/src/lnd/client.rs b/swapd/src/lnd/client.rs index 378d791..6e4647c 100644 --- a/swapd/src/lnd/client.rs +++ b/swapd/src/lnd/client.rs @@ -189,6 +189,40 @@ where Ok(Some(PreimageResult { preimage, label })) } + #[instrument(level = "trace", skip(self))] + async fn has_pending_or_complete_payment( + &self, + hash: &sha256::Hash, + ) -> Result { + let mut router_client = self.get_router_client().await?; + let res = router_client + .track_payment_v2(TrackPaymentRequest { + payment_hash: hash.as_byte_array().to_vec(), + no_inflight_updates: false, + }) + .await; + let mut stream = match res { + Ok(res) => res.into_inner(), + Err(e) => { + return match e.code() { + tonic::Code::NotFound => Ok(false), + _ => Err(LightningError::General(e)), + } + } + }; + let payment = match stream.message().await? { + Some(message) => message, + None => return Err(LightningError::ConnectionFailed), + }; + Ok(match payment.status() { + PaymentStatus::Unknown => true, + PaymentStatus::InFlight => true, + PaymentStatus::Succeeded => true, + PaymentStatus::Failed => false, + PaymentStatus::Initiated => true, + }) + } + #[instrument(level = "trace", skip(self))] async fn pay( &self, diff --git a/swapd/src/postgresql/migrations/20241206_initial.up.sql b/swapd/src/postgresql/migrations/20241206_initial.up.sql index 8346e5a..b458112 100644 --- a/swapd/src/postgresql/migrations/20241206_initial.up.sql +++ b/swapd/src/postgresql/migrations/20241206_initial.up.sql @@ -10,6 +10,7 @@ CREATE TABLE swaps ( claim_script BYTEA NOT NULL, creation_time BIGINT NOT NULL, lock_height BIGINT NOT NULL, + locked INTEGER NULL, payment_hash BYTEA NOT NULL PRIMARY KEY, preimage BYTEA NULL, refund_pubkey BYTEA NOT NULL, diff --git a/swapd/src/postgresql/swap_repository.rs b/swapd/src/postgresql/swap_repository.rs index 88435d4..1d563d2 100644 --- a/swapd/src/postgresql/swap_repository.rs +++ b/swapd/src/postgresql/swap_repository.rs @@ -17,9 +17,9 @@ use tracing::instrument; use crate::{ lightning::PaymentResult, swap::{ - AddPaymentResultError, GetPaidUtxosError, GetSwapsError, PaidOutpoint, PaymentAttempt, - Swap, SwapPersistenceError, SwapPrivateData, SwapPublicData, SwapState, - SwapStatePaidOutpoints, + AddPaymentResultError, GetPaidUtxosError, GetSwapsError, LockSwapError, LockType, + PaidOutpoint, PaymentAttempt, Swap, SwapPersistenceError, SwapPrivateData, SwapPublicData, + SwapState, SwapStatePaidOutpoints, }, }; @@ -356,6 +356,60 @@ impl crate::swap::SwapRepository for SwapRepository { Ok(result) } + + async fn lock_swap( + &self, + swap: &Swap, + lock_type: LockType, + ) -> Result, LockSwapError> { + let mut tx = self.pool.begin().await?; + let extra_where = match lock_type { + LockType::Pay => "", + LockType::Refund => " OR locked = $2", + }; + let query = format!(" + UPDATE swaps x + SET locked = $2 + FROM (SELECT payment_hash, locked FROM swaps WHERE payment_hash = $1 AND locked IS NULL{} FOR UPDATE) y + WHERE x.payment_hash = y.payment_hash + RETURNING y.locked AS old_locked", extra_where); + + let lock_type: i32 = lock_type.into(); + let maybe_row = sqlx::query(&query) + .bind(swap.public.hash.as_byte_array().to_vec()) + .bind(lock_type) + .fetch_optional(&mut *tx) + .await?; + + let row = match maybe_row { + Some(row) => row, + None => return Err(LockSwapError::AlreadyLocked), + }; + + let old_lock_type: Option = row.try_get(0)?; + let old_lock_type: Option = match old_lock_type { + Some(old_lock_type) => Some(old_lock_type.try_into()?), + None => None, + }; + tx.commit().await?; + Ok(old_lock_type) + } + + async fn unlock_swap(&self, swap: &Swap) -> Result<(), LockSwapError> { + let query = String::from( + " + UPDATE swaps + SET locked = NULL + WHERE payment_hash = $1", + ); + + sqlx::query(&query) + .bind(swap.public.hash.as_byte_array().to_vec()) + .execute(&*self.pool) + .await?; + + Ok(()) + } } fn swap_state_fields(prefix: &str) -> String { @@ -441,6 +495,12 @@ impl From for GetPaidUtxosError { } } +impl From for LockSwapError { + fn from(value: sqlx::Error) -> Self { + LockSwapError::General(Box::new(value)) + } +} + impl From for GetPaidUtxosError { fn from(value: bitcoin::hashes::hex::HexToArrayError) -> Self { GetPaidUtxosError::General(Box::new(value)) diff --git a/swapd/src/public_server.rs b/swapd/src/public_server.rs index 67ff3d0..8f64167 100644 --- a/swapd/src/public_server.rs +++ b/swapd/src/public_server.rs @@ -23,7 +23,7 @@ use crate::{ }, chain_filter::ChainFilterService, lightning::{LightningClient, LightningError, PaymentRequest, PaymentResult}, - swap::{ClaimableUtxo, PaymentAttempt}, + swap::{ClaimableUtxo, LockSwapError, LockType, PaymentAttempt}, }; use crate::swap::{ @@ -358,6 +358,20 @@ where })? .as_nanos(); let label = format!("{}-{}", hash, unix_ns_now); + match self + .swap_repository + .lock_swap(&swap_state.swap, LockType::Pay) + .await + { + Ok(_) => {} + Err(LockSwapError::AlreadyLocked) => { + return Err(Status::failed_precondition("swap is locked")) + } + Err(e) => { + error!("failed to lock swap for payment: {:?}", e); + return Err(Status::internal("internal error")); + } + }; self.swap_repository .add_payment_attempt(&PaymentAttempt { amount_msat, @@ -410,6 +424,7 @@ where } }; + let _ = self.swap_repository.unlock_swap(&swap_state.swap).await; let response = match pay_result { PaymentResult::Success { preimage: _ } => { info!( @@ -508,6 +523,49 @@ where error!("failed to sign refund transaction: {:?}", e); Status::internal("internal error") })?; + + // Ensure this swap is not used for paying out at the moment, also + // prevent payouts from happening in the future. Note this will never be + // unlocked, because the user may steal funds if it's ever paid out. The + // user _can_ create a new refund later, however. + let was_locked = match self + .swap_repository + .lock_swap(&swap.swap, LockType::Refund) + .await + { + Ok(old_lock_type) => old_lock_type.is_some(), + Err(LockSwapError::AlreadyLocked) => { + return Err(Status::failed_precondition("swap is locked")) + } + Err(e) => { + error!("failed to lock swap for refund: {:?}", e); + return Err(Status::internal("internal error")); + } + }; + + match self + .lightning_client + .has_pending_or_complete_payment(&swap.swap.public.hash) + .await + { + Ok(false) => {} + Ok(true) => { + if !was_locked { + let _ = self.swap_repository.unlock_swap(&swap.swap).await; + } + + return Err(Status::failed_precondition("swap is locked")); + } + Err(e) => { + error!("failed to check for pending or complete payment: {:?}", e); + if !was_locked { + let _ = self.swap_repository.unlock_swap(&swap.swap).await; + } + + return Err(Status::internal("internal error")); + } + } + Ok(Response::new(RefundSwapResponse { partial_signature: partial_signature.serialize().to_vec(), pub_nonce: our_pub_nonce.serialize().to_vec(), diff --git a/swapd/src/swap/swap_repository.rs b/swapd/src/swap/swap_repository.rs index 3fc5974..377c3d3 100644 --- a/swapd/src/swap/swap_repository.rs +++ b/swapd/src/swap/swap_repository.rs @@ -20,6 +20,43 @@ pub enum AddPaymentResultError { General(Box), } +#[derive(Debug, PartialEq, Eq)] +pub enum LockType { + Pay, + Refund, +} + +impl TryFrom for LockType { + type Error = LockSwapError; + + fn try_from(value: i32) -> Result { + Ok(match value { + 0 => LockType::Pay, + 1 => LockType::Refund, + _ => return Err(LockSwapError::InvalidLockType), + }) + } +} + +impl From for i32 { + fn from(value: LockType) -> i32 { + match value { + LockType::Pay => 0, + LockType::Refund => 1, + } + } +} + +#[derive(Debug, Error)] +pub enum LockSwapError { + #[error("already locked")] + AlreadyLocked, + #[error("invalid lock type")] + InvalidLockType, + #[error("{0}")] + General(Box), +} + #[derive(Debug, Error)] pub enum GetPaidUtxosError { #[error("{0}")] @@ -86,4 +123,10 @@ pub trait SwapRepository { &self, addresses: &[Address], ) -> Result, GetSwapsError>; + async fn lock_swap( + &self, + swap: &Swap, + lock_type: LockType, + ) -> Result, LockSwapError>; + async fn unlock_swap(&self, swap: &Swap) -> Result<(), LockSwapError>; }