diff --git a/src/main/java/net/dv8tion/jda/api/sharding/DefaultShardManager.java b/src/main/java/net/dv8tion/jda/api/sharding/DefaultShardManager.java index 11f4d6e71bf..3f8c7ec416e 100644 --- a/src/main/java/net/dv8tion/jda/api/sharding/DefaultShardManager.java +++ b/src/main/java/net/dv8tion/jda/api/sharding/DefaultShardManager.java @@ -22,18 +22,18 @@ import net.dv8tion.jda.api.entities.Guild; import net.dv8tion.jda.api.entities.SelfUser; import net.dv8tion.jda.api.exceptions.InvalidTokenException; -import net.dv8tion.jda.api.requests.GatewayIntent; -import net.dv8tion.jda.api.requests.RestConfig; -import net.dv8tion.jda.api.requests.Route; +import net.dv8tion.jda.api.requests.*; import net.dv8tion.jda.api.utils.ChunkingFilter; import net.dv8tion.jda.api.utils.MiscUtil; import net.dv8tion.jda.api.utils.SessionController; import net.dv8tion.jda.api.utils.cache.ShardCacheView; +import net.dv8tion.jda.api.utils.data.DataObject; import net.dv8tion.jda.internal.JDAImpl; import net.dv8tion.jda.internal.entities.SelfUserImpl; import net.dv8tion.jda.internal.managers.PresenceImpl; import net.dv8tion.jda.internal.requests.RestActionImpl; import net.dv8tion.jda.internal.utils.Checks; +import net.dv8tion.jda.internal.utils.IOUtil; import net.dv8tion.jda.internal.utils.JDALogger; import net.dv8tion.jda.internal.utils.UnlockHook; import net.dv8tion.jda.internal.utils.cache.ShardCacheViewImpl; @@ -42,11 +42,16 @@ import net.dv8tion.jda.internal.utils.config.SessionConfig; import net.dv8tion.jda.internal.utils.config.ThreadingConfig; import net.dv8tion.jda.internal.utils.config.sharding.*; +import okhttp3.Call; import okhttp3.OkHttpClient; import org.slf4j.Logger; import javax.annotation.Nonnull; import javax.annotation.Nullable; +import java.io.IOException; +import java.io.InputStream; +import java.io.UncheckedIOException; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.Collection; import java.util.EnumSet; @@ -489,12 +494,7 @@ protected JDAImpl buildInstance(final int shardId) httpClient = sessionConfig.getHttpBuilder().build(); } - // We first initialize the thread-pool here with the known shard total - // If the shard total is not known yet, a rest-only shard is spun up to determine the recommended shard total from the API - // This uses a temporary pool with a single thread to avoid blocking the main thread-pool - - if (getShardsTotal() != -1) - threadingConfig.init(getShardsTotal()); + retrieveShardTotal(httpClient); // imagine if we had macros or closures or destructuring :) ExecutorPair rateLimitSchedulerPair = resolveExecutor(threadingConfig.getRateLimitSchedulerProvider(), shardId); @@ -566,7 +566,7 @@ protected JDAImpl buildInstance(final int shardId) if (presenceConfig.getStatusProvider() != null) presence.setCacheStatus(presenceConfig.getStatusProvider().apply(shardId)); - if (this.gatewayURL == null || getShardsTotal() == -1) + if (this.gatewayURL == null) { SessionController.ShardedGateway gateway = jda.getShardedGateway(); this.sessionConfig.getSessionController().setConcurrency(gateway.getConcurrency()); @@ -575,22 +575,6 @@ protected JDAImpl buildInstance(final int shardId) throw new IllegalStateException("Acquired null gateway url from SessionController"); else LOG.info("Login Successful!"); - - if (getShardsTotal() == -1) - { - shardingConfig.setShardsTotal(gateway.getShardTotal()); - this.shards = new ShardCacheViewImpl(getShardsTotal()); - - synchronized (queue) - { - for (int i = 0; i < getShardsTotal(); i++) - queue.add(i); - } - - // Rebuild instance with new shard total, to allow proper thread scaling - jda.shutdownNow(); - return buildInstance(shardId); - } } final JDA.ShardInfo shardInfo = new JDA.ShardInfo(shardId, getShardsTotal()); @@ -652,6 +636,44 @@ public void setStatusProvider(IntFunction statusProvider) presenceConfig.setStatusProvider(statusProvider); } + private synchronized void retrieveShardTotal(OkHttpClient httpClient) + { + if (getShardsTotal() != -1) + return; + + LOG.debug("Fetching shard total using temporary rate-limiter"); + + CompletableFuture future = new CompletableFuture<>(); + ScheduledExecutorService pool = Executors.newSingleThreadScheduledExecutor(task -> { + Thread thread = new Thread(task, "DefaultShardManager retrieveShardTotal"); + thread.setDaemon(true); + return thread; + }); + + try + { + RestRateLimiter.RateLimitConfig rateLimitConfig = new RestRateLimiter.RateLimitConfig(pool, RestRateLimiter.GlobalRateLimit.create(), true); + SequentialRestRateLimiter rateLimiter = new SequentialRestRateLimiter(rateLimitConfig); + rateLimiter.enqueue(new ShardTotalTask(future, httpClient)); + + int shardTotal = future.join(); + this.shardingConfig.setShardsTotal(shardTotal); + this.shards = new ShardCacheViewImpl(shardTotal); + + synchronized (queue) + { + for (int i = 0; i < shardTotal; i++) + queue.add(i); + } + + } + finally + { + future.cancel(false); + pool.shutdownNow(); + } + } + /** * This method creates the internal {@link java.util.concurrent.ScheduledExecutorService ScheduledExecutorService}. * It is intended as a hook for custom implementations to create their own executor. @@ -690,4 +712,132 @@ protected ExecutorPair(E executor, boolean automaticShutdown) this.automaticShutdown = automaticShutdown; } } + + protected class ShardTotalTask implements RestRateLimiter.Work + { + private final CompletableFuture future; + private final OkHttpClient httpClient; + private int failedAttempts = 0; + + protected ShardTotalTask(CompletableFuture future, OkHttpClient httpClient) + { + this.future = future; + this.httpClient = httpClient; + } + + @Nonnull + @Override + public Route.CompiledRoute getRoute() + { + return Route.Misc.GATEWAY_BOT.compile(); + } + + @Nonnull + @Override + public JDA getJDA() + { + throw new UnsupportedOperationException(); + } + + @Nullable + @Override + public okhttp3.Response execute() + { + try + { + String url = restConfigProvider.apply(0).getBaseUrl() + getRoute().getCompiledRoute(); + LOG.trace("Requesting shard total with url {}", url); + + Call call = httpClient.newCall(new okhttp3.Request.Builder() + .get() + .url(url) + .header("authorization", "Bot " + token) + .header("accept-encoding", "gzip") + .build() + ); + + okhttp3.Response response = call.execute(); + + try + { + LOG.trace("Received response with code {}", response.code()); + InputStream body = IOUtil.getBody(response); + + if (response.isSuccessful()) + { + DataObject json = DataObject.fromJson(body); + int shardTotal = json.getInt("shards"); + future.complete(shardTotal); + } + else if (response.code() != 429 && response.code() < 500 || ++failedAttempts > 4) + { + future.completeExceptionally(new IllegalStateException( + "Failed to fetch recommended shard total! Code: " + response.code() + "\n" + + new String(IOUtil.readFully(body), StandardCharsets.UTF_8) + )); + } + else if (response.code() >= 500) + { + int backoff = 1 << failedAttempts; + LOG.warn("Failed to retrieve recommended shard total. Code: {} ... retrying in {}s", response.code(), backoff); + response = response.newBuilder() + .headers(response.headers() + .newBuilder() + .set(RestRateLimiter.RESET_AFTER_HEADER, String.valueOf(backoff)) + .set(RestRateLimiter.REMAINING_HEADER, String.valueOf(0)) + .set(RestRateLimiter.LIMIT_HEADER, String.valueOf(1)) + .set(RestRateLimiter.SCOPE_HEADER, "custom") + .build()) + .build(); + } + + return response; + } + finally + { + response.close(); + } + } + catch (IOException e) + { + future.completeExceptionally(e); + throw new UncheckedIOException(e); + } + catch (Throwable e) + { + future.completeExceptionally(e); + throw e; + } + } + + @Override + public boolean isSkipped() + { + return isCancelled(); + } + + @Override + public boolean isDone() + { + return future.isDone(); + } + + @Override + public boolean isPriority() + { + return true; + } + + @Override + public boolean isCancelled() + { + return future.isCancelled(); + } + + @Override + public void cancel() + { + future.cancel(false); + } + } } diff --git a/src/main/java/net/dv8tion/jda/api/sharding/ThreadPoolProvider.java b/src/main/java/net/dv8tion/jda/api/sharding/ThreadPoolProvider.java index 9ccd4da0163..e7e9174f708 100644 --- a/src/main/java/net/dv8tion/jda/api/sharding/ThreadPoolProvider.java +++ b/src/main/java/net/dv8tion/jda/api/sharding/ThreadPoolProvider.java @@ -85,8 +85,7 @@ static LazySharedProvider lazy(@Nonnull IntFuncti final class LazySharedProvider implements ThreadPoolProvider { private final IntFunction initializer; - private volatile T temporaryPool; - private volatile T pool; + private T pool; LazySharedProvider(@Nonnull IntFunction initializer) { @@ -105,10 +104,6 @@ public synchronized void init(int shardTotal) { if (pool == null) pool = initializer.apply(shardTotal); - - if (temporaryPool != null && temporaryPool != pool) - temporaryPool.shutdownNow(); - temporaryPool = null; } /** @@ -121,10 +116,6 @@ public synchronized void shutdown() pool.shutdown(); pool = null; } - - if (temporaryPool != null && temporaryPool != pool) - temporaryPool.shutdown(); - temporaryPool = null; } /** @@ -139,13 +130,6 @@ public synchronized void shutdown() @Override public synchronized T provide(int shardId) { - if (pool == null) - { - if (temporaryPool == null) - temporaryPool = initializer.apply(1); - return temporaryPool; - } - return pool; } }