Skip to content

Commit

Permalink
Fix unit tests on MacOs
Browse files Browse the repository at this point in the history
  • Loading branch information
alex268 committed Jan 24, 2025
1 parent 28bee3a commit d34acdc
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import java.util.Map;
import java.util.stream.Collectors;

import javax.net.SocketFactory;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Ticker;
import org.slf4j.Logger;
Expand Down Expand Up @@ -105,7 +107,7 @@ static String detectLocalDC(List<EndpointRecord> endpoints, Ticker ticker) {
}

private static long tcpPing(InetSocketAddress socketAddress, Ticker ticker) {
try (Socket socket = new Socket()) {
try (Socket socket = SocketFactory.getDefault().createSocket()) {
final long startConnection = ticker.read();
socket.connect(socketAddress, DETECT_DC_TCP_PING_TIMEOUT_MS);
final long stopConnection = ticker.read();
Expand Down
153 changes: 75 additions & 78 deletions core/src/test/java/tech/ydb/core/impl/pool/EndpointPoolTest.java
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package tech.ydb.core.impl.pool;

import java.io.IOException;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketAddress;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;

import javax.net.ServerSocketFactory;
import javax.net.SocketFactory;

import com.google.common.base.Ticker;
import org.junit.After;
Expand All @@ -20,29 +21,32 @@
import tech.ydb.core.grpc.BalancingSettings;
import tech.ydb.core.timer.TestTicker;

import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

/**
* @author Aleksandr Gorshenin
* @author Kirill Kurdyukov
*/
public class EndpointPoolTest {
private AutoCloseable mocks;
private final MockedStatic<ThreadLocalRandom> threadLocalStaticMock = mockStatic(ThreadLocalRandom.class);
private final MockedStatic<Ticker> tickerStaticMock = mockStatic(Ticker.class);
private final MockedStatic<ThreadLocalRandom> threadLocalStaticMock = Mockito.mockStatic(ThreadLocalRandom.class);
private final MockedStatic<Ticker> tickerStaticMock = Mockito.mockStatic(Ticker.class);
private final MockedStatic<SocketFactory> socketFactoryStaticMock = Mockito.mockStatic(SocketFactory.class);

private final Socket socket = Mockito.mock(Socket.class);
private final SocketFactory socketFactory = Mockito.mock(SocketFactory.class);
private final ThreadLocalRandom random = Mockito.mock(ThreadLocalRandom.class);

@Before
public void setUp() {
public void setUp() throws IOException {
mocks = MockitoAnnotations.openMocks(this);
threadLocalStaticMock.when(ThreadLocalRandom::current).thenReturn(random);
socketFactoryStaticMock.when(SocketFactory::getDefault).thenReturn(socketFactory);
Mockito.when(socketFactory.createSocket()).thenReturn(socket);
Mockito.doNothing().when(socket).connect(Mockito.any(SocketAddress.class));
}

@After
public void tearDown() throws Exception {
socketFactoryStaticMock.close();
tickerStaticMock.close();
threadLocalStaticMock.close();
mocks.close();
Expand Down Expand Up @@ -77,7 +81,7 @@ public void useAllNodesTest() {

check(pool).records(3).knownNodes(3).needToReDiscovery(false).bestEndpointsCount(3);

when(random.nextInt(3)).thenReturn(2, 0, 2, 1);
Mockito.when(random.nextInt(3)).thenReturn(2, 0, 2, 1);

check(pool.getEndpoint(null)).hostname("n3.ydb.tech").nodeID(3).port(12345); // random choice
check(pool.getEndpoint(0)).hostname("n1.ydb.tech").nodeID(1).port(12345); // random choose
Expand All @@ -87,7 +91,7 @@ public void useAllNodesTest() {
check(pool.getEndpoint(4)).hostname("n3.ydb.tech").nodeID(3).port(12345); // random choose
check(pool.getEndpoint(5)).hostname("n2.ydb.tech").nodeID(2).port(12345); // random choose

verify(random, times(4)).nextInt(3);
Mockito.verify(random, Mockito.times(4)).nextInt(3);
}

@Test
Expand All @@ -103,7 +107,7 @@ public void localDcTest() {

check(pool).records(3).knownNodes(3).needToReDiscovery(false).bestEndpointsCount(1);

when(random.nextInt(1)).thenReturn(0, 0, 0);
Mockito.when(random.nextInt(1)).thenReturn(0, 0, 0);

check(pool.getEndpoint(null)).hostname("n2.ydb.tech").nodeID(2).port(12345); // random from local DC
check(pool.getEndpoint(0)).hostname("n2.ydb.tech").nodeID(2).port(12345); // random from local DC
Expand All @@ -112,7 +116,7 @@ public void localDcTest() {
check(pool.getEndpoint(3)).hostname("n3.ydb.tech").nodeID(3).port(12345); // preferred
check(pool.getEndpoint(4)).hostname("n2.ydb.tech").nodeID(2).port(12345); // random from local DC

verify(random, times(3)).nextInt(1);
Mockito.verify(random, Mockito.times(3)).nextInt(1);
}

@Test
Expand All @@ -128,7 +132,7 @@ public void preferredDcTest() {

check(pool).records(3).knownNodes(3).needToReDiscovery(false).bestEndpointsCount(1);

when(random.nextInt(1)).thenReturn(0, 0, 0);
Mockito.when(random.nextInt(1)).thenReturn(0, 0, 0);

check(pool.getEndpoint(null)).hostname("n1.ydb.tech").nodeID(1).port(12345); // random from DC1
check(pool.getEndpoint(0)).hostname("n1.ydb.tech").nodeID(1).port(12345); // random from DC1
Expand All @@ -137,7 +141,7 @@ public void preferredDcTest() {
check(pool.getEndpoint(3)).hostname("n3.ydb.tech").nodeID(3).port(12345); // preferred
check(pool.getEndpoint(4)).hostname("n1.ydb.tech").nodeID(1).port(12345); // random from DC1

verify(random, times(3)).nextInt(1);
Mockito.verify(random, Mockito.times(3)).nextInt(1);
}

@Test
Expand All @@ -153,7 +157,7 @@ public void preferredEndpointsTest() {

check(pool).records(3).knownNodes(3).needToReDiscovery(false).bestEndpointsCount(3);

when(random.nextInt(3)).thenReturn(2, 0, 2, 1);
Mockito.when(random.nextInt(3)).thenReturn(2, 0, 2, 1);

// If node is known
check(pool.getEndpoint(1)).hostname("n1.ydb.tech").nodeID(1).port(12341);
Expand All @@ -167,7 +171,7 @@ public void preferredEndpointsTest() {
check(pool.getEndpoint(6)).hostname("n3.ydb.tech").nodeID(3).port(12343);
check(pool.getEndpoint(7)).hostname("n2.ydb.tech").nodeID(2).port(12342);

verify(random, times(4)).nextInt(3);
Mockito.verify(random, Mockito.times(4)).nextInt(3);
}

@Test
Expand All @@ -185,24 +189,24 @@ public void nodePessimizationTest() {

check(pool).records(5).knownNodes(5).needToReDiscovery(false).bestEndpointsCount(5);

when(random.nextInt(5)).thenReturn(0, 1, 3, 2, 4);
Mockito.when(random.nextInt(5)).thenReturn(0, 1, 3, 2, 4);
check(pool.getEndpoint(null)).hostname("n1.ydb.tech").nodeID(1).port(12341);
check(pool.getEndpoint(null)).hostname("n2.ydb.tech").nodeID(2).port(12342);
check(pool.getEndpoint(null)).hostname("n4.ydb.tech").nodeID(4).port(12344);
check(pool.getEndpoint(null)).hostname("n3.ydb.tech").nodeID(3).port(12343);
check(pool.getEndpoint(null)).hostname("n5.ydb.tech").nodeID(5).port(12345);
verify(random, times(5)).nextInt(5);
Mockito.verify(random, Mockito.times(5)).nextInt(5);

// Pessimize one node - four left in use
pool.pessimizeEndpoint(pool.getEndpoint(2));
check(pool).records(5).knownNodes(5).needToReDiscovery(false).bestEndpointsCount(4);

when(random.nextInt(4)).thenReturn(0, 2, 1, 3);
Mockito.when(random.nextInt(4)).thenReturn(0, 2, 1, 3);
check(pool.getEndpoint(null)).hostname("n1.ydb.tech").nodeID(1).port(12341);
check(pool.getEndpoint(null)).hostname("n4.ydb.tech").nodeID(4).port(12344);
check(pool.getEndpoint(null)).hostname("n3.ydb.tech").nodeID(3).port(12343);
check(pool.getEndpoint(null)).hostname("n5.ydb.tech").nodeID(5).port(12345);
verify(random, times(4)).nextInt(4);
Mockito.verify(random, Mockito.times(4)).nextInt(4);

// but we can use pessimized node if specify it as preferred
check(pool.getEndpoint(2)).hostname("n2.ydb.tech").nodeID(2).port(12342);
Expand All @@ -217,25 +221,25 @@ public void nodePessimizationTest() {
pool.pessimizeEndpoint(pool.getEndpoint(2));
check(pool).records(5).knownNodes(5).needToReDiscovery(false).bestEndpointsCount(4);

when(random.nextInt(4)).thenReturn(3, 1, 2, 0);
Mockito.when(random.nextInt(4)).thenReturn(3, 1, 2, 0);
check(pool.getEndpoint(null)).hostname("n5.ydb.tech").nodeID(5).port(12345);
check(pool.getEndpoint(null)).hostname("n3.ydb.tech").nodeID(3).port(12343);
check(pool.getEndpoint(null)).hostname("n4.ydb.tech").nodeID(4).port(12344);
check(pool.getEndpoint(null)).hostname("n1.ydb.tech").nodeID(1).port(12341);
verify(random, times(8)).nextInt(4); // Mockito counts also previous 4
Mockito.verify(random, Mockito.times(8)).nextInt(4); // Mockito counts also previous 4

// Pessimize two nodes - then we need to discovery
pool.pessimizeEndpoint(pool.getEndpoint(3));
check(pool).records(5).knownNodes(5).needToReDiscovery(false).bestEndpointsCount(3);
pool.pessimizeEndpoint(pool.getEndpoint(5));
check(pool).records(5).knownNodes(5).needToReDiscovery(true).bestEndpointsCount(2);

when(random.nextInt(2)).thenReturn(1, 1, 0, 0);
Mockito.when(random.nextInt(2)).thenReturn(1, 1, 0, 0);
check(pool.getEndpoint(null)).hostname("n4.ydb.tech").nodeID(4).port(12344);
check(pool.getEndpoint(null)).hostname("n4.ydb.tech").nodeID(4).port(12344);
check(pool.getEndpoint(null)).hostname("n1.ydb.tech").nodeID(1).port(12341);
check(pool.getEndpoint(null)).hostname("n1.ydb.tech").nodeID(1).port(12341);
verify(random, times(4)).nextInt(2);
Mockito.verify(random, Mockito.times(4)).nextInt(2);
}

@Test
Expand All @@ -253,39 +257,39 @@ public void nodePessimizationFallbackTest() {
check(pool).records(4).knownNodes(4).needToReDiscovery(false).bestEndpointsCount(2);

// Only local nodes are used
when(random.nextInt(2)).thenReturn(0, 1);
Mockito.when(random.nextInt(2)).thenReturn(0, 1);
check(pool.getEndpoint(null)).hostname("n1.ydb.tech").nodeID(1).port(12341);
check(pool.getEndpoint(null)).hostname("n2.ydb.tech").nodeID(2).port(12342);
verify(random, times(2)).nextInt(2);
Mockito.verify(random, Mockito.times(2)).nextInt(2);

// Pessimize first local node - use second
pool.pessimizeEndpoint(pool.getEndpoint(1));
check(pool).records(4).knownNodes(4).needToReDiscovery(false).bestEndpointsCount(1);

when(random.nextInt(1)).thenReturn(0);
Mockito.when(random.nextInt(1)).thenReturn(0);
check(pool.getEndpoint(null)).hostname("n2.ydb.tech").nodeID(2).port(12342);
verify(random, times(1)).nextInt(1);
Mockito.verify(random, Mockito.times(1)).nextInt(1);

// Pessimize second local node - use unlocal nodes
pool.pessimizeEndpoint(pool.getEndpoint(2));
check(pool).records(4).knownNodes(4).needToReDiscovery(false).bestEndpointsCount(2);

when(random.nextInt(2)).thenReturn(1, 0);
Mockito.when(random.nextInt(2)).thenReturn(1, 0);
check(pool.getEndpoint(null)).hostname("n4.ydb.tech").nodeID(4).port(12344);
check(pool.getEndpoint(null)).hostname("n3.ydb.tech").nodeID(3).port(12343);
verify(random, times(4)).nextInt(2);
Mockito.verify(random, Mockito.times(4)).nextInt(2);

// Pessimize all - fallback to use all nodes
pool.pessimizeEndpoint(pool.getEndpoint(3));
pool.pessimizeEndpoint(pool.getEndpoint(4));
check(pool).records(4).knownNodes(4).needToReDiscovery(true).bestEndpointsCount(4);

when(random.nextInt(4)).thenReturn(3, 2, 1, 0);
Mockito.when(random.nextInt(4)).thenReturn(3, 2, 1, 0);
check(pool.getEndpoint(null)).hostname("n4.ydb.tech").nodeID(4).port(12344);
check(pool.getEndpoint(null)).hostname("n3.ydb.tech").nodeID(3).port(12343);
check(pool.getEndpoint(null)).hostname("n2.ydb.tech").nodeID(2).port(12342);
check(pool.getEndpoint(null)).hostname("n1.ydb.tech").nodeID(1).port(12341);
verify(random, times(4)).nextInt(4);
Mockito.verify(random, Mockito.times(4)).nextInt(4);

// setNewState reset all
pool.setNewState("DC3", list(
Expand Down Expand Up @@ -318,7 +322,7 @@ public void duplicateEndpointsTest() {
check(pool).record(2).hostname("n3.ydb.tech").nodeID(3).port(12343);
check(pool).record(3).hostname("n3.ydb.tech").nodeID(6).port(12344);

when(random.nextInt(4)).thenReturn(2, 0, 3, 1);
Mockito.when(random.nextInt(4)).thenReturn(2, 0, 3, 1);

check(pool.getEndpoint(null)).hostname("n3.ydb.tech").nodeID(3).port(12343); // random
check(pool.getEndpoint(0)).hostname("n1.ydb.tech").nodeID(1).port(12341); // random
Expand All @@ -329,7 +333,7 @@ public void duplicateEndpointsTest() {
check(pool.getEndpoint(5)).hostname("n2.ydb.tech").nodeID(2).port(12342); // random
check(pool.getEndpoint(6)).hostname("n3.ydb.tech").nodeID(6).port(12344);

verify(random, times(4)).nextInt(4);
Mockito.verify(random, Mockito.times(4)).nextInt(4);
}

@Test
Expand All @@ -349,15 +353,15 @@ public void duplicateNodesTest() {
check(pool).record(1).hostname("n2.ydb.tech").nodeID(2).port(12342);
check(pool).record(2).hostname("n3.ydb.tech").nodeID(2).port(12343);

when(random.nextInt(3)).thenReturn(1, 0, 2);
Mockito.when(random.nextInt(3)).thenReturn(1, 0, 2);

check(pool.getEndpoint(null)).hostname("n2.ydb.tech").nodeID(2).port(12342); // random
check(pool.getEndpoint(0)).hostname("n1.ydb.tech").nodeID(1).port(12341); // random
check(pool.getEndpoint(1)).hostname("n1.ydb.tech").nodeID(1).port(12341);
check(pool.getEndpoint(2)).hostname("n3.ydb.tech").nodeID(2).port(12343);
check(pool.getEndpoint(3)).hostname("n3.ydb.tech").nodeID(2).port(12343); // random

verify(random, times(3)).nextInt(3);
Mockito.verify(random, Mockito.times(3)).nextInt(3);
}

@Test
Expand All @@ -377,7 +381,7 @@ public void removeEndpointsTest() {
check(pool).record(1).hostname("n2.ydb.tech").nodeID(2).port(12342);
check(pool).record(2).hostname("n3.ydb.tech").nodeID(3).port(12343);

when(random.nextInt(3)).thenReturn(1, 0, 2);
Mockito.when(random.nextInt(3)).thenReturn(1, 0, 2);

check(pool.getEndpoint(null)).hostname("n2.ydb.tech").nodeID(2).port(12342); // random
check(pool.getEndpoint(0)).hostname("n1.ydb.tech").nodeID(1).port(12341); // random
Expand All @@ -386,7 +390,7 @@ public void removeEndpointsTest() {
check(pool.getEndpoint(3)).hostname("n3.ydb.tech").nodeID(3).port(12343);
check(pool.getEndpoint(4)).hostname("n3.ydb.tech").nodeID(3).port(12343); // random

verify(random, times(3)).nextInt(3);
Mockito.verify(random, Mockito.times(3)).nextInt(3);

pool.setNewState("DC", list(
endpoint(2, "n2.ydb.tech", 12342, "DC"),
Expand All @@ -402,7 +406,7 @@ public void removeEndpointsTest() {
check(pool).record(2).hostname("n5.ydb.tech").nodeID(5).port(12345);
check(pool).record(3).hostname("n6.ydb.tech").nodeID(6).port(12346);

when(random.nextInt(4)).thenReturn(3, 1, 2, 0);
Mockito.when(random.nextInt(4)).thenReturn(3, 1, 2, 0);

check(pool.getEndpoint(null)).hostname("n6.ydb.tech").nodeID(6).port(12346); // random
check(pool.getEndpoint(0)).hostname("n4.ydb.tech").nodeID(4).port(12344); // random
Expand All @@ -413,7 +417,7 @@ public void removeEndpointsTest() {
check(pool.getEndpoint(5)).hostname("n5.ydb.tech").nodeID(5).port(12345);
check(pool.getEndpoint(6)).hostname("n6.ydb.tech").nodeID(6).port(12346);

verify(random, times(4)).nextInt(4);
Mockito.verify(random, Mockito.times(4)).nextInt(4);
}


Expand All @@ -427,42 +431,35 @@ public void detectLocalDCTest() throws IOException {

tickerStaticMock.when(Ticker::systemTicker).thenReturn(testTicker);

try (
ServerSocket s1 = ServerSocketFactory.getDefault().createServerSocket(0);
ServerSocket s2 = ServerSocketFactory.getDefault().createServerSocket(0);
ServerSocket s3 = ServerSocketFactory.getDefault().createServerSocket(0);
) {

EndpointPool pool = new EndpointPool(detectLocalDC());
check(pool).records(0).knownNodes(0).needToReDiscovery(false);

int p1 = s1.getLocalPort();
int p2 = s2.getLocalPort();
int p3 = s3.getLocalPort();

pool.setNewState("DC", list(
endpoint(1, "127.0.0.1", p1, "DC1"),
endpoint(2, "127.0.0.2", p2, "DC2"),
endpoint(3, "127.0.0.3", p3, "DC3")
));

check(pool).records(3).knownNodes(3).needToReDiscovery(false).bestEndpointsCount(1);

check(pool.getEndpoint(null)).hostname("127.0.0.2").nodeID(2).port(p2); // detect local dc
check(pool.getEndpoint(0)).hostname("127.0.0.2").nodeID(2).port(p2); // random from local dc
check(pool.getEndpoint(1)).hostname("127.0.0.1").nodeID(1).port(p1);
check(pool.getEndpoint(2)).hostname("127.0.0.2").nodeID(2).port(p2); // local dc
check(pool.getEndpoint(3)).hostname("127.0.0.3").nodeID(3).port(p3);
check(pool.getEndpoint(4)).hostname("127.0.0.2").nodeID(2).port(p2); // random from local dc

pool.pessimizeEndpoint(pool.getEndpoint(2));
check(pool.getEndpoint(null)).hostname("127.0.0.1").nodeID(1).port(p1); // new local dc
check(pool.getEndpoint(0)).hostname("127.0.0.1").nodeID(1).port(p1); // random from local dc
check(pool.getEndpoint(1)).hostname("127.0.0.1").nodeID(1).port(p1);
check(pool.getEndpoint(2)).hostname("127.0.0.2").nodeID(2).port(p2); // local dc
check(pool.getEndpoint(3)).hostname("127.0.0.3").nodeID(3).port(p3);
check(pool.getEndpoint(4)).hostname("127.0.0.1").nodeID(1).port(p1); // random from local dc
}
EndpointPool pool = new EndpointPool(detectLocalDC());
check(pool).records(0).knownNodes(0).needToReDiscovery(false);

int p1 = 1234;
int p2 = 1235;
int p3 = 1236;

pool.setNewState("DC", list(
endpoint(1, "127.0.0.1", p1, "DC1"),
endpoint(2, "127.0.0.2", p2, "DC2"),
endpoint(3, "127.0.0.3", p3, "DC3")
));

check(pool).records(3).knownNodes(3).needToReDiscovery(false).bestEndpointsCount(1);

check(pool.getEndpoint(null)).hostname("127.0.0.2").nodeID(2).port(p2); // detect local dc
check(pool.getEndpoint(0)).hostname("127.0.0.2").nodeID(2).port(p2); // random from local dc
check(pool.getEndpoint(1)).hostname("127.0.0.1").nodeID(1).port(p1);
check(pool.getEndpoint(2)).hostname("127.0.0.2").nodeID(2).port(p2); // local dc
check(pool.getEndpoint(3)).hostname("127.0.0.3").nodeID(3).port(p3);
check(pool.getEndpoint(4)).hostname("127.0.0.2").nodeID(2).port(p2); // random from local dc

pool.pessimizeEndpoint(pool.getEndpoint(2));
check(pool.getEndpoint(null)).hostname("127.0.0.1").nodeID(1).port(p1); // new local dc
check(pool.getEndpoint(0)).hostname("127.0.0.1").nodeID(1).port(p1); // random from local dc
check(pool.getEndpoint(1)).hostname("127.0.0.1").nodeID(1).port(p1);
check(pool.getEndpoint(2)).hostname("127.0.0.2").nodeID(2).port(p2); // local dc
check(pool.getEndpoint(3)).hostname("127.0.0.3").nodeID(3).port(p3);
check(pool.getEndpoint(4)).hostname("127.0.0.1").nodeID(1).port(p1); // random from local dc
}

private static class PoolChecker {
Expand Down

0 comments on commit d34acdc

Please sign in to comment.