Skip to content

Commit

Permalink
Correctly fail blinded payments after restart (#2704)
Browse files Browse the repository at this point in the history
When restarting, we weren't checking whether it was using blinded paths.
If we were an intermediate node in the blinded path, we were incorrectly
returning a normal failure: it should be ok, since the introduction node
is supposed to translate those failures, but it's safer to assume that
they don't.
  • Loading branch information
t-bast authored Jul 3, 2023
1 parent 9db0063 commit 4c98e1c
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import fr.acinq.eclair.db._
import fr.acinq.eclair.payment.Monitoring.Tags
import fr.acinq.eclair.payment.{ChannelPaymentRelayed, IncomingPaymentPacket, PaymentFailed, PaymentSent}
import fr.acinq.eclair.transactions.DirectedHtlc.outgoing
import fr.acinq.eclair.wire.protocol.{FailureMessage, TemporaryNodeFailure, UpdateAddHtlc}
import fr.acinq.eclair.wire.protocol.{FailureMessage, InvalidOnionBlinding, TemporaryNodeFailure, UpdateAddHtlc}
import fr.acinq.eclair.{CustomCommitmentsPlugin, Feature, Features, Logs, MilliSatoshiLong, NodeParams, TimestampMilli}

import scala.concurrent.Promise
Expand Down Expand Up @@ -124,7 +124,16 @@ class PostRestartHtlcCleaner(nodeParams: NodeParams, register: ActorRef, initial
Metrics.Resolved.withTag(Tags.Success, value = false).withTag(Metrics.Relayed, value = false).increment()
if (e.currentState != CLOSING && e.currentState != CLOSED) {
log.info(s"failing not relayed htlc=$htlc")
channel ! CMD_FAIL_HTLC(htlc.id, Right(TemporaryNodeFailure()), commit = true)
val cmd = htlc.blinding_opt match {
case Some(_) =>
// The incoming HTLC contains a blinding point: we must be an intermediate node in a blinded path,
// and we thus need to return an update_fail_malformed_htlc.
val failure = InvalidOnionBlinding(ByteVector32.Zeroes)
CMD_FAIL_MALFORMED_HTLC(htlc.id, failure.onionHash, failure.code, commit = true)
case None =>
CMD_FAIL_HTLC(htlc.id, Right(TemporaryNodeFailure()), commit = true)
}
channel ! cmd
} else {
log.info(s"would fail but upstream channel is closed for htlc=$htlc")
}
Expand Down Expand Up @@ -244,7 +253,16 @@ class PostRestartHtlcCleaner(nodeParams: NodeParams, register: ActorRef, initial
case Origin.ChannelRelayedCold(originChannelId, originHtlcId, _, _) =>
log.warning(s"payment failed for paymentHash=${failedHtlc.paymentHash}: failing 1 HTLC upstream")
Metrics.Resolved.withTag(Tags.Success, value = false).withTag(Metrics.Relayed, value = true).increment()
val cmd = ChannelRelay.translateRelayFailure(originHtlcId, fail)
val cmd = failedHtlc.blinding_opt match {
case Some(_) =>
// If we are inside a blinded path, we cannot know whether we're the introduction node or not since
// we don't have access to the incoming onion: to avoid leaking information, we act as if we were an
// intermediate node and send invalid_onion_blinding in an update_fail_malformed_htlc message.
val failure = InvalidOnionBlinding(ByteVector32.Zeroes)
CMD_FAIL_MALFORMED_HTLC(originHtlcId, failure.onionHash, failure.code, commit = true)
case None =>
ChannelRelay.translateRelayFailure(originHtlcId, fail)
}
PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, originChannelId, cmd)
case Origin.TrampolineRelayedCold(origins) =>
log.warning(s"payment failed for paymentHash=${failedHtlc.paymentHash}: failing ${origins.length} HTLCs upstream")
Expand Down Expand Up @@ -336,7 +354,7 @@ object PostRestartHtlcCleaner {
case _ => None
})

def decryptedIncomingHtlcs(paymentsDb: IncomingPaymentsDb): PartialFunction[Either[FailureMessage, IncomingPaymentPacket], IncomingHtlc] = {
private def decryptedIncomingHtlcs(paymentsDb: IncomingPaymentsDb): PartialFunction[Either[FailureMessage, IncomingPaymentPacket], IncomingHtlc] = {
// When we're not the final recipient, we'll only consider HTLCs that aren't relayed downstream, so no need to look for a preimage.
case Right(p: IncomingPaymentPacket.ChannelRelayPacket) => IncomingHtlc(p.add, None)
case Right(p: IncomingPaymentPacket.NodeRelayPacket) => IncomingHtlc(p.add, None)
Expand All @@ -361,7 +379,7 @@ object PostRestartHtlcCleaner {
private def isPendingUpstream(channelId: ByteVector32, htlcId: Long, htlcsIn: Seq[IncomingHtlc]): Boolean =
htlcsIn.exists(htlc => htlc.add.channelId == channelId && htlc.add.id == htlcId)

def groupByOrigin(htlcsOut: Seq[(Origin, ByteVector32, Long)], htlcsIn: Seq[IncomingHtlc]): Map[Origin, Set[(ByteVector32, Long)]] =
private def groupByOrigin(htlcsOut: Seq[(Origin, ByteVector32, Long)], htlcsIn: Seq[IncomingHtlc]): Map[Origin, Set[(ByteVector32, Long)]] =
htlcsOut
.groupBy { case (origin, _, _) => origin }
.view
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class PostRestartHtlcCleanerSpec extends TestKitBaseClass with FixtureAnyFunSuit
buildHtlcIn(1, channelId_ab_1, randomBytes32()), // not relayed
buildHtlcOut(2, channelId_ab_1, randomBytes32()),
buildHtlcOut(3, channelId_ab_1, randomBytes32()),
buildHtlcIn(4, channelId_ab_1, randomBytes32()), // not relayed
buildHtlcIn(4, channelId_ab_1, randomBytes32(), blinded = true), // not relayed
buildHtlcIn(5, channelId_ab_1, relayedPaymentHash)
)
val htlc_ab_2 = Seq(
Expand All @@ -121,20 +121,26 @@ class PostRestartHtlcCleanerSpec extends TestKitBaseClass with FixtureAnyFunSuit

// channel 1 goes to NORMAL state:
system.eventStream.publish(ChannelStateChanged(channel.ref, channels.head.commitments.channelId, system.deadLetters, a, OFFLINE, NORMAL, Some(channels.head.commitments)))
val fails_ab_1 = channel.expectMsgType[CMD_FAIL_HTLC] :: channel.expectMsgType[CMD_FAIL_HTLC] :: Nil
assert(fails_ab_1.toSet == Set(CMD_FAIL_HTLC(1, Right(TemporaryNodeFailure()), commit = true), CMD_FAIL_HTLC(4, Right(TemporaryNodeFailure()), commit = true)))
channel.expectMsgAllOf(
CMD_FAIL_HTLC(1, Right(TemporaryNodeFailure()), commit = true),
CMD_FAIL_MALFORMED_HTLC(4, ByteVector32.Zeroes, FailureMessageCodecs.BADONION | FailureMessageCodecs.PERM | 24, commit = true)
)
channel.expectNoMessage(100 millis)

// channel 2 goes to NORMAL state:
system.eventStream.publish(ChannelStateChanged(channel.ref, channels(1).commitments.channelId, system.deadLetters, a, OFFLINE, NORMAL, Some(channels(1).commitments)))
val fails_ab_2 = channel.expectMsgType[CMD_FAIL_HTLC] :: channel.expectMsgType[CMD_FAIL_HTLC] :: Nil
assert(fails_ab_2.toSet == Set(CMD_FAIL_HTLC(0, Right(TemporaryNodeFailure()), commit = true), CMD_FAIL_HTLC(4, Right(TemporaryNodeFailure()), commit = true)))
channel.expectMsgAllOf(
CMD_FAIL_HTLC(0, Right(TemporaryNodeFailure()), commit = true),
CMD_FAIL_HTLC(4, Right(TemporaryNodeFailure()), commit = true)
)
channel.expectNoMessage(100 millis)

// let's assume that channel 1 was disconnected before having signed the fails, and gets connected again:
system.eventStream.publish(ChannelStateChanged(channel.ref, channels.head.channelId, system.deadLetters, a, OFFLINE, NORMAL, Some(channels.head.commitments)))
val fails_ab_1_bis = channel.expectMsgType[CMD_FAIL_HTLC] :: channel.expectMsgType[CMD_FAIL_HTLC] :: Nil
assert(fails_ab_1_bis.toSet == Set(CMD_FAIL_HTLC(1, Right(TemporaryNodeFailure()), commit = true), CMD_FAIL_HTLC(4, Right(TemporaryNodeFailure()), commit = true)))
channel.expectMsgAllOf(
CMD_FAIL_HTLC(1, Right(TemporaryNodeFailure()), commit = true),
CMD_FAIL_MALFORMED_HTLC(4, ByteVector32.Zeroes, FailureMessageCodecs.BADONION | FailureMessageCodecs.PERM | 24, commit = true)
)
channel.expectNoMessage(100 millis)

// let's now assume that channel 1 gets reconnected, and it had the time to fail the htlcs:
Expand Down Expand Up @@ -502,6 +508,25 @@ class PostRestartHtlcCleanerSpec extends TestKitBaseClass with FixtureAnyFunSuit
eventListener.expectNoMessage(100 millis)
}

test("handle a blinded channel relay htlc-fail") { f =>
import f._

val htlc_ab = buildHtlcIn(0, channelId_ab_1, paymentHash1, blinded = true)
val origin = Origin.ChannelRelayedCold(htlc_ab.add.channelId, htlc_ab.add.id, htlc_ab.add.amountMsat, htlc_ab.add.amountMsat - 100.msat)
val htlc_bc = buildHtlcOut(6, channelId_bc_1, paymentHash1, blinded = true)
val data_ab = ChannelCodecsSpec.makeChannelDataNormal(Seq(htlc_ab), Map.empty)
val data_bc = ChannelCodecsSpec.makeChannelDataNormal(Seq(htlc_bc), Map(6L -> origin))
val channels = List(data_ab, data_bc)

val (relayer, _) = f.createRelayer(nodeParams)
relayer ! PostRestartHtlcCleaner.Init(channels)
register.expectNoMessage(100 millis)

sender.send(relayer, buildForwardFail(htlc_bc.add, origin))
val cmd = register.expectMsgType[Register.Forward[CMD_FAIL_MALFORMED_HTLC]]
assert(cmd.message == CMD_FAIL_MALFORMED_HTLC(htlc_ab.add.id, ByteVector32.Zeroes, FailureMessageCodecs.BADONION | FailureMessageCodecs.PERM | 24, commit = true))
}

test("handle a channel relay htlc-fulfill") { f =>
import f._

Expand Down Expand Up @@ -670,14 +695,19 @@ object PostRestartHtlcCleanerSpec {
val (preimage1, preimage2, preimage3) = (randomBytes32(), randomBytes32(), randomBytes32())
val (paymentHash1, paymentHash2, paymentHash3) = (Crypto.sha256(preimage1), Crypto.sha256(preimage2), Crypto.sha256(preimage3))

def buildHtlc(htlcId: Long, channelId: ByteVector32, paymentHash: ByteVector32): UpdateAddHtlc = {
val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, hops, None), SpontaneousRecipient(e, finalAmount, finalExpiry, randomBytes32()))
UpdateAddHtlc(channelId, htlcId, payment.cmd.amount, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, None)
def buildHtlc(htlcId: Long, channelId: ByteVector32, paymentHash: ByteVector32, blinded: Boolean = false): UpdateAddHtlc = {
val (route, recipient) = if (blinded) {
singleBlindedHop()
} else {
(Route(finalAmount, hops, None), SpontaneousRecipient(e, finalAmount, finalExpiry, randomBytes32()))
}
val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, route, recipient)
UpdateAddHtlc(channelId, htlcId, payment.cmd.amount, paymentHash, payment.cmd.cltvExpiry, payment.cmd.onion, payment.cmd.nextBlindingKey_opt)
}

def buildHtlcIn(htlcId: Long, channelId: ByteVector32, paymentHash: ByteVector32): DirectedHtlc = IncomingHtlc(buildHtlc(htlcId, channelId, paymentHash))
def buildHtlcIn(htlcId: Long, channelId: ByteVector32, paymentHash: ByteVector32, blinded: Boolean = false): DirectedHtlc = IncomingHtlc(buildHtlc(htlcId, channelId, paymentHash, blinded))

def buildHtlcOut(htlcId: Long, channelId: ByteVector32, paymentHash: ByteVector32): DirectedHtlc = OutgoingHtlc(buildHtlc(htlcId, channelId, paymentHash))
def buildHtlcOut(htlcId: Long, channelId: ByteVector32, paymentHash: ByteVector32, blinded: Boolean = false): DirectedHtlc = OutgoingHtlc(buildHtlc(htlcId, channelId, paymentHash, blinded))

def buildFinalHtlc(htlcId: Long, channelId: ByteVector32, paymentHash: ByteVector32): DirectedHtlc = {
val Right(payment) = buildOutgoingPayment(ActorRef.noSender, priv_a.privateKey, Upstream.Local(UUID.randomUUID()), paymentHash, Route(finalAmount, Seq(channelHopFromUpdate(a, b, channelUpdate_ab)), None), SpontaneousRecipient(b, finalAmount, finalExpiry, randomBytes32()))
Expand Down

0 comments on commit 4c98e1c

Please sign in to comment.