From bf33269faf2dd1c27578e060d8d8391ee5cd224e Mon Sep 17 00:00:00 2001
From: Dmytro Vyazelenko <696855+vyazelenko@users.noreply.github.com>
Date: Thu, 11 Apr 2024 20:03:09 +0200
Subject: [PATCH] [Java] Fix connection race + Mockito error.

---
 .../artio/engine/framer/FramerTest.java       | 34 +++++++++++--------
 1 file changed, 19 insertions(+), 15 deletions(-)

diff --git a/artio-core/src/test/java/uk/co/real_logic/artio/engine/framer/FramerTest.java b/artio-core/src/test/java/uk/co/real_logic/artio/engine/framer/FramerTest.java
index d4fa4a1d98..103539100d 100644
--- a/artio-core/src/test/java/uk/co/real_logic/artio/engine/framer/FramerTest.java
+++ b/artio-core/src/test/java/uk/co/real_logic/artio/engine/framer/FramerTest.java
@@ -22,6 +22,7 @@
 import org.agrona.DirectBuffer;
 import org.agrona.ErrorHandler;
 import org.agrona.LangUtil;
+import org.agrona.collections.MutableLong;
 import org.agrona.concurrent.AgentInvoker;
 import org.agrona.concurrent.QueuedPipe;
 import org.agrona.concurrent.status.CountersReader;
@@ -32,6 +33,7 @@
 import org.junit.Test;
 import org.mockito.ArgumentCaptor;
 import org.mockito.Mockito;
+import org.mockito.stubbing.Answer;
 import org.mockito.verification.VerificationMode;
 import uk.co.real_logic.artio.CloseChecker;
 import uk.co.real_logic.artio.FixCounters;
@@ -67,6 +69,7 @@
 import static org.junit.Assert.*;
 import static org.mockito.Mockito.*;
 import static uk.co.real_logic.artio.CommonConfiguration.DEFAULT_NAME_PREFIX;
+import static uk.co.real_logic.artio.GatewayProcess.NO_CONNECTION_ID;
 import static uk.co.real_logic.artio.Timing.assertEventuallyTrue;
 import static uk.co.real_logic.artio.engine.FixEngine.ENGINE_LIBRARY_ID;
 import static uk.co.real_logic.artio.library.FixLibrary.NO_MESSAGE_REPLAY;
@@ -135,7 +138,7 @@ public class FramerTest
 
     private Framer framer;
 
-    private final ArgumentCaptor<Long> connectionId = ArgumentCaptor.forClass(Long.class);
+    private final MutableLong connectionId = new MutableLong(NO_CONNECTION_ID);
     private final ErrorHandler errorHandler = mock(ErrorHandler.class);
 
     @Before
@@ -150,15 +153,19 @@ public void setUp() throws IOException
         when(outboundLibrarySubscription.imageBySessionId(anyInt())).thenReturn(normalImage);
 
         when(mockEndPointFactory.receiverEndPoint(
-            any(), connectionId.capture(), anyLong(), anyInt(), anyInt(), any()))
-            .thenReturn(mockReceiverEndPoint);
+            any(), anyLong(), anyLong(), anyInt(), anyInt(), any()))
+            .thenAnswer((Answer<FixReceiverEndPoint>)invocationOnMock ->
+            {
+                connectionId.set(invocationOnMock.getArgument(1));
+                return mockReceiverEndPoint;
+            });
 
         when(mockEndPointFactory.senderEndPoint(any(), anyLong(), anyInt(), any(), any()))
             .thenReturn(mockSenderEndPoint);
 
-        when(mockReceiverEndPoint.connectionId()).then((inv) -> connectionId.getValue());
+        when(mockReceiverEndPoint.connectionId()).then((inv) -> connectionId.get());
 
-        when(mockSenderEndPoint.connectionId()).then((inv) -> connectionId.getValue());
+        when(mockSenderEndPoint.connectionId()).then((inv) -> connectionId.get());
 
         when(gatewaySession.session()).thenReturn(session);
         when(gatewaySession.fixDictionary()).thenReturn(fixDictionary);
@@ -278,7 +285,7 @@ public void shouldCloseSocketUponDisconnect() throws Exception
         aClientConnects();
         framer.doWork();
 
-        framer.onDisconnect(LIBRARY_ID, connectionId.getValue(), APPLICATION_DISCONNECT);
+        framer.onDisconnect(LIBRARY_ID, connectionId.get(), APPLICATION_DISCONNECT);
         framer.doWork();
 
         verifyEndPointsDisconnected(APPLICATION_DISCONNECT);
@@ -298,6 +305,7 @@ public void shouldNotConnectIfLibraryUnknown() throws Exception
         framer.doWork();
 
         assertNull("Sender has connected to server", server.accept());
+        assertEquals(NO_CONNECTION_ID, connectionId.get());
         verifyErrorPublished(UNKNOWN_LIBRARY);
     }
 
@@ -357,7 +365,6 @@ public void shouldIdentifyDuplicateInitiatedSessions() throws Exception
         assertEquals(CONTINUE, onInitiateConnection());
 
         verifyErrorPublished(DUPLICATE_SESSION);
-        assertNull(server.accept());
     }
 
     @Test
@@ -730,7 +737,7 @@ private void releaseConnection(final Action expectedResult)
     {
         assertEquals(expectedResult, framer.onReleaseSession(
             LIBRARY_ID,
-            connectionId.getValue(),
+            connectionId.get(),
             SESSION_ID,
             CORR_ID,
             ACTIVE,
@@ -749,7 +756,7 @@ private Action onLibraryConnect()
 
     private void givenAGatewayToManage()
     {
-        when(gatewaySession.connectionId()).thenReturn(connectionId.getValue());
+        when(gatewaySession.connectionId()).thenReturn(connectionId.get());
         when(gatewaySession.sessionKey()).thenReturn(mock(CompositeKey.class));
         when(gatewaySessions.sessions()).thenReturn(singletonList(gatewaySession));
     }
@@ -843,13 +850,10 @@ private void initiateConnection() throws Exception
 
         assertEquals(CONTINUE, onInitiateConnection());
 
-        do
+        while (NO_CONNECTION_ID == connectionId.get())
         {
             framer.doWork();
         }
-        while (server.accept() == null);
-
-        assertNotNull("Connection not completed yet", connectionId.getValue());
     }
 
     private Action onInitiateConnection()
@@ -900,7 +904,7 @@ private void notifyLibraryOfConnection()
     private void notifyLibraryOfConnection(final VerificationMode times)
     {
         verify(inboundPublication, times).saveManageSession(eq(LIBRARY_ID),
-            eq(connectionId.getValue()),
+            eq(connectionId.get()),
             anyLong(),
             anyInt(),
             anyInt(),
@@ -943,7 +947,7 @@ private void notifyLibraryOfConnection(final VerificationMode times)
     private void verifySessionExistsSaved(final VerificationMode times, final SessionStatus status)
     {
         verify(inboundPublication, times).saveManageSession(eq(LIBRARY_ID),
-            eq(connectionId.getValue()),
+            eq(connectionId.get()),
             anyLong(),
             anyInt(),
             anyInt(),