Skip to content

Commit

Permalink
ensure not pay and refund can't happen together
Browse files Browse the repository at this point in the history
  • Loading branch information
JssDWt committed Dec 13, 2024
1 parent ad73975 commit 63a83f8
Show file tree
Hide file tree
Showing 10 changed files with 400 additions and 36 deletions.
14 changes: 11 additions & 3 deletions itest/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"setup_user_and_swapper",
"create_swap_no_invoice_extended",
"create_swap_no_invoice",
"create_swap_extended",
"create_swap",
"whatthefee",
"postgres_factory",
Expand Down Expand Up @@ -70,15 +71,22 @@ def create_swap_no_invoice_extended(user: ClnNode, swapper: SwapdServer):


def create_swap_no_invoice(user: ClnNode, swapper: SwapdServer):
address, preimage, h, _, _ = create_swap_no_invoice_extended(user, swapper)
address, preimage, h, _, _, _ = create_swap_no_invoice_extended(user, swapper)
return address, preimage, h


def create_swap(user: ClnNode, swapper: SwapdServer, amount=100_000_000):
address, preimage, h = create_swap_no_invoice(user, swapper)
def create_swap_extended(user: ClnNode, swapper: SwapdServer, amount=100_000_000):
address, preimage, h, refund_privkey, claim_pubkey, lock_height = (
create_swap_no_invoice_extended(user, swapper)
)
payment_request = user.create_invoice(
amount,
description="test",
preimage=preimage,
)
return address, payment_request, h, refund_privkey, claim_pubkey, lock_height


def create_swap(user: ClnNode, swapper: SwapdServer, amount=100_000_000):
address, payment_request, h, _, _, _ = create_swap_extended(user, swapper, amount)
return address, payment_request, h
8 changes: 4 additions & 4 deletions itest/tests/test_cltv_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ def test_below_cltv_limit(
node_factory, swapd_factory, lock_time, min_claim_blocks, min_viable_cltv
):
user, swapper = setup_user_and_swapper(node_factory, swapd_factory)
address, preimage, payment_hash, _ = create_swap_no_invoice(user, swapper)
address, preimage, payment_hash = create_swap_no_invoice(user, swapper)
user.bitcoin.rpc.sendtoaddress(address, 100_000 / 10**8)
user.bitcoin.generate_block(1)

Expand All @@ -30,7 +30,7 @@ def test_on_cltv_limit(
node_factory, swapd_factory, lock_time, min_claim_blocks, min_viable_cltv
):
user, swapper = setup_user_and_swapper(node_factory, swapd_factory)
address, preimage, payment_hash, _ = create_swap_no_invoice(user, swapper)
address, preimage, payment_hash = create_swap_no_invoice(user, swapper)
user.bitcoin.rpc.sendtoaddress(address, 100_000 / 10**8)
user.bitcoin.generate_block(1)

Expand All @@ -53,7 +53,7 @@ def test_below_cltv_limit_with_router(
user, router, swapper = setup_user_router_swapper(
node_factory, swapd_factory, swapd_opts={"min-viable-cltv": 0}
)
address, preimage, payment_hash, _ = create_swap_no_invoice(user, swapper)
address, preimage, payment_hash = create_swap_no_invoice(user, swapper)
user.bitcoin.rpc.sendtoaddress(address, 100_000 / 10**8)
user.bitcoin.generate_block(1)

Expand Down Expand Up @@ -81,7 +81,7 @@ def test_on_cltv_limit_with_router(
user, router, swapper = setup_user_router_swapper(
node_factory, swapd_factory, swapd_opts={"min-viable-cltv": 0}
)
address, preimage, payment_hash, _ = create_swap_no_invoice(user, swapper)
address, preimage, payment_hash = create_swap_no_invoice(user, swapper)
user.bitcoin.rpc.sendtoaddress(address, 100_000 / 10**8)
user.bitcoin.generate_block(1)

Expand Down
152 changes: 139 additions & 13 deletions itest/tests/test_coop_refund.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from helpers import *
import grpc
import musig2
import os
from bitcoinutils.transactions import Locktime
Expand All @@ -11,26 +12,27 @@
from decimal import Decimal


def test_cooperative_refund(node_factory, swapd_factory):
setup("regtest")
user, swapper = setup_user_and_swapper(node_factory, swapd_factory)
address, payment_request, h, refund_privkey, claim_pubkey, lock_height = (
create_swap_no_invoice_extended(user, swapper)
)
to_spend_txid = user.bitcoin.rpc.sendtoaddress(address, 100_000 / 10**8)
user.bitcoin.generate_block(1)
def coop_refund(
user,
swapper,
address,
h,
refund_privkey,
claim_pubkey,
lock_height,
to_spend_txid,
refund_amount,
):
to_spend_tx = user.bitcoin.rpc.getrawtransaction(to_spend_txid, True)
to_spend_output_index = 0
if to_spend_tx["vout"][1]["value"] == Decimal("0.00100000"):
to_spend_output_index = 1

wait_for(lambda: len(swapper.internal_rpc.get_swap(address).outputs) > 0)

refund_address = P2trAddress(user.new_address())
extra_in = os.urandom(32)

tx_in = TxInput(to_spend_txid, to_spend_output_index)
tx_out = TxOutput(99_000, refund_address.to_script_pub_key())
tx_in = TxInput(to_spend_txid, to_spend_output_index, sequence="00000001")
tx_out = TxOutput(refund_amount, refund_address.to_script_pub_key())
tx = Transaction([tx_in], [tx_out], has_segwit=True, witnesses=[TxWitnessInput([])])
tx_digest = tx.get_transaction_taproot_digest(
0, [P2trAddress(address).to_script_pub_key()], [100_000]
Expand Down Expand Up @@ -81,8 +83,132 @@ def test_cooperative_refund(node_factory, swapd_factory):

sig_agg = musig2.partial_sig_agg([their_partial_sig, our_partial_sig], session)
tx.witnesses[0].stack.append(sig_agg.hex())
return tx.to_hex()


def test_cooperative_refund_success(node_factory, swapd_factory):
setup("regtest")
user, swapper = setup_user_and_swapper(node_factory, swapd_factory)
address, _, h, refund_privkey, claim_pubkey, lock_height = (
create_swap_no_invoice_extended(user, swapper)
)
to_spend_txid = user.bitcoin.rpc.sendtoaddress(address, 100_000 / 10**8)
user.bitcoin.generate_block(1)

wait_for(lambda: len(swapper.internal_rpc.get_swap(address).outputs) > 0)

tx = coop_refund(
user,
swapper,
address,
h,
refund_privkey,
claim_pubkey,
lock_height,
to_spend_txid,
99_000,
)

expected_utxos = len(user.list_utxos()) + 1
user.bitcoin.rpc.sendrawtransaction(tx.to_hex())
user.bitcoin.rpc.sendrawtransaction(tx)
user.bitcoin.generate_block(1)
wait_for(lambda: len(user.list_utxos()) == expected_utxos)


def test_cooperative_refund_rbf_success(node_factory, swapd_factory):
setup("regtest")
user, swapper = setup_user_and_swapper(node_factory, swapd_factory)
address, _, h, refund_privkey, claim_pubkey, lock_height = (
create_swap_no_invoice_extended(user, swapper)
)
to_spend_txid = user.bitcoin.rpc.sendtoaddress(address, 100_000 / 10**8)
user.bitcoin.generate_block(1)

wait_for(lambda: len(swapper.internal_rpc.get_swap(address).outputs) > 0)

expected_utxos = len(user.list_utxos()) + 1
tx = coop_refund(
user,
swapper,
address,
h,
refund_privkey,
claim_pubkey,
lock_height,
to_spend_txid,
99_000,
)
user.bitcoin.rpc.sendrawtransaction(tx)
tx = coop_refund(
user,
swapper,
address,
h,
refund_privkey,
claim_pubkey,
lock_height,
to_spend_txid,
98_000,
)
user.bitcoin.rpc.sendrawtransaction(tx)
user.bitcoin.generate_block(1)
wait_for(lambda: len(user.list_utxos()) == expected_utxos)


def test_cooperative_refund_then_pay_failure(node_factory, swapd_factory):
setup("regtest")
user, swapper = setup_user_and_swapper(node_factory, swapd_factory)
address, payment_request, h, refund_privkey, claim_pubkey, lock_height = (
create_swap_extended(user, swapper)
)
to_spend_txid = user.bitcoin.rpc.sendtoaddress(address, 100_000 / 10**8)
user.bitcoin.generate_block(1)

wait_for(lambda: len(swapper.internal_rpc.get_swap(address).outputs) > 0)

coop_refund(
user,
swapper,
address,
h,
refund_privkey,
claim_pubkey,
lock_height,
to_spend_txid,
99_000,
)
try:
swapper.rpc.pay_swap(payment_request)
assert False
except grpc._channel._InactiveRpcError as e:
assert e.details() == "swap is locked"


def test_pay_then_cooperative_refund_failure(node_factory, swapd_factory):
setup("regtest")
user, swapper = setup_user_and_swapper(node_factory, swapd_factory)
address, payment_request, h, refund_privkey, claim_pubkey, lock_height = (
create_swap_extended(user, swapper)
)
to_spend_txid = user.bitcoin.rpc.sendtoaddress(address, 100_000 / 10**8)
user.bitcoin.generate_block(1)

wait_for(lambda: len(swapper.internal_rpc.get_swap(address).outputs) > 0)

swapper.rpc.pay_swap(payment_request)

try:
coop_refund(
user,
swapper,
address,
h,
refund_privkey,
claim_pubkey,
lock_height,
to_spend_txid,
99_000,
)
assert False
except grpc._channel._InactiveRpcError as e:
assert e.details() == "swap is locked"
54 changes: 42 additions & 12 deletions swapd/src/cln/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ use bitcoin::{
hashes::{sha256, Hash},
Network,
};
use futures::{stream::FuturesUnordered, StreamExt};
use futures::{future::join_all, stream::FuturesUnordered, StreamExt};
use regex::Regex;
use thiserror::Error;
use tokio::join;
use tonic::{
transport::{Certificate, Channel, ClientTlsConfig, Identity, Uri},
Request, Status,
Status,
};
use tracing::{debug, error, instrument, warn};

Expand Down Expand Up @@ -72,10 +72,10 @@ impl LightningClient for Client {
) -> Result<Option<PreimageResult>, LightningError> {
let mut client = self.get_client().await?;
let resp = client
.list_pays(Request::new(ListpaysRequest {
.list_pays(ListpaysRequest {
payment_hash: Some(hash.as_byte_array().to_vec()),
..Default::default()
}))
})
.await?;

let result = match resp
Expand All @@ -97,12 +97,42 @@ impl LightningClient for Client {
Ok(result)
}

#[instrument(level = "trace", skip(self))]
async fn has_pending_or_complete_payment(
&self,
hash: &sha256::Hash,
) -> Result<bool, LightningError> {
let mut client = self.get_client().await?;
let mut clone = client.clone();
let fut1 = clone.list_send_pays(ListsendpaysRequest {
payment_hash: Some(hash.as_byte_array().to_vec()),
status: Some(ListsendpaysStatus::Pending.into()),
..Default::default()
});

let fut2 = client.list_send_pays(ListsendpaysRequest {
payment_hash: Some(hash.as_byte_array().to_vec()),
status: Some(ListsendpaysStatus::Complete.into()),
..Default::default()
});

let resps = join_all([fut1, fut2]).await;
for resp in resps {
let resp = resp?;
if !resp.into_inner().payments.is_empty() {
return Ok(true);
}
}

Ok(false)
}

#[instrument(level = "trace", skip(self))]
async fn pay(&self, request: PaymentRequest) -> Result<PaymentResult, LightningError> {
let mut client = self.get_client().await?;

let pay_resp = match client
.pay(Request::new(PayRequest {
.pay(PayRequest {
label: Some(request.label),
bolt11: request.bolt11,
maxfee: Some(Amount {
Expand All @@ -111,7 +141,7 @@ impl LightningClient for Client {
retry_for: Some(request.timeout_seconds as u32),
maxdelay: Some(request.cltv_limit),
..Default::default()
}))
})
.await
{
Ok(resp) => resp,
Expand Down Expand Up @@ -167,22 +197,22 @@ async fn wait_payment<'a>(
payment_hash: sha256::Hash,
) -> Result<Option<PaymentResult>, LightningError> {
let mut client2 = client.clone();
let completed_payments_fut = client.list_send_pays(Request::new(ListsendpaysRequest {
let completed_payments_fut = client.list_send_pays(ListsendpaysRequest {
payment_hash: Some(payment_hash.as_byte_array().to_vec()),
bolt11: None,
index: None,
limit: None,
start: None,
status: Some(ListsendpaysStatus::Complete.into()),
}));
let pending_payments_fut = client2.list_send_pays(Request::new(ListsendpaysRequest {
});
let pending_payments_fut = client2.list_send_pays(ListsendpaysRequest {
payment_hash: Some(payment_hash.as_byte_array().to_vec()),
bolt11: None,
index: None,
limit: None,
start: None,
status: Some(ListsendpaysStatus::Pending.into()),
}));
});
let (completed_payments, pending_payments) =
join!(completed_payments_fut, pending_payments_fut);
let (completed_payments, pending_payments) = (completed_payments?, pending_payments?);
Expand All @@ -206,12 +236,12 @@ async fn wait_payment<'a>(
let mut client = client.clone();
tasks.push(async move {
client
.wait_send_pay(Request::new(WaitsendpayRequest {
.wait_send_pay(WaitsendpayRequest {
groupid: Some(payment.groupid),
partid: payment.partid,
payment_hash: payment_hash.as_byte_array().to_vec(),
timeout: None,
}))
})
.await
});
}
Expand Down
4 changes: 4 additions & 0 deletions swapd/src/lightning/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,9 @@ pub trait LightningClient {
&self,
hash: sha256::Hash,
) -> Result<Option<PreimageResult>, LightningError>;
async fn has_pending_or_complete_payment(
&self,
hash: &sha256::Hash,
) -> Result<bool, LightningError>;
async fn pay(&self, request: PaymentRequest) -> Result<PaymentResult, LightningError>;
}
Loading

0 comments on commit 63a83f8

Please sign in to comment.