Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Improve logic for retrieving initial shard total
Browse files Browse the repository at this point in the history
MinnDevelopment committed Nov 12, 2023
1 parent 1e18000 commit 7d1b1af
Showing 2 changed files with 177 additions and 43 deletions.
202 changes: 176 additions & 26 deletions src/main/java/net/dv8tion/jda/api/sharding/DefaultShardManager.java
Original file line number Diff line number Diff line change
@@ -22,18 +22,18 @@
import net.dv8tion.jda.api.entities.Guild;
import net.dv8tion.jda.api.entities.SelfUser;
import net.dv8tion.jda.api.exceptions.InvalidTokenException;
import net.dv8tion.jda.api.requests.GatewayIntent;
import net.dv8tion.jda.api.requests.RestConfig;
import net.dv8tion.jda.api.requests.Route;
import net.dv8tion.jda.api.requests.*;
import net.dv8tion.jda.api.utils.ChunkingFilter;
import net.dv8tion.jda.api.utils.MiscUtil;
import net.dv8tion.jda.api.utils.SessionController;
import net.dv8tion.jda.api.utils.cache.ShardCacheView;
import net.dv8tion.jda.api.utils.data.DataObject;
import net.dv8tion.jda.internal.JDAImpl;
import net.dv8tion.jda.internal.entities.SelfUserImpl;
import net.dv8tion.jda.internal.managers.PresenceImpl;
import net.dv8tion.jda.internal.requests.RestActionImpl;
import net.dv8tion.jda.internal.utils.Checks;
import net.dv8tion.jda.internal.utils.IOUtil;
import net.dv8tion.jda.internal.utils.JDALogger;
import net.dv8tion.jda.internal.utils.UnlockHook;
import net.dv8tion.jda.internal.utils.cache.ShardCacheViewImpl;
@@ -42,11 +42,16 @@
import net.dv8tion.jda.internal.utils.config.SessionConfig;
import net.dv8tion.jda.internal.utils.config.ThreadingConfig;
import net.dv8tion.jda.internal.utils.config.sharding.*;
import okhttp3.Call;
import okhttp3.OkHttpClient;
import org.slf4j.Logger;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collection;
import java.util.EnumSet;
@@ -489,12 +494,7 @@ protected JDAImpl buildInstance(final int shardId)
httpClient = sessionConfig.getHttpBuilder().build();
}

// We first initialize the thread-pool here with the known shard total
// If the shard total is not known yet, a rest-only shard is spun up to determine the recommended shard total from the API
// This uses a temporary pool with a single thread to avoid blocking the main thread-pool

if (getShardsTotal() != -1)
threadingConfig.init(getShardsTotal());
retrieveShardTotal(httpClient);

// imagine if we had macros or closures or destructuring :)
ExecutorPair<ScheduledExecutorService> rateLimitSchedulerPair = resolveExecutor(threadingConfig.getRateLimitSchedulerProvider(), shardId);
@@ -566,7 +566,7 @@ protected JDAImpl buildInstance(final int shardId)
if (presenceConfig.getStatusProvider() != null)
presence.setCacheStatus(presenceConfig.getStatusProvider().apply(shardId));

if (this.gatewayURL == null || getShardsTotal() == -1)
if (this.gatewayURL == null)
{
SessionController.ShardedGateway gateway = jda.getShardedGateway();
this.sessionConfig.getSessionController().setConcurrency(gateway.getConcurrency());
@@ -575,22 +575,6 @@ protected JDAImpl buildInstance(final int shardId)
throw new IllegalStateException("Acquired null gateway url from SessionController");
else
LOG.info("Login Successful!");

if (getShardsTotal() == -1)
{
shardingConfig.setShardsTotal(gateway.getShardTotal());
this.shards = new ShardCacheViewImpl(getShardsTotal());

synchronized (queue)
{
for (int i = 0; i < getShardsTotal(); i++)
queue.add(i);
}

// Rebuild instance with new shard total, to allow proper thread scaling
jda.shutdownNow();
return buildInstance(shardId);
}
}

final JDA.ShardInfo shardInfo = new JDA.ShardInfo(shardId, getShardsTotal());
@@ -652,6 +636,44 @@ public void setStatusProvider(IntFunction<OnlineStatus> statusProvider)
presenceConfig.setStatusProvider(statusProvider);
}

private synchronized void retrieveShardTotal(OkHttpClient httpClient)
{
if (getShardsTotal() != -1)
return;

LOG.debug("Fetching shard total using temporary rate-limiter");

CompletableFuture<Integer> future = new CompletableFuture<>();
ScheduledExecutorService pool = Executors.newSingleThreadScheduledExecutor(task -> {
Thread thread = new Thread(task, "DefaultShardManager retrieveShardTotal");
thread.setDaemon(true);
return thread;
});

try
{
RestRateLimiter.RateLimitConfig rateLimitConfig = new RestRateLimiter.RateLimitConfig(pool, RestRateLimiter.GlobalRateLimit.create(), true);
SequentialRestRateLimiter rateLimiter = new SequentialRestRateLimiter(rateLimitConfig);
rateLimiter.enqueue(new ShardTotalTask(future, httpClient));

int shardTotal = future.join();
this.shardingConfig.setShardsTotal(shardTotal);
this.shards = new ShardCacheViewImpl(shardTotal);

synchronized (queue)
{
for (int i = 0; i < shardTotal; i++)
queue.add(i);
}

}
finally
{
future.cancel(false);
pool.shutdownNow();
}
}

/**
* This method creates the internal {@link java.util.concurrent.ScheduledExecutorService ScheduledExecutorService}.
* It is intended as a hook for custom implementations to create their own executor.
@@ -690,4 +712,132 @@ protected ExecutorPair(E executor, boolean automaticShutdown)
this.automaticShutdown = automaticShutdown;
}
}

protected class ShardTotalTask implements RestRateLimiter.Work
{
private final CompletableFuture<Integer> future;
private final OkHttpClient httpClient;
private int failedAttempts = 0;

protected ShardTotalTask(CompletableFuture<Integer> future, OkHttpClient httpClient)
{
this.future = future;
this.httpClient = httpClient;
}

@Nonnull
@Override
public Route.CompiledRoute getRoute()
{
return Route.Misc.GATEWAY_BOT.compile();
}

@Nonnull
@Override
public JDA getJDA()
{
throw new UnsupportedOperationException();
}

@Nullable
@Override
public okhttp3.Response execute()
{
try
{
String url = restConfigProvider.apply(0).getBaseUrl() + getRoute().getCompiledRoute();
LOG.trace("Requesting shard total with url {}", url);

Call call = httpClient.newCall(new okhttp3.Request.Builder()
.get()
.url(url)
.header("authorization", "Bot " + token)
.header("accept-encoding", "gzip")
.build()
);

okhttp3.Response response = call.execute();

try
{
LOG.trace("Received response with code {}", response.code());
InputStream body = IOUtil.getBody(response);

if (response.isSuccessful())
{
DataObject json = DataObject.fromJson(body);
int shardTotal = json.getInt("shards");
future.complete(shardTotal);
}
else if (response.code() != 429 && response.code() < 500 || ++failedAttempts > 4)
{
future.completeExceptionally(new IllegalStateException(
"Failed to fetch recommended shard total! Code: " + response.code() + "\n" +
new String(IOUtil.readFully(body), StandardCharsets.UTF_8)
));
}
else if (response.code() >= 500)
{
int backoff = 1 << failedAttempts;
LOG.warn("Failed to retrieve recommended shard total. Code: {} ... retrying in {}s", response.code(), backoff);
response = response.newBuilder()
.headers(response.headers()
.newBuilder()
.set(RestRateLimiter.RESET_AFTER_HEADER, String.valueOf(backoff))
.set(RestRateLimiter.REMAINING_HEADER, String.valueOf(0))
.set(RestRateLimiter.LIMIT_HEADER, String.valueOf(1))
.set(RestRateLimiter.SCOPE_HEADER, "custom")
.build())
.build();
}

return response;
}
finally
{
response.close();
}
}
catch (IOException e)
{
future.completeExceptionally(e);
throw new UncheckedIOException(e);
}
catch (Throwable e)
{
future.completeExceptionally(e);
throw e;
}
}

@Override
public boolean isSkipped()
{
return isCancelled();
}

@Override
public boolean isDone()
{
return future.isDone();
}

@Override
public boolean isPriority()
{
return true;
}

@Override
public boolean isCancelled()
{
return future.isCancelled();
}

@Override
public void cancel()
{
future.cancel(false);
}
}
}
Original file line number Diff line number Diff line change
@@ -85,8 +85,7 @@ static <T extends ExecutorService> LazySharedProvider<T> lazy(@Nonnull IntFuncti
final class LazySharedProvider<T extends ExecutorService> implements ThreadPoolProvider<T>
{
private final IntFunction<T> initializer;
private volatile T temporaryPool;
private volatile T pool;
private T pool;

LazySharedProvider(@Nonnull IntFunction<T> initializer)
{
@@ -105,10 +104,6 @@ public synchronized void init(int shardTotal)
{
if (pool == null)
pool = initializer.apply(shardTotal);

if (temporaryPool != null && temporaryPool != pool)
temporaryPool.shutdownNow();
temporaryPool = null;
}

/**
@@ -121,10 +116,6 @@ public synchronized void shutdown()
pool.shutdown();
pool = null;
}

if (temporaryPool != null && temporaryPool != pool)
temporaryPool.shutdown();
temporaryPool = null;
}

/**
@@ -139,13 +130,6 @@ public synchronized void shutdown()
@Override
public synchronized T provide(int shardId)
{
if (pool == null)
{
if (temporaryPool == null)
temporaryPool = initializer.apply(1);
return temporaryPool;
}

return pool;
}
}

0 comments on commit 7d1b1af

Please sign in to comment.