diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala index 7974a86dc3..d4fbd7f1d1 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala @@ -92,7 +92,9 @@ class Peer(val nodeParams: NodeParams, } else { None } - goto(DISCONNECTED) using DisconnectedData(channels, activeChannels = Set.empty, PeerStorage(peerStorageData, written = true)) // when we restart, we will attempt to reconnect right away, but then we'll wait + // When we restart, we will attempt to reconnect right away, but then we'll wait. + // We don't fetch our peer's features from the DB: if the connection succeeds, we will get them from their init message, which saves a DB call. + goto(DISCONNECTED) using DisconnectedData(channels, activeChannels = Set.empty, PeerStorage(peerStorageData, written = true), remoteFeatures_opt = None) } when(DISCONNECTED) { @@ -150,7 +152,14 @@ class Peer(val nodeParams: NodeParams, if (!d.peerStorage.written && !isTimerActive(WritePeerStorageTimerKey)) { startSingleTimer(WritePeerStorageTimerKey, WritePeerStorage, nodeParams.peerStorageConfig.writeDelay) } - stay() using d.copy(activeChannels = d.activeChannels + e.channelId) + val remoteFeatures_opt = d.remoteFeatures_opt match { + case Some(remoteFeatures) if !remoteFeatures.written => + // We have a channel, so we can write to the DB without any DoS risk. + nodeParams.db.peers.addOrUpdatePeer(remoteNodeId, NodeInfo(remoteFeatures.features, None)) + Some(remoteFeatures.copy(written = true)) + case _ => d.remoteFeatures_opt + } + stay() using d.copy(activeChannels = d.activeChannels + e.channelId, remoteFeatures_opt = remoteFeatures_opt) case Event(e: LocalChannelDown, d: DisconnectedData) => stay() using d.copy(activeChannels = d.activeChannels - e.channelId) @@ -447,7 +456,11 @@ class Peer(val nodeParams: NodeParams, if (!d.peerStorage.written && !isTimerActive(WritePeerStorageTimerKey)) { startSingleTimer(WritePeerStorageTimerKey, WritePeerStorage, nodeParams.peerStorageConfig.writeDelay) } - stay() using d.copy(activeChannels = d.activeChannels + e.channelId) + if (!d.remoteFeaturesWritten) { + // We have a channel, so we can write to the DB without any DoS risk. + nodeParams.db.peers.addOrUpdatePeer(remoteNodeId, NodeInfo(d.remoteFeatures, None)) + } + stay() using d.copy(activeChannels = d.activeChannels + e.channelId, remoteFeaturesWritten = true) case Event(e: LocalChannelDown, d: ConnectedData) => stay() using d.copy(activeChannels = d.activeChannels - e.channelId) @@ -492,7 +505,8 @@ class Peer(val nodeParams: NodeParams, stopPeer(d.peerStorage) } else { d.channels.values.toSet[ActorRef].foreach(_ ! INPUT_DISCONNECTED) // we deduplicate with toSet because there might be two entries per channel (tmp id and final id) - goto(DISCONNECTED) using DisconnectedData(d.channels.collect { case (k: FinalChannelId, v) => (k, v) }, d.activeChannels, d.peerStorage) + val lastRemoteFeatures = LastRemoteFeatures(d.remoteFeatures, d.remoteFeaturesWritten) + goto(DISCONNECTED) using DisconnectedData(d.channels.collect { case (k: FinalChannelId, v) => (k, v) }, d.activeChannels, d.peerStorage, Some(lastRemoteFeatures)) } case Event(Terminated(actor), d: ConnectedData) if d.channels.values.toSet.contains(actor) => @@ -587,12 +601,22 @@ class Peer(val nodeParams: NodeParams, case Event(r: GetPeerInfo, d) => val replyTo = r.replyTo.getOrElse(sender().toTyped) - val peerInfo = d match { - case c: ConnectedData => PeerInfo(self, remoteNodeId, stateName, Some(c.remoteFeatures), Some(c.address), c.channels.values.toSet) - case _ => PeerInfo(self, remoteNodeId, stateName, None, None, d.channels.values.toSet) + d match { + case c: ConnectedData => + replyTo ! PeerInfo(self, remoteNodeId, stateName, Some(c.remoteFeatures), Some(c.address), c.channels.values.toSet) + stay() + case d: DisconnectedData => + // If we haven't reconnected since our last restart, we fetch the latest remote features from our DB. + val remoteFeatures_opt = d.remoteFeatures_opt match { + case Some(remoteFeatures) => Some(remoteFeatures) + case None => nodeParams.db.peers.getPeer(remoteNodeId).map(nodeInfo => LastRemoteFeatures(nodeInfo.features, written = true)) + } + replyTo ! PeerInfo(self, remoteNodeId, stateName, remoteFeatures_opt.map(_.features), None, d.channels.values.toSet) + stay() using d.copy(remoteFeatures_opt = remoteFeatures_opt) + case _ => + replyTo ! PeerInfo(self, remoteNodeId, stateName, None, None, d.channels.values.toSet) + stay() } - replyTo ! peerInfo - stay() case Event(r: GetPeerChannels, d) => if (d.channels.isEmpty) { @@ -804,7 +828,13 @@ class Peer(val nodeParams: NodeParams, // We store the node address and features upon successful outgoing connection, so we can reconnect later. // The previous address is overwritten: we don't need it since the current one works. nodeParams.db.peers.addOrUpdatePeer(remoteNodeId, NodeInfo(connectionReady.remoteInit.features, Some(connectionReady.address))) + } else if (channels.nonEmpty) { + // If this is an incoming connection, we only store the peer details in our DB if we have channels with them. + // Otherwise nodes could DoS by simply connecting to us to force us to store data in our DB. + // We don't update the remote address, we don't know if we would successfully connect using the current one. + nodeParams.db.peers.addOrUpdatePeer(remoteNodeId, NodeInfo(connectionReady.remoteInit.features, None)) } + val remoteFeaturesWritten = connectionReady.outgoing || channels.nonEmpty // If we have some data stored from our peer, we send it to them before doing anything else. peerStorage.data.foreach(connectionReady.peerConnection ! PeerStorageRetrieval(_)) @@ -826,7 +856,7 @@ class Peer(val nodeParams: NodeParams, connectionReady.peerConnection ! CurrentFeeCredit(nodeParams.chainHash, feeCredit.getOrElse(0 msat)) } - goto(CONNECTED) using ConnectedData(connectionReady.address, connectionReady.peerConnection, connectionReady.localInit, connectionReady.remoteInit, channels, activeChannels, feerates, None, peerStorage) + goto(CONNECTED) using ConnectedData(connectionReady.address, connectionReady.peerConnection, connectionReady.localInit, connectionReady.remoteInit, channels, activeChannels, feerates, None, peerStorage, remoteFeaturesWritten) } /** @@ -967,6 +997,8 @@ object Peer { case class PeerStorage(data: Option[ByteVector], written: Boolean) + case class LastRemoteFeatures(features: Features[InitFeature], written: Boolean) + sealed trait Data { def channels: Map[_ <: ChannelId, ActorRef] // will be overridden by Map[FinalChannelId, ActorRef] or Map[ChannelId, ActorRef] def activeChannels: Set[ByteVector32] // channels that are available to process payments @@ -977,8 +1009,8 @@ object Peer { override def activeChannels: Set[ByteVector32] = Set.empty override def peerStorage: PeerStorage = PeerStorage(None, written = true) } - case class DisconnectedData(channels: Map[FinalChannelId, ActorRef], activeChannels: Set[ByteVector32], peerStorage: PeerStorage) extends Data - case class ConnectedData(address: NodeAddress, peerConnection: ActorRef, localInit: protocol.Init, remoteInit: protocol.Init, channels: Map[ChannelId, ActorRef], activeChannels: Set[ByteVector32], currentFeerates: RecommendedFeerates, previousFeerates_opt: Option[RecommendedFeerates], peerStorage: PeerStorage) extends Data { + case class DisconnectedData(channels: Map[FinalChannelId, ActorRef], activeChannels: Set[ByteVector32], peerStorage: PeerStorage, remoteFeatures_opt: Option[LastRemoteFeatures]) extends Data + case class ConnectedData(address: NodeAddress, peerConnection: ActorRef, localInit: protocol.Init, remoteInit: protocol.Init, channels: Map[ChannelId, ActorRef], activeChannels: Set[ByteVector32], currentFeerates: RecommendedFeerates, previousFeerates_opt: Option[RecommendedFeerates], peerStorage: PeerStorage, remoteFeaturesWritten: Boolean) extends Data { val connectionInfo: ConnectionInfo = ConnectionInfo(address, peerConnection, localInit, remoteInit) def localFeatures: Features[InitFeature] = localInit.features def remoteFeatures: Features[InitFeature] = remoteInit.features diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala index ac64eb146a..1d9eca3131 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala @@ -755,7 +755,7 @@ class PeerSpec extends FixtureSpec { channel.expectMsg(open) } - test("peer storage") { f => + test("store remote peer storage once we have channels") { f => import f._ // We connect with a previous backup. @@ -768,7 +768,6 @@ class PeerSpec extends FixtureSpec { peerConnection1.send(peer, PeerStorageStore(hex"0123456789")) // We disconnect and reconnect, sending the last backup we received. - peer ! Peer.Disconnect(f.remoteNodeId) val peerConnection2 = TestProbe() connect(remoteNodeId, peer, peerConnection2, switchboard, channels = Set(ChannelCodecsSpec.normal), initializePeer = false, peerStorage = Some(hex"0123456789")) peerConnection2.send(peer, PeerStorageStore(hex"1111")) @@ -788,6 +787,63 @@ class PeerSpec extends FixtureSpec { assert(nodeParams.db.peers.getStorage(remoteNodeId).contains(hex"1111")) } + test("store remote features when channel confirms") { f => + import f._ + + // When we make an outgoing connection, we store the peer details in our DB. + assert(nodeParams.db.peers.getPeer(remoteNodeId).isEmpty) + connect(remoteNodeId, peer, peerConnection, switchboard) + val Some(nodeInfo1) = nodeParams.db.peers.getPeer(remoteNodeId) + assert(nodeInfo1.features == TestConstants.Bob.nodeParams.features.initFeatures()) + assert(nodeInfo1.address_opt.contains(fakeIPAddress)) + + // We disconnect and our peer connects to us: we don't have any channel, so we don't update the DB entry. + val peerConnection2 = TestProbe() + val address2 = Tor3("of7husrflx7sforh3fw6yqlpwstee3wg5imvvmkp4bz6rbjxtg5nljad", 9735) + val remoteFeatures2 = Features(Features.ChannelType -> FeatureSupport.Mandatory).initFeatures() + switchboard.send(peer, PeerConnection.ConnectionReady(peerConnection2.ref, remoteNodeId, address2, outgoing = false, protocol.Init(Features.empty), protocol.Init(remoteFeatures2))) + val probe = TestProbe() + probe.send(peer, Peer.GetPeerInfo(Some(probe.ref.toTyped))) + assert(probe.expectMsgType[Peer.PeerInfo].address.contains(address2)) + assert(nodeParams.db.peers.getPeer(remoteNodeId).contains(nodeInfo1)) + + // A channel is created, so we update the remote features in our DB. + // We don't update the address because this was an incoming connection. + peer ! ChannelReadyForPayments(ActorRef.noSender, remoteNodeId, randomBytes32(), 0) + probe.send(peer, Peer.GetPeerInfo(Some(probe.ref.toTyped))) + assert(probe.expectMsgType[Peer.PeerInfo].features.contains(remoteFeatures2)) + assert(nodeParams.db.peers.getPeer(remoteNodeId).contains(nodeInfo1.copy(features = remoteFeatures2))) + } + + test("store remote features when channel confirms while disconnected") { f => + import f._ + + // When we receive an incoming connection, we don't store the peer details in our DB. + assert(nodeParams.db.peers.getPeer(remoteNodeId).isEmpty) + switchboard.send(peer, Peer.Init(Set.empty, Map.empty)) + val localInit = protocol.Init(peer.underlyingActor.nodeParams.features.initFeatures()) + val remoteInit = protocol.Init(TestConstants.Bob.nodeParams.features.initFeatures()) + switchboard.send(peer, PeerConnection.ConnectionReady(peerConnection.ref, remoteNodeId, fakeIPAddress, outgoing = false, localInit, remoteInit)) + val probe = TestProbe() + probe.send(peer, Peer.GetPeerInfo(Some(probe.ref.toTyped))) + assert(probe.expectMsgType[Peer.PeerInfo].state == Peer.CONNECTED) + assert(nodeParams.db.peers.getPeer(remoteNodeId).isEmpty) + + // Our peer wants to open a channel to us, but we disconnect before we have a confirmed channel. + peer ! SpawnChannelNonInitiator(Left(createOpenChannelMessage()), ChannelConfig.standard, ChannelTypes.Standard(), None, localParams, peerConnection.ref) + peer ! Peer.ConnectionDown(peerConnection.ref) + probe.send(peer, Peer.GetPeerInfo(Some(probe.ref.toTyped))) + assert(probe.expectMsgType[Peer.PeerInfo].state == Peer.DISCONNECTED) + assert(nodeParams.db.peers.getPeer(remoteNodeId).isEmpty) + + // The channel confirms, so we store the remote features in our DB. + // We don't store the remote address because this was an incoming connection. + peer ! ChannelReadyForPayments(ActorRef.noSender, remoteNodeId, randomBytes32(), 0) + probe.send(peer, Peer.GetPeerInfo(Some(probe.ref.toTyped))) + assert(probe.expectMsgType[Peer.PeerInfo].state == Peer.DISCONNECTED) + assert(nodeParams.db.peers.getPeer(remoteNodeId).contains(NodeInfo(remoteInit.features, None))) + } + } object PeerSpec { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/ReconnectionTaskSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/ReconnectionTaskSpec.scala index d4ba5401c2..06ca2b13e9 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/ReconnectionTaskSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/ReconnectionTaskSpec.scala @@ -38,8 +38,8 @@ class ReconnectionTaskSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike private val recommendedFeerates = RecommendedFeerates(Block.RegtestGenesisBlock.hash, TestConstants.feeratePerKw, TestConstants.anchorOutputsFeeratePerKw) private val PeerNothingData = Peer.Nothing - private val PeerDisconnectedData = Peer.DisconnectedData(channels, activeChannels = Set.empty, PeerStorage(None, written = true)) - private val PeerConnectedData = Peer.ConnectedData(fakeIPAddress, system.deadLetters, null, null, channels.map { case (k: ChannelId, v) => (k, v) }, activeChannels = Set.empty, recommendedFeerates, None, PeerStorage(None, written = true)) + private val PeerDisconnectedData = Peer.DisconnectedData(channels, activeChannels = Set.empty, PeerStorage(None, written = true), remoteFeatures_opt = None) + private val PeerConnectedData = Peer.ConnectedData(fakeIPAddress, system.deadLetters, null, null, channels.map { case (k: ChannelId, v) => (k, v) }, activeChannels = Set.empty, recommendedFeerates, None, PeerStorage(None, written = true), remoteFeaturesWritten = true) case class FixtureParam(nodeParams: NodeParams, remoteNodeId: PublicKey, reconnectionTask: TestFSMRef[ReconnectionTask.State, ReconnectionTask.Data, ReconnectionTask], monitor: TestProbe) @@ -82,7 +82,7 @@ class ReconnectionTaskSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike import f._ val peer = TestProbe() - peer.send(reconnectionTask, Peer.Transition(PeerNothingData, Peer.DisconnectedData(Map.empty, activeChannels = Set.empty, PeerStorage(None, written = true)))) + peer.send(reconnectionTask, Peer.Transition(PeerNothingData, Peer.DisconnectedData(Map.empty, activeChannels = Set.empty, PeerStorage(None, written = true), None))) monitor.expectNoMessage() } @@ -205,7 +205,6 @@ class ReconnectionTaskSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike peer.send(reconnectionTask, Peer.Transition(PeerDisconnectedData, PeerConnectedData)) // we cancel the reconnection and go to idle state val TransitionWithData(ReconnectionTask.WAITING, ReconnectionTask.IDLE, _, _) = monitor.expectMsgType[TransitionWithData] - } test("reconnect using the address from node_announcement") { f => @@ -232,15 +231,13 @@ class ReconnectionTaskSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike val tor = NodeAddress.fromParts("iq7zhmhck54vcax2vlrdcavq2m32wao7ekh6jyeglmnuuvv3js57r4id.onion", 9735).get // NB: we don't test randomization here, but it makes tests unnecessary more complex for little value - { // tor not supported: always return clearnet addresses nodeParams.socksProxy_opt returns None - assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet)) == Some(clearnet)) - assert(ReconnectionTask.selectNodeAddress(nodeParams, List(tor)) == None) - assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet, tor)) == Some(clearnet)) + assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet)).contains(clearnet)) + assert(ReconnectionTask.selectNodeAddress(nodeParams, List(tor)).isEmpty) + assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet, tor)).contains(clearnet)) } - { // tor supported but not enabled for clearnet addresses: return clearnet addresses when available val socksParams = mock[Socks5ProxyParams] @@ -248,11 +245,10 @@ class ReconnectionTaskSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike socksParams.useForIPv4 returns false socksParams.useForIPv6 returns false nodeParams.socksProxy_opt returns Some(socksParams) - assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet)) == Some(clearnet)) - assert(ReconnectionTask.selectNodeAddress(nodeParams, List(tor)) == Some(tor)) - assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet, tor)) == Some(clearnet)) + assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet)).contains(clearnet)) + assert(ReconnectionTask.selectNodeAddress(nodeParams, List(tor)).contains(tor)) + assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet, tor)).contains(clearnet)) } - { // tor supported and enabled for clearnet addresses: return tor addresses when available val socksParams = mock[Socks5ProxyParams] @@ -260,11 +256,10 @@ class ReconnectionTaskSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike socksParams.useForIPv4 returns true socksParams.useForIPv6 returns true nodeParams.socksProxy_opt returns Some(socksParams) - assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet)) == Some(clearnet)) - assert(ReconnectionTask.selectNodeAddress(nodeParams, List(tor)) == Some(tor)) - assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet, tor)) == Some(tor)) + assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet)).contains(clearnet)) + assert(ReconnectionTask.selectNodeAddress(nodeParams, List(tor)).contains(tor)) + assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet, tor)).contains(tor)) } - } }