Skip to content

Commit

Permalink
Store incoming peers with channels in PeersDb
Browse files Browse the repository at this point in the history
Once we have a channel with a peer that connected to us, we store their
details in our DB. We don't store the address they're connecting from,
because we don't know if we will be able to connect to them using this
address, but we store their features.
  • Loading branch information
t-bast committed Jan 10, 2025
1 parent d0a0589 commit 2092ead
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 31 deletions.
56 changes: 44 additions & 12 deletions eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ class Peer(val nodeParams: NodeParams,
} else {
None
}
goto(DISCONNECTED) using DisconnectedData(channels, activeChannels = Set.empty, PeerStorage(peerStorageData, written = true)) // when we restart, we will attempt to reconnect right away, but then we'll wait
// When we restart, we will attempt to reconnect right away, but then we'll wait.
// We don't fetch our peer's features from the DB: if the connection succeeds, we will get them from their init message, which saves a DB call.
goto(DISCONNECTED) using DisconnectedData(channels, activeChannels = Set.empty, PeerStorage(peerStorageData, written = true), remoteFeatures_opt = None)
}

when(DISCONNECTED) {
Expand Down Expand Up @@ -150,7 +152,14 @@ class Peer(val nodeParams: NodeParams,
if (!d.peerStorage.written && !isTimerActive(WritePeerStorageTimerKey)) {
startSingleTimer(WritePeerStorageTimerKey, WritePeerStorage, nodeParams.peerStorageConfig.writeDelay)
}
stay() using d.copy(activeChannels = d.activeChannels + e.channelId)
val remoteFeatures_opt = d.remoteFeatures_opt match {
case Some(remoteFeatures) if !remoteFeatures.written =>
// We have a channel, so we can write to the DB without any DoS risk.
nodeParams.db.peers.addOrUpdatePeer(remoteNodeId, NodeInfo(remoteFeatures.features, None))
Some(remoteFeatures.copy(written = true))
case _ => d.remoteFeatures_opt
}
stay() using d.copy(activeChannels = d.activeChannels + e.channelId, remoteFeatures_opt = remoteFeatures_opt)

case Event(e: LocalChannelDown, d: DisconnectedData) =>
stay() using d.copy(activeChannels = d.activeChannels - e.channelId)
Expand Down Expand Up @@ -447,7 +456,11 @@ class Peer(val nodeParams: NodeParams,
if (!d.peerStorage.written && !isTimerActive(WritePeerStorageTimerKey)) {
startSingleTimer(WritePeerStorageTimerKey, WritePeerStorage, nodeParams.peerStorageConfig.writeDelay)
}
stay() using d.copy(activeChannels = d.activeChannels + e.channelId)
if (!d.remoteFeaturesWritten) {
// We have a channel, so we can write to the DB without any DoS risk.
nodeParams.db.peers.addOrUpdatePeer(remoteNodeId, NodeInfo(d.remoteFeatures, None))
}
stay() using d.copy(activeChannels = d.activeChannels + e.channelId, remoteFeaturesWritten = true)

case Event(e: LocalChannelDown, d: ConnectedData) =>
stay() using d.copy(activeChannels = d.activeChannels - e.channelId)
Expand Down Expand Up @@ -492,7 +505,8 @@ class Peer(val nodeParams: NodeParams,
stopPeer(d.peerStorage)
} else {
d.channels.values.toSet[ActorRef].foreach(_ ! INPUT_DISCONNECTED) // we deduplicate with toSet because there might be two entries per channel (tmp id and final id)
goto(DISCONNECTED) using DisconnectedData(d.channels.collect { case (k: FinalChannelId, v) => (k, v) }, d.activeChannels, d.peerStorage)
val lastRemoteFeatures = LastRemoteFeatures(d.remoteFeatures, d.remoteFeaturesWritten)
goto(DISCONNECTED) using DisconnectedData(d.channels.collect { case (k: FinalChannelId, v) => (k, v) }, d.activeChannels, d.peerStorage, Some(lastRemoteFeatures))
}

case Event(Terminated(actor), d: ConnectedData) if d.channels.values.toSet.contains(actor) =>
Expand Down Expand Up @@ -587,12 +601,22 @@ class Peer(val nodeParams: NodeParams,

case Event(r: GetPeerInfo, d) =>
val replyTo = r.replyTo.getOrElse(sender().toTyped)
val peerInfo = d match {
case c: ConnectedData => PeerInfo(self, remoteNodeId, stateName, Some(c.remoteFeatures), Some(c.address), c.channels.values.toSet)
case _ => PeerInfo(self, remoteNodeId, stateName, None, None, d.channels.values.toSet)
d match {
case c: ConnectedData =>
replyTo ! PeerInfo(self, remoteNodeId, stateName, Some(c.remoteFeatures), Some(c.address), c.channels.values.toSet)
stay()
case d: DisconnectedData =>
// If we haven't reconnected since our last restart, we fetch the latest remote features from our DB.
val remoteFeatures_opt = d.remoteFeatures_opt match {
case Some(remoteFeatures) => Some(remoteFeatures)
case None => nodeParams.db.peers.getPeer(remoteNodeId).map(nodeInfo => LastRemoteFeatures(nodeInfo.features, written = true))
}
replyTo ! PeerInfo(self, remoteNodeId, stateName, remoteFeatures_opt.map(_.features), None, d.channels.values.toSet)
stay() using d.copy(remoteFeatures_opt = remoteFeatures_opt)
case _ =>
replyTo ! PeerInfo(self, remoteNodeId, stateName, None, None, d.channels.values.toSet)
stay()
}
replyTo ! peerInfo
stay()

case Event(r: GetPeerChannels, d) =>
if (d.channels.isEmpty) {
Expand Down Expand Up @@ -804,7 +828,13 @@ class Peer(val nodeParams: NodeParams,
// We store the node address and features upon successful outgoing connection, so we can reconnect later.
// The previous address is overwritten: we don't need it since the current one works.
nodeParams.db.peers.addOrUpdatePeer(remoteNodeId, NodeInfo(connectionReady.remoteInit.features, Some(connectionReady.address)))
} else if (channels.nonEmpty) {
// If this is an incoming connection, we only store the peer details in our DB if we have channels with them.
// Otherwise nodes could DoS by simply connecting to us to force us to store data in our DB.
// We don't update the remote address, we don't know if we would successfully connect using the current one.
nodeParams.db.peers.addOrUpdatePeer(remoteNodeId, NodeInfo(connectionReady.remoteInit.features, None))
}
val remoteFeaturesWritten = connectionReady.outgoing || channels.nonEmpty

// If we have some data stored from our peer, we send it to them before doing anything else.
peerStorage.data.foreach(connectionReady.peerConnection ! PeerStorageRetrieval(_))
Expand All @@ -826,7 +856,7 @@ class Peer(val nodeParams: NodeParams,
connectionReady.peerConnection ! CurrentFeeCredit(nodeParams.chainHash, feeCredit.getOrElse(0 msat))
}

goto(CONNECTED) using ConnectedData(connectionReady.address, connectionReady.peerConnection, connectionReady.localInit, connectionReady.remoteInit, channels, activeChannels, feerates, None, peerStorage)
goto(CONNECTED) using ConnectedData(connectionReady.address, connectionReady.peerConnection, connectionReady.localInit, connectionReady.remoteInit, channels, activeChannels, feerates, None, peerStorage, remoteFeaturesWritten)
}

/**
Expand Down Expand Up @@ -967,6 +997,8 @@ object Peer {

case class PeerStorage(data: Option[ByteVector], written: Boolean)

case class LastRemoteFeatures(features: Features[InitFeature], written: Boolean)

sealed trait Data {
def channels: Map[_ <: ChannelId, ActorRef] // will be overridden by Map[FinalChannelId, ActorRef] or Map[ChannelId, ActorRef]
def activeChannels: Set[ByteVector32] // channels that are available to process payments
Expand All @@ -977,8 +1009,8 @@ object Peer {
override def activeChannels: Set[ByteVector32] = Set.empty
override def peerStorage: PeerStorage = PeerStorage(None, written = true)
}
case class DisconnectedData(channels: Map[FinalChannelId, ActorRef], activeChannels: Set[ByteVector32], peerStorage: PeerStorage) extends Data
case class ConnectedData(address: NodeAddress, peerConnection: ActorRef, localInit: protocol.Init, remoteInit: protocol.Init, channels: Map[ChannelId, ActorRef], activeChannels: Set[ByteVector32], currentFeerates: RecommendedFeerates, previousFeerates_opt: Option[RecommendedFeerates], peerStorage: PeerStorage) extends Data {
case class DisconnectedData(channels: Map[FinalChannelId, ActorRef], activeChannels: Set[ByteVector32], peerStorage: PeerStorage, remoteFeatures_opt: Option[LastRemoteFeatures]) extends Data
case class ConnectedData(address: NodeAddress, peerConnection: ActorRef, localInit: protocol.Init, remoteInit: protocol.Init, channels: Map[ChannelId, ActorRef], activeChannels: Set[ByteVector32], currentFeerates: RecommendedFeerates, previousFeerates_opt: Option[RecommendedFeerates], peerStorage: PeerStorage, remoteFeaturesWritten: Boolean) extends Data {
val connectionInfo: ConnectionInfo = ConnectionInfo(address, peerConnection, localInit, remoteInit)
def localFeatures: Features[InitFeature] = localInit.features
def remoteFeatures: Features[InitFeature] = remoteInit.features
Expand Down
60 changes: 58 additions & 2 deletions eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,7 @@ class PeerSpec extends FixtureSpec {
channel.expectMsg(open)
}

test("peer storage") { f =>
test("store remote peer storage once we have channels") { f =>
import f._

// We connect with a previous backup.
Expand All @@ -768,7 +768,6 @@ class PeerSpec extends FixtureSpec {
peerConnection1.send(peer, PeerStorageStore(hex"0123456789"))

// We disconnect and reconnect, sending the last backup we received.
peer ! Peer.Disconnect(f.remoteNodeId)
val peerConnection2 = TestProbe()
connect(remoteNodeId, peer, peerConnection2, switchboard, channels = Set(ChannelCodecsSpec.normal), initializePeer = false, peerStorage = Some(hex"0123456789"))
peerConnection2.send(peer, PeerStorageStore(hex"1111"))
Expand All @@ -788,6 +787,63 @@ class PeerSpec extends FixtureSpec {
assert(nodeParams.db.peers.getStorage(remoteNodeId).contains(hex"1111"))
}

test("store remote features when channel confirms") { f =>
import f._

// When we make an outgoing connection, we store the peer details in our DB.
assert(nodeParams.db.peers.getPeer(remoteNodeId).isEmpty)
connect(remoteNodeId, peer, peerConnection, switchboard)
val Some(nodeInfo1) = nodeParams.db.peers.getPeer(remoteNodeId)
assert(nodeInfo1.features == TestConstants.Bob.nodeParams.features.initFeatures())
assert(nodeInfo1.address_opt.contains(fakeIPAddress))

// We disconnect and our peer connects to us: we don't have any channel, so we don't update the DB entry.
val peerConnection2 = TestProbe()
val address2 = Tor3("of7husrflx7sforh3fw6yqlpwstee3wg5imvvmkp4bz6rbjxtg5nljad", 9735)
val remoteFeatures2 = Features(Features.ChannelType -> FeatureSupport.Mandatory).initFeatures()
switchboard.send(peer, PeerConnection.ConnectionReady(peerConnection2.ref, remoteNodeId, address2, outgoing = false, protocol.Init(Features.empty), protocol.Init(remoteFeatures2)))
val probe = TestProbe()
probe.send(peer, Peer.GetPeerInfo(Some(probe.ref.toTyped)))
assert(probe.expectMsgType[Peer.PeerInfo].address.contains(address2))
assert(nodeParams.db.peers.getPeer(remoteNodeId).contains(nodeInfo1))

// A channel is created, so we update the remote features in our DB.
// We don't update the address because this was an incoming connection.
peer ! ChannelReadyForPayments(ActorRef.noSender, remoteNodeId, randomBytes32(), 0)
probe.send(peer, Peer.GetPeerInfo(Some(probe.ref.toTyped)))
assert(probe.expectMsgType[Peer.PeerInfo].features.contains(remoteFeatures2))
assert(nodeParams.db.peers.getPeer(remoteNodeId).contains(nodeInfo1.copy(features = remoteFeatures2)))
}

test("store remote features when channel confirms while disconnected") { f =>
import f._

// When we receive an incoming connection, we don't store the peer details in our DB.
assert(nodeParams.db.peers.getPeer(remoteNodeId).isEmpty)
switchboard.send(peer, Peer.Init(Set.empty, Map.empty))
val localInit = protocol.Init(peer.underlyingActor.nodeParams.features.initFeatures())
val remoteInit = protocol.Init(TestConstants.Bob.nodeParams.features.initFeatures())
switchboard.send(peer, PeerConnection.ConnectionReady(peerConnection.ref, remoteNodeId, fakeIPAddress, outgoing = false, localInit, remoteInit))
val probe = TestProbe()
probe.send(peer, Peer.GetPeerInfo(Some(probe.ref.toTyped)))
assert(probe.expectMsgType[Peer.PeerInfo].state == Peer.CONNECTED)
assert(nodeParams.db.peers.getPeer(remoteNodeId).isEmpty)

// Our peer wants to open a channel to us, but we disconnect before we have a confirmed channel.
peer ! SpawnChannelNonInitiator(Left(createOpenChannelMessage()), ChannelConfig.standard, ChannelTypes.Standard(), None, localParams, peerConnection.ref)
peer ! Peer.ConnectionDown(peerConnection.ref)
probe.send(peer, Peer.GetPeerInfo(Some(probe.ref.toTyped)))
assert(probe.expectMsgType[Peer.PeerInfo].state == Peer.DISCONNECTED)
assert(nodeParams.db.peers.getPeer(remoteNodeId).isEmpty)

// The channel confirms, so we store the remote features in our DB.
// We don't store the remote address because this was an incoming connection.
peer ! ChannelReadyForPayments(ActorRef.noSender, remoteNodeId, randomBytes32(), 0)
probe.send(peer, Peer.GetPeerInfo(Some(probe.ref.toTyped)))
assert(probe.expectMsgType[Peer.PeerInfo].state == Peer.DISCONNECTED)
assert(nodeParams.db.peers.getPeer(remoteNodeId).contains(NodeInfo(remoteInit.features, None)))
}

}

object PeerSpec {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ class ReconnectionTaskSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike
private val recommendedFeerates = RecommendedFeerates(Block.RegtestGenesisBlock.hash, TestConstants.feeratePerKw, TestConstants.anchorOutputsFeeratePerKw)

private val PeerNothingData = Peer.Nothing
private val PeerDisconnectedData = Peer.DisconnectedData(channels, activeChannels = Set.empty, PeerStorage(None, written = true))
private val PeerConnectedData = Peer.ConnectedData(fakeIPAddress, system.deadLetters, null, null, channels.map { case (k: ChannelId, v) => (k, v) }, activeChannels = Set.empty, recommendedFeerates, None, PeerStorage(None, written = true))
private val PeerDisconnectedData = Peer.DisconnectedData(channels, activeChannels = Set.empty, PeerStorage(None, written = true), remoteFeatures_opt = None)
private val PeerConnectedData = Peer.ConnectedData(fakeIPAddress, system.deadLetters, null, null, channels.map { case (k: ChannelId, v) => (k, v) }, activeChannels = Set.empty, recommendedFeerates, None, PeerStorage(None, written = true), remoteFeaturesWritten = true)

case class FixtureParam(nodeParams: NodeParams, remoteNodeId: PublicKey, reconnectionTask: TestFSMRef[ReconnectionTask.State, ReconnectionTask.Data, ReconnectionTask], monitor: TestProbe)

Expand Down Expand Up @@ -82,7 +82,7 @@ class ReconnectionTaskSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike
import f._

val peer = TestProbe()
peer.send(reconnectionTask, Peer.Transition(PeerNothingData, Peer.DisconnectedData(Map.empty, activeChannels = Set.empty, PeerStorage(None, written = true))))
peer.send(reconnectionTask, Peer.Transition(PeerNothingData, Peer.DisconnectedData(Map.empty, activeChannels = Set.empty, PeerStorage(None, written = true), None)))
monitor.expectNoMessage()
}

Expand Down Expand Up @@ -205,7 +205,6 @@ class ReconnectionTaskSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike
peer.send(reconnectionTask, Peer.Transition(PeerDisconnectedData, PeerConnectedData))
// we cancel the reconnection and go to idle state
val TransitionWithData(ReconnectionTask.WAITING, ReconnectionTask.IDLE, _, _) = monitor.expectMsgType[TransitionWithData]

}

test("reconnect using the address from node_announcement") { f =>
Expand All @@ -232,39 +231,35 @@ class ReconnectionTaskSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike
val tor = NodeAddress.fromParts("iq7zhmhck54vcax2vlrdcavq2m32wao7ekh6jyeglmnuuvv3js57r4id.onion", 9735).get

// NB: we don't test randomization here, but it makes tests unnecessary more complex for little value

{
// tor not supported: always return clearnet addresses
nodeParams.socksProxy_opt returns None
assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet)) == Some(clearnet))
assert(ReconnectionTask.selectNodeAddress(nodeParams, List(tor)) == None)
assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet, tor)) == Some(clearnet))
assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet)).contains(clearnet))
assert(ReconnectionTask.selectNodeAddress(nodeParams, List(tor)).isEmpty)
assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet, tor)).contains(clearnet))
}

{
// tor supported but not enabled for clearnet addresses: return clearnet addresses when available
val socksParams = mock[Socks5ProxyParams]
socksParams.useForTor returns true
socksParams.useForIPv4 returns false
socksParams.useForIPv6 returns false
nodeParams.socksProxy_opt returns Some(socksParams)
assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet)) == Some(clearnet))
assert(ReconnectionTask.selectNodeAddress(nodeParams, List(tor)) == Some(tor))
assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet, tor)) == Some(clearnet))
assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet)).contains(clearnet))
assert(ReconnectionTask.selectNodeAddress(nodeParams, List(tor)).contains(tor))
assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet, tor)).contains(clearnet))
}

{
// tor supported and enabled for clearnet addresses: return tor addresses when available
val socksParams = mock[Socks5ProxyParams]
socksParams.useForTor returns true
socksParams.useForIPv4 returns true
socksParams.useForIPv6 returns true
nodeParams.socksProxy_opt returns Some(socksParams)
assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet)) == Some(clearnet))
assert(ReconnectionTask.selectNodeAddress(nodeParams, List(tor)) == Some(tor))
assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet, tor)) == Some(tor))
assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet)).contains(clearnet))
assert(ReconnectionTask.selectNodeAddress(nodeParams, List(tor)).contains(tor))
assert(ReconnectionTask.selectNodeAddress(nodeParams, List(clearnet, tor)).contains(tor))
}

}

}

0 comments on commit 2092ead

Please sign in to comment.