diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/AbstractRouter.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/AbstractRouter.kt index a72b93cc..66f7e4ab 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/AbstractRouter.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/AbstractRouter.kt @@ -185,7 +185,9 @@ abstract class AbstractRouter( processControl(msg.control, peer) } - if (protocol.supportsExtensions()) { + // TODO we need to handle the existence of extension messages more generically (https://github.com/libp2p/jvm-libp2p/issues/441) + + if (protocol.supportsExtensions() && (msg.hasTestExtension() || msg.hasPartial())) { processExtensions(msg, peer) } diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/Gossip.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/Gossip.kt index 56b16268..39100f10 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/Gossip.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/Gossip.kt @@ -12,6 +12,7 @@ import io.libp2p.pubsub.PubsubApiImpl import io.libp2p.pubsub.PubsubProtocol import io.libp2p.pubsub.gossip.builders.GossipRouterBuilder import io.netty.channel.ChannelHandler +import org.slf4j.LoggerFactory import java.util.concurrent.CompletableFuture class Gossip @JvmOverloads constructor( @@ -21,6 +22,8 @@ class Gossip @JvmOverloads constructor( ) : ProtocolBinding, ConnectionHandler, PubsubApi by api { + private val logger = LoggerFactory.getLogger(Gossip::class.java) + fun updateTopicScoreParams(scoreParams: Map) { router.score.updateTopicParams(scoreParams) } @@ -62,6 +65,7 @@ class Gossip @JvmOverloads constructor( } override fun initChannel(ch: P2PChannel, selectedProtocol: String): CompletableFuture { + logger.trace("Gossip initChannel - selected protocol: {}", selectedProtocol) router.addPeerWithDebugHandler(ch as Stream, debugGossipHandler) return CompletableFuture.completedFuture(Unit) } diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipExtensionsState.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipExtensionsState.kt new file mode 100644 index 00000000..24daf6ad --- /dev/null +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipExtensionsState.kt @@ -0,0 +1,38 @@ +package io.libp2p.pubsub.gossip + +import io.libp2p.core.PeerId +import pubsub.pb.Rpc + +class GossipExtensionsState { + + /* + Tracks the peers that we have already sent a control extensions message + */ + private val outgoingControlExtensionsMsgPeers: MutableSet = mutableSetOf() + + /* + Tracks peers that already sent us a control extensions message + */ + private val peerExtensionSupportMap: MutableMap = mutableMapOf() + + fun onPeerDisconnected(peer: PeerId) { + outgoingControlExtensionsMsgPeers.remove(peer) + peerExtensionSupportMap.remove(peer) + } + + fun onControlExtensionsMessage(ctrlExtensions: Rpc.ControlExtensions, receivedFrom: PeerId) { + peerExtensionSupportMap[receivedFrom] = ctrlExtensions + } + + fun registerControlExtensionMessageSentToPeers(peerId: PeerId) { + outgoingControlExtensionsMsgPeers.add(peerId) + } + + fun peerSupportedExtensions(peerId: PeerId) = peerExtensionSupportMap[peerId] + + fun hasReceivedControlExtensionsFrom(peer: PeerId) = + peerExtensionSupportMap.contains(peer) + + fun hasSentControlExtensionsTo(peer: PeerId) = + outgoingControlExtensionsMsgPeers.contains(peer) +} diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRouter.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRouter.kt index 535d6acf..1bb46343 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRouter.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRouter.kt @@ -132,7 +132,7 @@ open class GossipRouter( private val acceptRequestsWhitelist = mutableMapOf() override val pendingRpcParts = PendingRpcPartsMap { DefaultGossipRpcPartsQueue(params) } - private val peerExtensionSupportMap = mutableMapOf() + val gossipExtensionsState = GossipExtensionsState() private fun setBackOff(peer: PeerHandler, topic: Topic) = setBackOff(peer, topic, params.pruneBackoff.toMillis()) private fun setBackOff(peer: PeerHandler, topic: Topic, delay: Long) { @@ -159,6 +159,7 @@ open class GossipRouter( fanout.values.forEach { it.remove(peer) } acceptRequestsWhitelist -= peer pendingRpcParts.popQueue(peer) // discard them + gossipExtensionsState.onPeerDisconnected(peer.peerId) super.onPeerDisconnected(peer) } @@ -166,6 +167,7 @@ open class GossipRouter( super.onPeerActive(peer) eventBroadcaster.notifyConnected(peer.peerId, peer.getRemoteAddress()) heartbeatTask.hashCode() // force lazy initialization + sendControlExtensions(peer) } override fun notifyUnseenMessage(peer: PeerHandler, msg: PubsubMessage) { @@ -398,34 +400,56 @@ open class GossipRouter( ) { logger.trace("Received control extension {}", ctrlExtensions.toString()) - if (peerExtensionSupportMap[receivedFrom.peerId] != null) { - // TODO Should downscore peers that send control extension multiple times? (https://github.com/libp2p/jvm-libp2p/issues/437) + if (gossipExtensionsState.hasReceivedControlExtensionsFrom(receivedFrom.peerId)) { + // TODO Should disconnect peers that send control extension multiple times (https://github.com/libp2p/jvm-libp2p/issues/437) logger.trace( "Received another control extension message from peer {}", receivedFrom.peerId ) return } else { - peerExtensionSupportMap[receivedFrom.peerId] = ctrlExtensions + gossipExtensionsState.onControlExtensionsMessage(ctrlExtensions, receivedFrom.peerId) } } override fun processExtensions(msg: Rpc.RPC, receivedFrom: PeerHandler) { - val peerSupportedExtensions = peerExtensionSupportMap[receivedFrom.peerId] - if (peerSupportedExtensions == null) { + val peerSupportedExtensions = + gossipExtensionsState.peerSupportedExtensions(receivedFrom.peerId) + + // TODO Revisit this logic as part of adding feature flags (https://github.com/libp2p/jvm-libp2p/issues/441) + + when { + msg.hasTestExtension() && checkPeerExtensionSupport( + peerSupportedExtensions, + Rpc.ControlExtensions::hasTestExtension + ) -> + processTestExtensionMessage(msg.testExtension, receivedFrom) + + msg.hasPartial() && checkPeerExtensionSupport( + peerSupportedExtensions, + Rpc.ControlExtensions::hasPartialMessages + ) -> + processPartialMessageExtension(msg.partial, receivedFrom) + } + } + + private fun checkPeerExtensionSupport( + peerSavedPreferences: Rpc.ControlExtensions?, + checkSupportFunction: (Rpc.ControlExtensions) -> Boolean + ): Boolean { + if (peerSavedPreferences == null) { + return false + } + + if (!checkSupportFunction.invoke(peerSavedPreferences)) { logger.trace( - "Ignoring extension messages from peer {} - did it send an extension control message?", - receivedFrom.peerId + "Ignoring extension messages from peer {} - did it send an control extensions message?", + peerSavedPreferences ) - } else { - when { - peerSupportedExtensions.hasTestExtension() && msg.hasTestExtension() -> - processTestExtensionMessage(msg.testExtension, receivedFrom) - - peerSupportedExtensions.hasPartialMessages() && msg.hasPartial() -> - processPartialMessageExtension(msg.partial, receivedFrom) - } + return false } + + return true } private fun processTestExtensionMessage( @@ -578,6 +602,8 @@ open class GossipRouter( fanout -= topic lastPublished -= topic } + + activePeers.forEach { sendControlExtensions(it) } } override fun unsubscribe(topic: Topic) { @@ -778,6 +804,33 @@ open class GossipRouter( send(peer, iDontWant) } + private fun sendControlExtensions(peer: PeerHandler) { + if (!this.protocol.supportsExtensions()) { + logger.trace( + "Protocol does not support extensions. Won't send control extensions message." + ) + return + } + + if (gossipExtensionsState.hasSentControlExtensionsTo(peer.peerId)) { + logger.trace( + "Already sent control extensions msg to peer {}. Won't send another one.", + peer.peerId + ) + return + } + + logger.trace("Sending control extensions message to peer {}", peer.peerId) + + pendingRpcParts.getQueue(peer).addControlExtensions( + Rpc.ControlExtensions.newBuilder() + .setTestExtension(true) + .setPartialMessages(true) + .build() + ) + gossipExtensionsState.registerControlExtensionMessageSentToPeers(peer.peerId) + } + data class AcceptRequestsWhitelistEntry(val whitelistedTill: Long, val messagesAccepted: Int = 0) { fun incrementMessageCount() = AcceptRequestsWhitelistEntry(whitelistedTill, messagesAccepted + 1) } diff --git a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRpcPartsQueue.kt b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRpcPartsQueue.kt index e9033258..32e5c908 100644 --- a/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRpcPartsQueue.kt +++ b/libp2p/src/main/kotlin/io/libp2p/pubsub/gossip/GossipRpcPartsQueue.kt @@ -26,6 +26,9 @@ interface GossipRpcPartsQueue : RpcPartsQueue { * Gossip 1.1 variant */ fun addPrune(topic: Topic, backoffSeconds: Long, backoffPeers: List) + + // TODO Need to check if we should handle when control extension and extension messages could be separated by split (https://github.com/libp2p/jvm-libp2p/issues/440) + fun addControlExtensions(ctrlMessage: Rpc.ControlExtensions) } /** @@ -81,6 +84,12 @@ open class DefaultGossipRpcPartsQueue( } } + protected data class ControlExtensionPart(val ctrlExtension: Rpc.ControlExtensions) : AbstractPart { + override fun appendToBuilder(builder: Rpc.RPC.Builder) { + builder.controlBuilder.setExtensions(ctrlExtension) + } + } + override fun addIHave(messageId: MessageId, topic: Topic) { addPart(IHavePart(messageId, topic)) } @@ -101,6 +110,10 @@ open class DefaultGossipRpcPartsQueue( addPart(PrunePart(topic, backoffSeconds, backoffPeers)) } + override fun addControlExtensions(ctrlMessage: Rpc.ControlExtensions) { + addPart(ControlExtensionPart(ctrlMessage)) + } + override fun takeMerged(): List { val ret = mutableListOf() var partIdx = 0 diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipExtensionsStateTest.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipExtensionsStateTest.kt new file mode 100644 index 00000000..315c8dec --- /dev/null +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipExtensionsStateTest.kt @@ -0,0 +1,358 @@ +package io.libp2p.pubsub.gossip + +import io.libp2p.core.PeerId +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.Test +import pubsub.pb.Rpc + +class GossipExtensionsStateTest { + + private lateinit var extensionsState: GossipExtensionsState + private lateinit var peer1: PeerId + private lateinit var peer2: PeerId + private lateinit var peer3: PeerId + + @BeforeEach + fun setup() { + extensionsState = GossipExtensionsState() + peer1 = PeerId.random() + peer2 = PeerId.random() + peer3 = PeerId.random() + } + + @Test + fun `onControlExtensionsMessage() stores peer extensions support`() { + val extension = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .setTestExtension(false) + .build() + + extensionsState.onControlExtensionsMessage(extension, peer1) + + val stored = extensionsState.peerSupportedExtensions(peer1) + assertThat(stored).isNotNull + assertThat(stored!!.partialMessages).isTrue() + assertThat(stored.testExtension).isFalse() + } + + @Test + fun `hasReceivedControlExtensionsFrom() returns true after receiving extensions`() { + val extension = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .build() + + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer1)).isFalse() + + extensionsState.onControlExtensionsMessage(extension, peer1) + + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer1)).isTrue() + } + + @Test + fun `hasReceivedControlExtensionsFrom() returns false for unknown peer`() { + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer1)).isFalse() + } + + @Test + fun `peerSupportedExtensions() returns null for unknown peer`() { + val extensions = extensionsState.peerSupportedExtensions(peer1) + assertThat(extensions).isNull() + } + + /* + In practice, we should not receive more than one control message from the same peer on + the same connection, but if this ever happens, it makes sense to override the in-memory + config given it most likely has the most up-to-date data for that particular peer + */ + @Test + fun `onControlExtensionsMessage() overwrites previous extensions from same peer`() { + val extension1 = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .setTestExtension(false) + .build() + + val extension2 = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(false) + .setTestExtension(true) + .build() + + extensionsState.onControlExtensionsMessage(extension1, peer1) + extensionsState.onControlExtensionsMessage(extension2, peer1) + + val stored = extensionsState.peerSupportedExtensions(peer1) + assertThat(stored).isNotNull + assertThat(stored!!.partialMessages).isFalse() + assertThat(stored.testExtension).isTrue() + } + + @Test + fun `hasSentControlExtensionsTo() returns false for unknown peer`() { + assertThat(extensionsState.hasSentControlExtensionsTo(peer1)).isFalse() + } + + @Test + fun `registerControlExtensionMessageSentToPeers() registers peer`() { + assertThat(extensionsState.hasSentControlExtensionsTo(peer1)).isFalse() + + extensionsState.registerControlExtensionMessageSentToPeers(peer1) + + assertThat(extensionsState.hasSentControlExtensionsTo(peer1)).isTrue() + } + + @Test + fun `hasSentControlExtensionsTo() returns true after registration`() { + extensionsState.registerControlExtensionMessageSentToPeers(peer1) + + assertThat(extensionsState.hasSentControlExtensionsTo(peer1)).isTrue() + } + + @Test + fun `registerControlExtensionMessageSentToPeers() can register multiple peers`() { + extensionsState.registerControlExtensionMessageSentToPeers(peer1) + extensionsState.registerControlExtensionMessageSentToPeers(peer2) + + assertThat(extensionsState.hasSentControlExtensionsTo(peer1)).isTrue() + assertThat(extensionsState.hasSentControlExtensionsTo(peer2)).isTrue() + assertThat(extensionsState.hasSentControlExtensionsTo(peer3)).isFalse() + } + + @Test + fun `sent and received extension tracking are independent`() { + // Register that we sent to peer1 + extensionsState.registerControlExtensionMessageSentToPeers(peer1) + + // Receive from peer2 + val extension = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .build() + extensionsState.onControlExtensionsMessage(extension, peer2) + + // Verify sent tracking + assertThat(extensionsState.hasSentControlExtensionsTo(peer1)).isTrue() + assertThat(extensionsState.hasSentControlExtensionsTo(peer2)).isFalse() + + // Verify received tracking + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer1)).isFalse() + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer2)).isTrue() + } + + @Test + fun `peer can be in both sent and received tracking`() { + // Register that we sent to peer1 + extensionsState.registerControlExtensionMessageSentToPeers(peer1) + + // Receive from peer1 + val extension = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .build() + extensionsState.onControlExtensionsMessage(extension, peer1) + + // Both should be tracked + assertThat(extensionsState.hasSentControlExtensionsTo(peer1)).isTrue() + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer1)).isTrue() + } + + @Test + fun `tracks multiple peers with different extensions`() { + val extension1 = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .setTestExtension(false) + .build() + + val extension2 = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(false) + .setTestExtension(true) + .build() + + val extension3 = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .setTestExtension(true) + .build() + + extensionsState.onControlExtensionsMessage(extension1, peer1) + extensionsState.onControlExtensionsMessage(extension2, peer2) + extensionsState.onControlExtensionsMessage(extension3, peer3) + + // Verify each peer has correct extensions + val stored1 = extensionsState.peerSupportedExtensions(peer1) + assertThat(stored1!!.partialMessages).isTrue() + assertThat(stored1.testExtension).isFalse() + + val stored2 = extensionsState.peerSupportedExtensions(peer2) + assertThat(stored2!!.partialMessages).isFalse() + assertThat(stored2.testExtension).isTrue() + + val stored3 = extensionsState.peerSupportedExtensions(peer3) + assertThat(stored3!!.partialMessages).isTrue() + assertThat(stored3.testExtension).isTrue() + + // Verify all peers are tracked + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer1)).isTrue() + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer2)).isTrue() + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer3)).isTrue() + } + + @Test + fun `tracks many peers simultaneously`() { + val peers = (1..10).map { PeerId.random() } + val extension = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .build() + + peers.forEach { peer -> + extensionsState.onControlExtensionsMessage(extension, peer) + } + + // Verify all peers are tracked + peers.forEach { peer -> + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer)).isTrue() + assertThat(extensionsState.peerSupportedExtensions(peer)).isNotNull + } + } + + @Test + fun `onPeerDisconnected() removes peer from received extensions map`() { + val extension = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .build() + + extensionsState.onControlExtensionsMessage(extension, peer1) + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer1)).isTrue() + + extensionsState.onPeerDisconnected(peer1) + + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer1)).isFalse() + assertThat(extensionsState.peerSupportedExtensions(peer1)).isNull() + } + + @Test + fun `onPeerDisconnected() handles unknown peer gracefully`() { + // Should not throw exception for unknown peer + extensionsState.onPeerDisconnected(peer1) + + // State should remain empty + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer1)).isFalse() + assertThat(extensionsState.peerSupportedExtensions(peer1)).isNull() + } + + @Test + fun `onPeerDisconnected() only removes specified peer`() { + val extension1 = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .build() + + val extension2 = Rpc.ControlExtensions.newBuilder() + .setTestExtension(true) + .build() + + extensionsState.onControlExtensionsMessage(extension1, peer1) + extensionsState.onControlExtensionsMessage(extension2, peer2) + + // Disconnect peer1 + extensionsState.onPeerDisconnected(peer1) + + // peer1 should be removed + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer1)).isFalse() + assertThat(extensionsState.peerSupportedExtensions(peer1)).isNull() + + // peer2 should remain + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer2)).isTrue() + assertThat(extensionsState.peerSupportedExtensions(peer2)).isNotNull + assertThat(extensionsState.peerSupportedExtensions(peer2)!!.testExtension).isTrue() + } + + @Test + fun `multiple disconnects and reconnects work correctly`() { + val extension = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .build() + + // Connect + extensionsState.onControlExtensionsMessage(extension, peer1) + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer1)).isTrue() + + // Disconnect + extensionsState.onPeerDisconnected(peer1) + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer1)).isFalse() + + // Reconnect with different extensions + val newExtension = Rpc.ControlExtensions.newBuilder() + .setTestExtension(true) + .build() + extensionsState.onControlExtensionsMessage(newExtension, peer1) + + val stored = extensionsState.peerSupportedExtensions(peer1) + assertThat(stored).isNotNull + assertThat(stored!!.hasPartialMessages()).isFalse() + assertThat(stored.testExtension).isTrue() + } + + @Test + fun `onPeerDisconnected() removes peer from sent extensions list`() { + extensionsState.registerControlExtensionMessageSentToPeers(peer1) + assertThat(extensionsState.hasSentControlExtensionsTo(peer1)).isTrue() + + extensionsState.onPeerDisconnected(peer1) + + assertThat(extensionsState.hasSentControlExtensionsTo(peer1)).isFalse() + } + + @Test + fun `onPeerDisconnected() removes peer from both sent and received tracking`() { + // Register sent + extensionsState.registerControlExtensionMessageSentToPeers(peer1) + + // Register received + val extension = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .build() + extensionsState.onControlExtensionsMessage(extension, peer1) + + // Verify both tracked + assertThat(extensionsState.hasSentControlExtensionsTo(peer1)).isTrue() + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer1)).isTrue() + + // Disconnect + extensionsState.onPeerDisconnected(peer1) + + // Both should be removed + assertThat(extensionsState.hasSentControlExtensionsTo(peer1)).isFalse() + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer1)).isFalse() + assertThat(extensionsState.peerSupportedExtensions(peer1)).isNull() + } + + @Test + fun `onPeerDisconnected() only removes specified peer from sent list`() { + extensionsState.registerControlExtensionMessageSentToPeers(peer1) + extensionsState.registerControlExtensionMessageSentToPeers(peer2) + + extensionsState.onPeerDisconnected(peer1) + + assertThat(extensionsState.hasSentControlExtensionsTo(peer1)).isFalse() + assertThat(extensionsState.hasSentControlExtensionsTo(peer2)).isTrue() + } + + @Test + fun `reconnecting peer can have sent extension registered again`() { + // First connection - register sent + extensionsState.registerControlExtensionMessageSentToPeers(peer1) + assertThat(extensionsState.hasSentControlExtensionsTo(peer1)).isTrue() + + // Disconnect + extensionsState.onPeerDisconnected(peer1) + assertThat(extensionsState.hasSentControlExtensionsTo(peer1)).isFalse() + + // Reconnect - register sent again + extensionsState.registerControlExtensionMessageSentToPeers(peer1) + assertThat(extensionsState.hasSentControlExtensionsTo(peer1)).isTrue() + } + + @Test + fun `querying empty state returns expected values`() { + extensionsState = GossipExtensionsState() + assertThat(extensionsState.hasReceivedControlExtensionsFrom(peer1)).isFalse() + assertThat(extensionsState.hasSentControlExtensionsTo(peer1)).isFalse() + assertThat(extensionsState.peerSupportedExtensions(peer1)).isNull() + } +} diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipRpcPartsQueueTest.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipRpcPartsQueueTest.kt index 5b6b35e5..e978877d 100644 --- a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipRpcPartsQueueTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/GossipRpcPartsQueueTest.kt @@ -306,4 +306,191 @@ class GossipRpcPartsQueueTest { .addMessageIDs("2222".toWBytes().toProtobuf()).build(), ) } + + @Test + fun `addControlExtensions() sets testExtension flag in control message`() { + val partsQueue = TestGossipQueue(gossipParamsNoLimits) + + val extension = Rpc.ControlExtensions.newBuilder() + .setTestExtension(true) + .build() + + partsQueue.addControlExtensions(extension) + + val res = partsQueue.takeMerged().first() + + assertThat(res.hasControl()).isTrue() + assertThat(res.control.hasExtensions()).isTrue() + assertThat(res.control.extensions.testExtension).isTrue() + } + + @Test + fun `addControlExtensions() sets partialMessages flag in control message`() { + val partsQueue = TestGossipQueue(gossipParamsNoLimits) + + val extension = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .build() + + partsQueue.addControlExtensions(extension) + + val res = partsQueue.takeMerged().first() + + assertThat(res.hasControl()).isTrue() + assertThat(res.control.hasExtensions()).isTrue() + assertThat(res.control.extensions.partialMessages).isTrue() + } + + @Test + fun `addControlExtensions() sets all extension flags`() { + val partsQueue = TestGossipQueue(gossipParamsNoLimits) + + val extension = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .setTestExtension(true) + .build() + + partsQueue.addControlExtensions(extension) + + val res = partsQueue.takeMerged().first() + + assertThat(res.hasControl()).isTrue() + assertThat(res.control.hasExtensions()).isTrue() + assertThat(res.control.extensions.partialMessages).isTrue() + assertThat(res.control.extensions.testExtension).isTrue() + } + + @Test + fun `control extensions message works with other control messages`() { + val partsQueue = TestGossipQueue(gossipParamsNoLimits) + + // Add various control messages + partsQueue.addIHave(byteArrayOf(1).toWBytes(), "topic1") + partsQueue.addIWant(byteArrayOf(2).toWBytes()) + partsQueue.addGraft("topic2") + partsQueue.addPrune("topic3") + + // Add extension + val extension = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .build() + partsQueue.addControlExtensions(extension) + + val res = partsQueue.takeMerged().first() + + // Verify all control messages are present + assertThat(res.hasControl()).isTrue() + assertThat(res.control.ihaveList).hasSize(1) + assertThat(res.control.iwantList).hasSize(1) + assertThat(res.control.graftList).hasSize(1) + assertThat(res.control.pruneList).hasSize(1) + + // Verify extension is present + assertThat(res.control.hasExtensions()).isTrue() + assertThat(res.control.extensions.partialMessages).isTrue() + } + + @Test + fun `control extensions message with subscriptions and publishes`() { + val partsQueue = TestGossipQueue(gossipParamsNoLimits) + + partsQueue.addSubscribe("topic1") + partsQueue.addPublish(createRpcMessage("topic1", "data1")) + + val extension = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .build() + partsQueue.addControlExtensions(extension) + + val res = partsQueue.takeMerged().first() + + // Verify subscriptions and publishes + assertThat(res.subscriptionsList).hasSize(1) + assertThat(res.publishList).hasSize(1) + + // Verify extension + assertThat(res.control.hasExtensions()).isTrue() + assertThat(res.control.extensions.partialMessages).isTrue() + } + + @Test + fun `control extensions message works with message splitting`() { + val partsQueue = TestGossipQueue(gossipParamsWithLimits) + + // Add enough messages to force splitting + (1..20).forEach { + partsQueue.addPublish(createRpcMessage("topic-$it", "data")) + } + + // Add extension + val extension = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .build() + partsQueue.addControlExtensions(extension) + + val merged = partsQueue.takeMerged() + + // Should be split into multiple RPCs due to maxPublishedMessages limit + assertThat(merged.size).isGreaterThan(1) + + // Extension should be in the last RPC (since it's added last) + val lastRpc = merged.last() + assertThat(lastRpc.hasControl()).isTrue() + assertThat(lastRpc.control.hasExtensions()).isTrue() + assertThat(lastRpc.control.extensions.partialMessages).isTrue() + } + + @Test + fun `multiple control extensions messages - last one wins`() { + val partsQueue = TestGossipQueue(gossipParamsNoLimits) + + // Add first extension + val extension1 = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .setTestExtension(false) + .build() + partsQueue.addControlExtensions(extension1) + + // Add second extension (should overwrite first) + val extension2 = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(false) + .setTestExtension(true) + .build() + partsQueue.addControlExtensions(extension2) + + val res = partsQueue.takeMerged().first() + + // Verify only the last extension is present + assertThat(res.control.hasExtensions()).isTrue() + // Note: false flags may or may not be serialized depending on protobuf default behavior + // But testExtension should definitely be true + assertThat(res.control.extensions.testExtension).isTrue() + } + + @Test + fun `control extensions message does not count toward limits but may be split`() { + val partsQueue = TestGossipQueue(gossipParamsWithLimits) + + // Add exactly maxPublishedMessages messages + (1..maxPublishedMessages).forEach { + partsQueue.addPublish(createRpcMessage("topic-$it", "data")) + } + + // Add extension + val extension = Rpc.ControlExtensions.newBuilder() + .setPartialMessages(true) + .build() + partsQueue.addControlExtensions(extension) + + val merged = partsQueue.takeMerged() + + // Extension doesn't count toward limits, but it may end up in a separate RPC + // if it comes after parts that exhaust a limit + assertThat(merged).hasSize(2) + assertThat(merged[0].publishList).hasSize(maxPublishedMessages) + + // Extension should be in the second RPC + assertThat(merged[1].control.hasExtensions()).isTrue() + assertThat(merged[1].control.extensions.partialMessages).isTrue() + } } diff --git a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/extensions/GossipExtensionsMessageHandlingTest.kt b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/extensions/GossipExtensionsMessageHandlingTest.kt index e9c97ce8..391e7a76 100644 --- a/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/extensions/GossipExtensionsMessageHandlingTest.kt +++ b/libp2p/src/test/kotlin/io/libp2p/pubsub/gossip/extensions/GossipExtensionsMessageHandlingTest.kt @@ -2,11 +2,16 @@ package io.libp2p.pubsub.gossip.extensions import io.libp2p.pubsub.PubsubProtocol import io.libp2p.pubsub.gossip.GossipTestsBase +import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource import pubsub.pb.Rpc import java.util.concurrent.TimeoutException +private const val DEFAULT_WAIT_TIMEOUT_IN_MILLIS = 500L + class GossipExtensionsMessageHandlingTest : GossipTestsBase() { @Test @@ -15,87 +20,195 @@ class GossipExtensionsMessageHandlingTest : GossipTestsBase() { protocol = PubsubProtocol.Gossip_V_1_2 ) - val rpcMessageWithControlExtensionAndTestExtensionMessages = Rpc.RPC.newBuilder() - .setControl( - Rpc.ControlMessage.newBuilder() - .setExtensions(Rpc.ControlExtensions.newBuilder().setTestExtension(true)) - .build() - ) - .setTestExtension(Rpc.TestExtension.newBuilder().build()) - .build() - test.mockRouter.sendToSingle(rpcMessageWithControlExtensionAndTestExtensionMessages) - + test.mockRouter.sendToSingle(rpcMsgWithCtrlExtensionsAndTestExtension) assertNoResponseFromTestExtension(test) } @Test - fun `extension messages sent to peer prior to sending extension control messages are ignored`() { + fun `extension messages sent to peer prior to sending control extensions messages are ignored`() { val test = TwoRoutersTest( protocol = PubsubProtocol.Gossip_V_1_3 ) - val rpcMessageWithTestExtension = - Rpc.RPC.newBuilder().setTestExtension(testExtensionMessage).build() test.mockRouter.sendToSingle(rpcMessageWithTestExtension) - assertNoResponseFromTestExtension(test) } @Test - fun `extension message flow with extension control message before actual extension message`() { + fun `extension message flow with control extensions message before actual extension message`() { val test = TwoRoutersTest( protocol = PubsubProtocol.Gossip_V_1_3 ) - val rpcMessageWithControl = Rpc.RPC.newBuilder().setControl( - Rpc.ControlMessage.newBuilder().setExtensions(controlExtensionMessage()) - ).build() - test.mockRouter.sendToSingle(rpcMessageWithControl) + test.mockRouter.sendToSingle(rpcMessageWithControlExtensions) + assertThat(test.gossipRouter.gossipExtensionsState.peerSupportedExtensions(test.router2.peerId)).isEqualTo( + rpcMessageWithControlExtensions.control.extensions + ) - val rpcMessageWithTestExtension = - Rpc.RPC.newBuilder().setTestExtension(testExtensionMessage).build() test.mockRouter.sendToSingle(rpcMessageWithTestExtension) + test.mockRouter.waitForMessage { it.hasTestExtension() } + } + + @Test + fun `extension message flow with control extensions and extension message in the same rpc message`() { + val test = TwoRoutersTest( + protocol = PubsubProtocol.Gossip_V_1_3 + ) + + test.mockRouter.sendToSingle(rpcMsgWithCtrlExtensionsAndTestExtension) + test.mockRouter.waitForMessage { it.hasTestExtension() } + } + + @Test + fun `remove peer control extensions map when disconnecting`() { + val test = TwoRoutersTest( + protocol = PubsubProtocol.Gossip_V_1_3 + ) + + test.mockRouter.sendToSingle(rpcMsgWithCtrlExtensionsAndTestExtension) + + assertThat(test.gossipRouter.gossipExtensionsState.peerSupportedExtensions(test.router2.peerId)).isEqualTo( + rpcMsgWithCtrlExtensionsAndTestExtension.control.extensions + ) test.mockRouter.waitForMessage { it.hasTestExtension() } + + // Successfully registered peer2 extensions support + + assertThat(test.gossipRouter.gossipExtensionsState.peerSupportedExtensions(test.router2.peerId)).isNotNull() + + test.connection.disconnect() + + // After disconnecting removes peer2 from extensions support map + assertThat(test.gossipRouter.gossipExtensionsState.peerSupportedExtensions(test.router2.peerId)).isNull() + } + + @ParameterizedTest + @MethodSource("protocolVersionsWithExtensionSupport") + fun `control extension message sent to peer on connection with extension support`(protocol: PubsubProtocol) { + val test = TwoRoutersTest(protocol = protocol) + + val receivedMessage = test.mockRouter.waitForMessage( + { it.hasControl() && it.control.hasExtensions() }, + DEFAULT_WAIT_TIMEOUT_IN_MILLIS + ) + + assertThat(receivedMessage.control.extensions.partialMessages).isTrue() + assertThat(receivedMessage.control.extensions.testExtension).isTrue() + } + + @ParameterizedTest + @MethodSource("protocolVersionsWithoutExtensionSupport") + fun `control extension message not sent to peer on connection without extension support`(protocol: PubsubProtocol) { + val test = TwoRoutersTest(protocol = protocol) + + // Should not receive control extension message on versions without extension support + assertThrows { + test.mockRouter.waitForMessage( + { it.hasControl() && it.control.hasExtensions() }, + DEFAULT_WAIT_TIMEOUT_IN_MILLIS + ) + } + } + + @Test + fun `control extension message contains all supported extensions flags`() { + val test = TwoRoutersTest( + protocol = PubsubProtocol.Gossip_V_1_3 + ) + + val receivedMessage = test.mockRouter.waitForMessage( + { it.hasControl() && it.control.hasExtensions() }, + 2000L + ) + + val extensions = receivedMessage.control.extensions + + // Verify both extension flags are set + assertThat(extensions.hasPartialMessages()).isTrue() + assertThat(extensions.partialMessages).isTrue() + assertThat(extensions.hasTestExtension()).isTrue() + assertThat(extensions.testExtension).isTrue() } @Test - fun `extension message flow with extension control and extension message in the same rpc message`() { + fun `extension state tracks that we sent control extension to peer`() { val test = TwoRoutersTest( protocol = PubsubProtocol.Gossip_V_1_3 ) - val rpcMessageWithControlExtensionAndTestExtensionMessages = Rpc.RPC.newBuilder() + // Wait for control extension message to be sent + test.mockRouter.waitForMessage( + { it.hasControl() && it.control.hasExtensions() }, + DEFAULT_WAIT_TIMEOUT_IN_MILLIS + ) + + // Should be tracked in state + assertThat(test.gossipRouter.gossipExtensionsState.hasSentControlExtensionsTo(test.router2.peerId)).isTrue() + } + + @Test + fun `control extension sent state cleared on peer disconnect`() { + val test = TwoRoutersTest( + protocol = PubsubProtocol.Gossip_V_1_3 + ) + + // Wait for control extension message + test.mockRouter.waitForMessage( + { it.hasControl() && it.control.hasExtensions() }, + DEFAULT_WAIT_TIMEOUT_IN_MILLIS + ) + + // Verify it's tracked + assertThat(test.gossipRouter.gossipExtensionsState.hasSentControlExtensionsTo(test.router2.peerId)).isTrue() + + // Disconnect + test.connection.disconnect() + + // Should be cleared from sent tracking + assertThat(test.gossipRouter.gossipExtensionsState.hasSentControlExtensionsTo(test.router2.peerId)).isFalse() + } + + companion object { + @JvmStatic + fun protocolVersionsWithExtensionSupport() = listOf( + PubsubProtocol.Gossip_V_1_3 + ) + + @JvmStatic + fun protocolVersionsWithoutExtensionSupport() = listOf( + PubsubProtocol.Gossip_V_1_1, + PubsubProtocol.Gossip_V_1_2 + ) + + val testExtensionMessage: Rpc.TestExtension = Rpc.TestExtension.newBuilder().build() + + val rpcMessageWithControlExtensions = Rpc.RPC.newBuilder().setControl( + Rpc.ControlMessage.newBuilder().setExtensions(controlExtensionMessage()) + ).build()!! + + val rpcMessageWithTestExtension = + Rpc.RPC.newBuilder().setTestExtension(testExtensionMessage).build()!! + + // An RPC message with both ControlExtensions and TestExtension message (test extension enabled on control) + val rpcMsgWithCtrlExtensionsAndTestExtension = Rpc.RPC.newBuilder() .setControl( Rpc.ControlMessage.newBuilder() .setExtensions(Rpc.ControlExtensions.newBuilder().setTestExtension(true)) .build() ) .setTestExtension(Rpc.TestExtension.newBuilder().build()) - .build() - test.mockRouter.sendToSingle(rpcMessageWithControlExtensionAndTestExtensionMessages) + .build()!! - test.mockRouter.waitForMessage { it.hasTestExtension() } - } - - companion object { - val testExtensionControlEnabledMessage: Rpc.RPC = Rpc.RPC.newBuilder().setControl( - Rpc.ControlMessage.newBuilder() - .setExtensions(Rpc.ControlExtensions.newBuilder().setTestExtension(true).build()) - .build() - ).build() - - fun controlExtensionMessage(testExtensionEnabled: Boolean = false): Rpc.ControlExtensions { + fun controlExtensionMessage(testExtensionEnabled: Boolean = true): Rpc.ControlExtensions { return Rpc.ControlExtensions.newBuilder().setTestExtension(testExtensionEnabled).build() } - val testExtensionMessage: Rpc.TestExtension = Rpc.TestExtension.newBuilder().build() - fun assertNoResponseFromTestExtension(test: TwoRoutersTest) { assertThrows { test.mockRouter.waitForMessage( { it.hasTestExtension() }, - 500L + DEFAULT_WAIT_TIMEOUT_IN_MILLIS ) } }