diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/message/OnionMessages.scala b/eclair-core/src/main/scala/fr/acinq/eclair/message/OnionMessages.scala index 6f5f22323f..e8c177ff34 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/message/OnionMessages.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/message/OnionMessages.scala @@ -44,7 +44,14 @@ object OnionMessages { timeout: FiniteDuration, maxAttempts: Int) - case class IntermediateNode(nodeId: PublicKey, padding: Option[ByteVector] = None, customTlvs: Set[GenericTlv] = Set.empty) + case class IntermediateNode(nodeId: PublicKey, outgoingChannel_opt: Option[ShortChannelId] = None, padding: Option[ByteVector] = None, customTlvs: Set[GenericTlv] = Set.empty) { + def toTlvStream(nextNodeId: PublicKey, nextBlinding_opt: Option[PublicKey] = None): TlvStream[RouteBlindingEncryptedDataTlv] = + TlvStream(Set[Option[RouteBlindingEncryptedDataTlv]]( + padding.map(Padding), + outgoingChannel_opt.map(OutgoingChannelId).orElse(Some(OutgoingNodeId(nextNodeId))), + nextBlinding_opt.map(NextBlinding) + ).flatten, customTlvs) + } // @formatter:off sealed trait Destination @@ -64,12 +71,12 @@ object OnionMessages { } // @formatter:on - private def buildIntermediatePayloads(intermediateNodes: Seq[IntermediateNode], nextTlvs: Set[RouteBlindingEncryptedDataTlv]): Seq[ByteVector] = { + private def buildIntermediatePayloads(intermediateNodes: Seq[IntermediateNode], lastNodeId: PublicKey, lastBlinding_opt: Option[PublicKey] = None): Seq[ByteVector] = { if (intermediateNodes.isEmpty) { Nil } else { - (intermediateNodes.tail.map(node => Set(OutgoingNodeId(node.nodeId))) :+ nextTlvs) - .zip(intermediateNodes).map { case (tlvs, hop) => TlvStream(hop.padding.map(Padding).toSet[RouteBlindingEncryptedDataTlv] ++ tlvs, hop.customTlvs) } + (intermediateNodes.tail.zip(intermediateNodes.dropRight(1)).map { case (nextNode, hop) => hop.toTlvStream(nextNode.nodeId) } :+ + intermediateNodes.last.toTlvStream(lastNodeId, lastBlinding_opt)) .map(tlvs => RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(tlvs).require.bytes) } } @@ -77,7 +84,7 @@ object OnionMessages { def buildRoute(blindingSecret: PrivateKey, intermediateNodes: Seq[IntermediateNode], recipient: Recipient): Sphinx.RouteBlinding.BlindedRoute = { - val intermediatePayloads = buildIntermediatePayloads(intermediateNodes, Set(OutgoingNodeId(recipient.nodeId))) + val intermediatePayloads = buildIntermediatePayloads(intermediateNodes, recipient.nodeId) val tlvs: Set[RouteBlindingEncryptedDataTlv] = Set(recipient.padding.map(Padding), recipient.pathId.map(PathId)).flatten val lastPayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream(tlvs, recipient.customTlvs)).require.bytes Sphinx.RouteBlinding.create(blindingSecret, intermediateNodes.map(_.nodeId) :+ recipient.nodeId, intermediatePayloads :+ lastPayload).route @@ -101,7 +108,7 @@ object OnionMessages { } case BlindedPath(route) if intermediateNodes.isEmpty => Some(route) case BlindedPath(route) => - val intermediatePayloads = buildIntermediatePayloads(intermediateNodes, Set(OutgoingNodeId(route.introductionNodeId), NextBlinding(route.blindingKey))) + val intermediatePayloads = buildIntermediatePayloads(intermediateNodes, route.introductionNodeId, Some(route.blindingKey)) val routePrefix = Sphinx.RouteBlinding.create(blindingSecret, intermediateNodes.map(_.nodeId), intermediatePayloads).route Some(Sphinx.RouteBlinding.BlindedRoute(routePrefix.introductionNodeId, routePrefix.blindingKey, routePrefix.blindedNodes ++ route.blindedNodes)) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/payment/OfferPaymentSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/payment/OfferPaymentSpec.scala index a11c907caa..cc4d2bec28 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/payment/OfferPaymentSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/payment/OfferPaymentSpec.scala @@ -34,11 +34,12 @@ import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment.offer.OfferManager import fr.acinq.eclair.payment.receive.MultiPartHandler.{DummyBlindedHop, ReceivingRoute} import fr.acinq.eclair.payment.send.PaymentInitiator.{SendPaymentToNode, SendSpontaneousPayment} -import fr.acinq.eclair.payment.send.{OfferPayment, PaymentLifecycle} +import fr.acinq.eclair.payment.send.{ClearRecipient, OfferPayment, PaymentLifecycle} +import fr.acinq.eclair.router.Router import fr.acinq.eclair.testutils.FixtureSpec import fr.acinq.eclair.wire.protocol.OfferTypes.{Offer, OfferPaths} import fr.acinq.eclair.wire.protocol.{IncorrectOrUnknownPaymentDetails, InvalidOnionBlinding} -import fr.acinq.eclair.{CltvExpiryDelta, Features, MilliSatoshi, MilliSatoshiLong, randomBytes32, randomKey} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, MilliSatoshi, MilliSatoshiLong, ShortChannelId, randomBytes32, randomKey} import org.scalatest.concurrent.IntegrationPatience import org.scalatest.{Tag, TestData} import scodec.bits.HexStringSyntax @@ -138,8 +139,9 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { val recipientKey = randomKey() val pathId = randomBytes32() val offerPaths = routes.map(route => { - route.nodes.dropRight(1).map(IntermediateNode(_)) - buildRoute(randomKey(), route.nodes.dropRight(1).map(IntermediateNode(_)), Recipient(route.nodes.last, Some(pathId))) + val ourNode = route.nodes.last + val intermediateNode = route.nodes.dropRight(1).map(IntermediateNode(_)) ++ route.dummyHops.map(_ => IntermediateNode(ourNode)) + buildRoute(randomKey(), intermediateNode, Recipient(route.nodes.last, Some(pathId))) }) val offer = Offer(None, "test", recipientKey.publicKey, Features.empty, recipient.nodeParams.chainHash, additionalTlvs = Set(OfferPaths(offerPaths))) val handler = recipient.system.spawnAnonymous(offerHandler(amount, routes)) @@ -350,4 +352,38 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { assert(failure.t == PaymentLifecycle.UpdateMalformedException) } + test("send payment a->b->c compact offer") { f => + import f._ + + val amount = 25_000_000 msat + + val probe = TestProbe() + val routeParams = carol.nodeParams.routerConf.pathFindingExperimentConf.getRandomConf().getDefaultRouteParams + probe.send(carol.router, Router.RouteRequest(bob.nodeId, ClearRecipient(carol.nodeId, Features.empty, amount, CltvExpiry(1000000000), randomBytes32()), routeParams)) + val route = probe.expectMsgType[Router.RouteResponse].routes.head + + val recipientKey = randomKey() + val pathId = randomBytes32() + + val intermediateNodeIds = route.hops.map(hop => IntermediateNode(hop.nodeId)) :+ IntermediateNode(carol.nodeId) + val blindedRoute = buildRoute(randomKey(), intermediateNodeIds, Recipient(carol.nodeId, Some(pathId))) + val offer = Offer(None, "test", recipientKey.publicKey, Features.empty, carol.nodeParams.chainHash, additionalTlvs = Set(OfferPaths(Seq(blindedRoute)))) + + val intermediateChannelIds = route.hops.map(hop => IntermediateNode(hop.nodeId, Some(hop.shortChannelId))) :+ IntermediateNode(carol.nodeId, Some(ShortChannelId.toSelf)) + val compactBlindedRoute = buildRoute(randomKey(), intermediateChannelIds, Recipient(carol.nodeId, Some(pathId))) + val compactOffer = Offer(None, "test", recipientKey.publicKey, Features.empty, carol.nodeParams.chainHash, additionalTlvs = Set(OfferPaths(Seq(compactBlindedRoute)))) + + assert(compactOffer.toString.length < offer.toString.length) + + val receivingRoute = ReceivingRoute(Seq(bob.nodeId, carol.nodeId), maxFinalExpiryDelta) + + val handler = carol.system.spawnAnonymous(offerHandler(amount, Seq(receivingRoute))) + carol.offerManager ! OfferManager.RegisterOffer(compactOffer, recipientKey, Some(pathId), handler) + val offerPayment = alice.system.spawnAnonymous(OfferPayment(alice.nodeParams, alice.postman, alice.paymentInitiator)) + val sendPaymentConfig = OfferPayment.SendPaymentConfig(None, connectDirectly = false, maxAttempts = 1, alice.routeParams, blocking = true) + offerPayment ! OfferPayment.PayOffer(probe.ref, compactOffer, amount, 1, sendPaymentConfig) + + val payment = verifyPaymentSuccess(compactOffer, amount, probe.expectMsgType[PaymentEvent]) + assert(payment.parts.length == 1) + } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/message/OnionMessagesSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/message/OnionMessagesSpec.scala index 1205472bea..7f2101ab31 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/message/OnionMessagesSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/message/OnionMessagesSpec.scala @@ -27,7 +27,7 @@ import fr.acinq.eclair.wire.protocol.OnionMessagePayloadTlv.EncryptedData import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataTlv._ import fr.acinq.eclair.wire.protocol.{GenericTlv, OnionMessage, OnionMessagePayloadTlv, OnionRoutingCodecs, RouteBlindingEncryptedDataCodecs, RouteBlindingEncryptedDataTlv, TlvStream} -import fr.acinq.eclair.{UInt64, randomBytes, randomKey} +import fr.acinq.eclair.{ShortChannelId, UInt64, randomBytes, randomKey} import org.json4s._ import org.json4s.jackson.JsonMethods._ import org.scalatest.funsuite.AnyFunSuite @@ -286,7 +286,7 @@ class OnionMessagesSpec extends AnyFunSuite { Recipient(nodeKey.publicKey, Some(ByteVector.fromValidHex((json \ "path_id").extract[String])), (json \ "padding").extract[Option[String]].map(ByteVector.fromValidHex(_)), getCustomTlvs(json)) def makeIntermediateNode(nodeKey: PrivateKey, json: JValue): IntermediateNode = - IntermediateNode(nodeKey.publicKey, (json \ "padding").extract[Option[String]].map(ByteVector.fromValidHex(_)), getCustomTlvs(json)) + IntermediateNode(nodeKey.publicKey, None, (json \ "padding").extract[Option[String]].map(ByteVector.fromValidHex(_)), getCustomTlvs(json)) val blindingSecretBob = PrivateKey(ByteVector32.fromValidHex(((testVector \ "generate" \ "hops")(1) \ "blinding_secret").extract[String])) val pathId = ByteVector.fromValidHex(((testVector \ "generate" \ "hops")(3) \ "tlvs" \ "path_id").extract[String]) @@ -347,4 +347,33 @@ class OnionMessagesSpec extends AnyFunSuite { case x => fail(x.toString) } } + + test("route with channel ids") { + val nodeKey = randomKey() + val alice = randomKey() + val alice2bob = ShortChannelId(1) + val bob = randomKey() + val bob2carol = ShortChannelId(2) + val carol = randomKey() + val sessionKey = randomKey() + val blindingSecret = randomKey() + val pathId = randomBytes(64) + val Right((_, messageForAlice)) = buildMessage(nodeKey, sessionKey, blindingSecret, IntermediateNode(alice.publicKey, outgoingChannel_opt = Some(alice2bob)) :: IntermediateNode(bob.publicKey, outgoingChannel_opt = Some(bob2carol)) :: Nil, Recipient(carol.publicKey, Some(pathId)), TlvStream.empty) + + // Checking that the onion is relayed properly + process(alice, messageForAlice) match { + case SendMessage(Left(outgoingChannelId), onionForBob) => + assert(outgoingChannelId == alice2bob) + process(bob, onionForBob) match { + case SendMessage(Left(outgoingChannelId), onionForCarol) => + assert(outgoingChannelId == bob2carol) + process(carol, onionForCarol) match { + case ReceiveMessage(finalPayload) => assert(finalPayload.pathId_opt.contains(pathId)) + case x => fail(x.toString) + } + case x => fail(x.toString) + } + case x => fail(x.toString) + } + } }