diff --git a/DisCatSharp/Clients/BaseDiscordClient.cs b/DisCatSharp/Clients/BaseDiscordClient.cs index fb8aff487..33e1826f0 100644 --- a/DisCatSharp/Clients/BaseDiscordClient.cs +++ b/DisCatSharp/Clients/BaseDiscordClient.cs @@ -1,311 +1,315 @@ // This file is part of the DisCatSharp project. // // Copyright (c) 2021 AITSYS // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in all // copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. #pragma warning disable CS0618 using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Collections.ObjectModel; using System.Linq; using System.Net.Http; using System.Reflection; using System.Threading.Tasks; using DisCatSharp.Entities; using DisCatSharp.Net; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; namespace DisCatSharp { /// /// Represents a common base for various Discord client implementations. /// public abstract class BaseDiscordClient : IDisposable { /// /// Gets the api client. /// internal protected DiscordApiClient ApiClient { get; } /// /// Gets the configuration. /// internal protected DiscordConfiguration Configuration { get; } /// /// Gets the instance of the logger for this client. /// public ILogger Logger { get; } /// /// Gets the string representing the version of bot lib. /// public string VersionString { get; } /// /// Gets the bot library name. /// public string BotLibrary { get; } /// /// Gets the library team. /// public DisCatSharpTeam LibraryDeveloperTeam => this.ApiClient.GetDisCatSharpTeamAsync().Result; /// /// Gets the current user. /// public DiscordUser CurrentUser { get; internal set; } /// /// Gets the current application. /// public DiscordApplication CurrentApplication { get; internal set; } + public HttpClient RestClient { get; internal set; } + /// /// Exposes a for custom operations. /// public HttpClient RestClient { get; internal set; } /// /// Gets the cached guilds for this client. /// public abstract IReadOnlyDictionary Guilds { get; } /// /// Gets the cached users for this client. /// protected internal ConcurrentDictionary UserCache { get; } /// /// Gets the service provider. /// This allows passing data around without resorting to static members. /// Defaults to null. /// internal IServiceProvider ServiceProvider { get; set; } = new ServiceCollection().BuildServiceProvider(true); /// /// Gets the list of available voice regions. Note that this property will not contain VIP voice regions. /// public IReadOnlyDictionary VoiceRegions => this._voice_regions_lazy.Value; /// /// Gets the list of available voice regions. This property is meant as a way to modify . /// protected internal ConcurrentDictionary InternalVoiceRegions { get; set; } internal Lazy> _voice_regions_lazy; /// /// Initializes this Discord API client. /// /// Configuration for this client. protected BaseDiscordClient(DiscordConfiguration config) { this.Configuration = new DiscordConfiguration(config); if (this.Configuration.LoggerFactory == null) { this.Configuration.LoggerFactory = new DefaultLoggerFactory(); this.Configuration.LoggerFactory.AddProvider(new DefaultLoggerProvider(this)); } this.Logger = this.Configuration.LoggerFactory.CreateLogger(); this.ApiClient = new DiscordApiClient(this); this.UserCache = new ConcurrentDictionary(); this.InternalVoiceRegions = new ConcurrentDictionary(); this._voice_regions_lazy = new Lazy>(() => new ReadOnlyDictionary(this.InternalVoiceRegions)); + this.RestClient = this.ApiClient.Rest.HttpClient; + var a = typeof(DiscordClient).GetTypeInfo().Assembly; var iv = a.GetCustomAttribute(); if (iv != null) { this.VersionString = iv.InformationalVersion; } else { var v = a.GetName().Version; var vs = v.ToString(3); if (v.Revision > 0) this.VersionString = $"{vs}, CI build {v.Revision}"; } this.BotLibrary = "DisCatSharp"; this.ServiceProvider = config.ServiceProvider; } /// /// Gets the current API application. /// /// Current API application. public async Task GetCurrentApplicationAsync() { var tapp = await this.ApiClient.GetCurrentApplicationInfoAsync().ConfigureAwait(false); var app = new DiscordApplication { Discord = this, Id = tapp.Id, Name = tapp.Name, Description = tapp.Description, Summary = tapp.Summary, IconHash = tapp.IconHash, RpcOrigins = tapp.RpcOrigins != null ? new ReadOnlyCollection(tapp.RpcOrigins) : null, Flags = tapp.Flags, RequiresCodeGrant = tapp.BotRequiresCodeGrant, IsPublic = tapp.IsPublicBot, PrivacyPolicyUrl = tapp.PrivacyPolicyUrl, TermsOfServiceUrl = tapp.TermsOfServiceUrl, CustomInstallUrl = tapp.CustomInstallUrl, InstallParams = tapp.InstallParams, Tags = (tapp.Tags ?? Enumerable.Empty()).ToArray() }; // do team and owners // tbh fuck doing this properly if (tapp.Team == null) { // singular owner app.Owners = new ReadOnlyCollection(new[] { new DiscordUser(tapp.Owner) }); app.Team = null; app.TeamName = null; } else { // team owner app.Team = new DiscordTeam(tapp.Team); var members = tapp.Team.Members .Select(x => new DiscordTeamMember(x) { Team = app.Team, User = new DiscordUser(x.User) }) .ToArray(); var owners = members .Where(x => x.MembershipStatus == DiscordTeamMembershipStatus.Accepted) .Select(x => x.User) .ToArray(); app.Owners = new ReadOnlyCollection(owners); app.Team.Owner = owners.FirstOrDefault(x => x.Id == tapp.Team.OwnerId); app.Team.Members = new ReadOnlyCollection(members); app.TeamName = app.Team.Name; } app.GuildId = tapp.GuildId.HasValue ? tapp.GuildId.Value : null; app.Slug = tapp.Slug.HasValue ? tapp.Slug.Value : null; app.PrimarySkuId = tapp.PrimarySkuId.HasValue ? tapp.PrimarySkuId.Value : null; app.VerifyKey = tapp.VerifyKey.HasValue ? tapp.VerifyKey.Value : null; app.CoverImageHash = tapp.CoverImageHash.HasValue ? tapp.CoverImageHash.Value : null; return app; } /// /// Gets a list of regions /// /// /// Thrown when Discord is unable to process the request. public Task> ListVoiceRegionsAsync() => this.ApiClient.ListVoiceRegionsAsync(); /// /// Initializes this client. This method fetches information about current user, application, and voice regions. /// /// public virtual async Task InitializeAsync() { if (this.CurrentUser == null) { this.CurrentUser = await this.ApiClient.GetCurrentUserAsync().ConfigureAwait(false); this.UserCache.AddOrUpdate(this.CurrentUser.Id, this.CurrentUser, (id, xu) => this.CurrentUser); } if (this.Configuration.TokenType == TokenType.Bot && this.CurrentApplication == null) this.CurrentApplication = await this.GetCurrentApplicationAsync().ConfigureAwait(false); if (this.Configuration.TokenType != TokenType.Bearer && this.InternalVoiceRegions.Count == 0) { var vrs = await this.ListVoiceRegionsAsync().ConfigureAwait(false); foreach (var xvr in vrs) this.InternalVoiceRegions.TryAdd(xvr.Id, xvr); } } /// /// Gets the current gateway info for the provided token. /// If no value is provided, the configuration value will be used instead. /// /// A gateway info object. public async Task GetGatewayInfoAsync(string token = null) { if (this.Configuration.TokenType != TokenType.Bot) throw new InvalidOperationException("Only bot tokens can access this info."); if (string.IsNullOrEmpty(this.Configuration.Token)) { if (string.IsNullOrEmpty(token)) throw new InvalidOperationException("Could not locate a valid token."); this.Configuration.Token = token; var res = await this.ApiClient.GetGatewayInfoAsync().ConfigureAwait(false); this.Configuration.Token = null; return res; } return await this.ApiClient.GetGatewayInfoAsync().ConfigureAwait(false); } /// /// Gets a cached user. /// /// The user_id. internal DiscordUser GetCachedOrEmptyUserInternal(ulong user_id) { this.TryGetCachedUserInternal(user_id, out var user); return user; } /// /// Tries the get a cached user. /// /// The user_id. /// The user. internal bool TryGetCachedUserInternal(ulong user_id, out DiscordUser user) { if (this.UserCache.TryGetValue(user_id, out user)) return true; user = new DiscordUser { Id = user_id, Discord = this }; return false; } /// /// Disposes this client. /// public abstract void Dispose(); } } diff --git a/DisCatSharp/Net/Rest/RestClient.cs b/DisCatSharp/Net/Rest/RestClient.cs index 1495f7348..9ebe93897 100644 --- a/DisCatSharp/Net/Rest/RestClient.cs +++ b/DisCatSharp/Net/Rest/RestClient.cs @@ -1,860 +1,860 @@ // This file is part of the DisCatSharp project. // // Copyright (c) 2021 AITSYS // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in all // copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Globalization; using System.Linq; using System.Net; using System.Net.Http; using System.Net.Http.Headers; using System.Reflection; using System.Text; using System.Text.RegularExpressions; using System.Threading; using System.Threading.Tasks; using DisCatSharp.Exceptions; using Microsoft.Extensions.Logging; namespace DisCatSharp.Net { /// /// Represents a client used to make REST requests. /// internal sealed class RestClient : IDisposable { /// /// Gets the route argument regex. /// private static Regex RouteArgumentRegex { get; } = new Regex(@":([a-z_]+)"); /// /// Gets the http client. /// - private HttpClient HttpClient { get; } + internal HttpClient HttpClient { get; } /// /// Gets the discord client. /// private BaseDiscordClient Discord { get; } /// /// Gets a value indicating whether debug is enabled. /// internal bool Debug { get; set; } /// /// Gets the logger. /// private ILogger Logger { get; } /// /// Gets the routes to hashes. /// private ConcurrentDictionary RoutesToHashes { get; } /// /// Gets the hashes to buckets. /// private ConcurrentDictionary HashesToBuckets { get; } /// /// Gets the request queue. /// private ConcurrentDictionary RequestQueue { get; } /// /// Gets the global rate limit event. /// private AsyncManualResetEvent GlobalRateLimitEvent { get; } /// /// Gets a value indicating whether use reset after. /// private bool UseResetAfter { get; } private CancellationTokenSource _bucketCleanerTokenSource; private TimeSpan _bucketCleanupDelay = TimeSpan.FromSeconds(60); private volatile bool _cleanerRunning; private Task _cleanerTask; private volatile bool _disposed; /// /// Initializes a new instance of the class. /// /// The client. internal RestClient(BaseDiscordClient client) : this(client.Configuration.Proxy, client.Configuration.HttpTimeout, client.Configuration.UseRelativeRatelimit, client.Logger) { this.Discord = client; this.HttpClient.DefaultRequestHeaders.TryAddWithoutValidation("Authorization", Utilities.GetFormattedToken(client)); if (client.Configuration.Override != null) { this.HttpClient.DefaultRequestHeaders.TryAddWithoutValidation("x-super-properties", client.Configuration.Override); } } /// /// Initializes a new instance of the class. /// /// The proxy. /// The timeout. /// If true, use relative ratelimit. /// The logger. internal RestClient(IWebProxy proxy, TimeSpan timeout, bool useRelativeRatelimit, ILogger logger) // This is for meta-clients, such as the webhook client { this.Logger = logger; var httphandler = new HttpClientHandler { UseCookies = false, AutomaticDecompression = DecompressionMethods.Deflate | DecompressionMethods.GZip, UseProxy = proxy != null, Proxy = proxy }; this.HttpClient = new HttpClient(httphandler) { BaseAddress = new Uri(Utilities.GetApiBaseUri(this.Discord?.Configuration)), Timeout = timeout }; this.HttpClient.DefaultRequestHeaders.TryAddWithoutValidation("User-Agent", Utilities.GetUserAgent()); if (this.Discord != null && this.Discord.Configuration != null && this.Discord.Configuration.Override != null) { this.HttpClient.DefaultRequestHeaders.TryAddWithoutValidation("x-super-properties", this.Discord.Configuration.Override); } this.RoutesToHashes = new ConcurrentDictionary(); this.HashesToBuckets = new ConcurrentDictionary(); this.RequestQueue = new ConcurrentDictionary(); this.GlobalRateLimitEvent = new AsyncManualResetEvent(true); this.UseResetAfter = useRelativeRatelimit; } /// /// Gets a bucket. /// /// The method. /// The route. /// The route paramaters. /// The url. /// A ratelimit bucket. public RateLimitBucket GetBucket(RestRequestMethod method, string route, object route_params, out string url) { var rparams_props = route_params.GetType() .GetTypeInfo() .DeclaredProperties; var rparams = new Dictionary(); foreach (var xp in rparams_props) { var val = xp.GetValue(route_params); rparams[xp.Name] = val is string xs ? xs : val is DateTime dt ? dt.ToString("yyyy-MM-ddTHH:mm:sszzz", CultureInfo.InvariantCulture) : val is DateTimeOffset dto ? dto.ToString("yyyy-MM-ddTHH:mm:sszzz", CultureInfo.InvariantCulture) : val is IFormattable xf ? xf.ToString(null, CultureInfo.InvariantCulture) : val.ToString(); } var guild_id = rparams.ContainsKey("guild_id") ? rparams["guild_id"] : ""; var channel_id = rparams.ContainsKey("channel_id") ? rparams["channel_id"] : ""; var webhook_id = rparams.ContainsKey("webhook_id") ? rparams["webhook_id"] : ""; // Create a generic route (minus major params) key // ex: POST:/channels/channel_id/messages var hashKey = RateLimitBucket.GenerateHashKey(method, route); // We check if the hash is present, using our generic route (without major params) // ex: in POST:/channels/channel_id/messages, out 80c17d2f203122d936070c88c8d10f33 // If it doesn't exist, we create an unlimited hash as our initial key in the form of the hash key + the unlimited constant // and assign this to the route to hash cache // ex: this.RoutesToHashes[POST:/channels/channel_id/messages] = POST:/channels/channel_id/messages:unlimited var hash = this.RoutesToHashes.GetOrAdd(hashKey, RateLimitBucket.GenerateUnlimitedHash(method, route)); // Next we use the hash to generate the key to obtain the bucket. // ex: 80c17d2f203122d936070c88c8d10f33:guild_id:506128773926879242:webhook_id // or if unlimited: POST:/channels/channel_id/messages:unlimited:guild_id:506128773926879242:webhook_id var bucketId = RateLimitBucket.GenerateBucketId(hash, guild_id, channel_id, webhook_id); // If it's not in cache, create a new bucket and index it by its bucket id. var bucket = this.HashesToBuckets.GetOrAdd(bucketId, new RateLimitBucket(hash, guild_id, channel_id, webhook_id)); bucket.LastAttemptAt = DateTimeOffset.UtcNow; // Cache the routes for each bucket so it can be used for GC later. if (!bucket.RouteHashes.Contains(bucketId)) bucket.RouteHashes.Add(bucketId); // Add the current route to the request queue, which indexes the amount // of requests occurring to the bucket id. _ = this.RequestQueue.TryGetValue(bucketId, out var count); // Increment by one atomically due to concurrency this.RequestQueue[bucketId] = Interlocked.Increment(ref count); // Start bucket cleaner if not already running. if (!this._cleanerRunning) { this._cleanerRunning = true; this._bucketCleanerTokenSource = new CancellationTokenSource(); this._cleanerTask = Task.Run(this.CleanupBucketsAsync, this._bucketCleanerTokenSource.Token); this.Logger.LogDebug(LoggerEvents.RestCleaner, "Bucket cleaner task started."); } url = RouteArgumentRegex.Replace(route, xm => rparams[xm.Groups[1].Value]); return bucket; } /// /// Executes the request async. /// /// The request to be executed. public Task ExecuteRequestAsync(BaseRestRequest request) => request == null ? throw new ArgumentNullException(nameof(request)) : this.ExecuteRequestAsync(request, null, null); /// /// Executes the request async. /// This is to allow proper rescheduling of the first request from a bucket. /// /// The request to be executed. /// The bucket. /// The ratelimit task completion source. private async Task ExecuteRequestAsync(BaseRestRequest request, RateLimitBucket bucket, TaskCompletionSource ratelimitTcs) { if (this._disposed) return; HttpResponseMessage res = default; try { await this.GlobalRateLimitEvent.WaitAsync().ConfigureAwait(false); if (bucket == null) bucket = request.RateLimitBucket; if (ratelimitTcs == null) ratelimitTcs = await this.WaitForInitialRateLimit(bucket).ConfigureAwait(false); if (ratelimitTcs == null) // ckeck rate limit only if we are not the probe request { var now = DateTimeOffset.UtcNow; await bucket.TryResetLimitAsync(now).ConfigureAwait(false); // Decrement the remaining number of requests as there can be other concurrent requests before this one finishes and has a chance to update the bucket if (Interlocked.Decrement(ref bucket._remaining) < 0) { this.Logger.LogDebug(LoggerEvents.RatelimitDiag, "Request for {0} is blocked", bucket.ToString()); var delay = bucket.Reset - now; var resetDate = bucket.Reset; if (this.UseResetAfter) { delay = bucket.ResetAfter.Value; resetDate = bucket.ResetAfterOffset; } if (delay < new TimeSpan(-TimeSpan.TicksPerMinute)) { this.Logger.LogError(LoggerEvents.RatelimitDiag, "Failed to retrieve ratelimits - giving up and allowing next request for bucket"); bucket._remaining = 1; } if (delay < TimeSpan.Zero) delay = TimeSpan.FromMilliseconds(100); this.Logger.LogWarning(LoggerEvents.RatelimitPreemptive, "Pre-emptive ratelimit triggered - waiting until {0:yyyy-MM-dd HH:mm:ss zzz} ({1:c}).", resetDate, delay); Task.Delay(delay) .ContinueWith(_ => this.ExecuteRequestAsync(request, null, null)) .LogTaskFault(this.Logger, LogLevel.Error, LoggerEvents.RestError, "Error while executing request"); return; } this.Logger.LogDebug(LoggerEvents.RatelimitDiag, "Request for {0} is allowed", bucket.ToString()); } else this.Logger.LogDebug(LoggerEvents.RatelimitDiag, "Initial request for {0} is allowed", bucket.ToString()); var req = this.BuildRequest(request); if(this.Debug) this.Logger.LogTrace(LoggerEvents.Misc, await req.Content.ReadAsStringAsync()); var response = new RestResponse(); try { if (this._disposed) return; res = await this.HttpClient.SendAsync(req, HttpCompletionOption.ResponseContentRead, CancellationToken.None).ConfigureAwait(false); var bts = await res.Content.ReadAsByteArrayAsync().ConfigureAwait(false); var txt = Utilities.UTF8.GetString(bts, 0, bts.Length); this.Logger.LogTrace(LoggerEvents.RestRx, txt); response.Headers = res.Headers.ToDictionary(xh => xh.Key, xh => string.Join("\n", xh.Value), StringComparer.OrdinalIgnoreCase); response.Response = txt; response.ResponseCode = (int)res.StatusCode; } catch (HttpRequestException httpex) { this.Logger.LogError(LoggerEvents.RestError, httpex, "Request to {0} triggered an HttpException", request.Url); request.SetFaulted(httpex); this.FailInitialRateLimitTest(request, ratelimitTcs); return; } this.UpdateBucket(request, response, ratelimitTcs); Exception ex = null; switch (response.ResponseCode) { case 400: case 405: ex = new BadRequestException(request, response); break; case 401: case 403: ex = new UnauthorizedException(request, response); break; case 404: ex = new NotFoundException(request, response); break; case 413: ex = new RequestSizeException(request, response); break; case 429: ex = new RateLimitException(request, response); // check the limit info and requeue this.Handle429(response, out var wait, out var global); if (wait != null) { if (global) { bucket.IsGlobal = true; this.Logger.LogError(LoggerEvents.RatelimitHit, "Global ratelimit hit, cooling down"); try { this.GlobalRateLimitEvent.Reset(); await wait.ConfigureAwait(false); } finally { // we don't want to wait here until all the blocked requests have been run, additionally Set can never throw an exception that could be suppressed here _ = this.GlobalRateLimitEvent.SetAsync(); } this.ExecuteRequestAsync(request, bucket, ratelimitTcs) .LogTaskFault(this.Logger, LogLevel.Error, LoggerEvents.RestError, "Error while retrying request"); } else { this.Logger.LogError(LoggerEvents.RatelimitHit, "Ratelimit hit, requeueing request to {0}", request.Url); await wait.ConfigureAwait(false); this.ExecuteRequestAsync(request, bucket, ratelimitTcs) .LogTaskFault(this.Logger, LogLevel.Error, LoggerEvents.RestError, "Error while retrying request"); } return; } break; case 500: case 502: case 503: case 504: ex = new ServerErrorException(request, response); break; } if (ex != null) request.SetFaulted(ex); else request.SetCompleted(response); } catch (Exception ex) { this.Logger.LogError(LoggerEvents.RestError, ex, "Request to {0} triggered an exception", request.Url); // if something went wrong and we couldn't get rate limits for the first request here, allow the next request to run if (bucket != null && ratelimitTcs != null && bucket._limitTesting != 0) this.FailInitialRateLimitTest(request, ratelimitTcs); if (!request.TrySetFaulted(ex)) throw; } finally { res?.Dispose(); // Get and decrement active requests in this bucket by 1. _ = this.RequestQueue.TryGetValue(bucket.BucketId, out var count); this.RequestQueue[bucket.BucketId] = Interlocked.Decrement(ref count); // If it's 0 or less, we can remove the bucket from the active request queue, // along with any of its past routes. if (count <= 0) { foreach (var r in bucket.RouteHashes) { if (this.RequestQueue.ContainsKey(r)) { _ = this.RequestQueue.TryRemove(r, out _); } } } } } /// /// Fails the initial rate limit test. /// /// The request. /// The ratelimit task completion source. /// If true, reset to initial. private void FailInitialRateLimitTest(BaseRestRequest request, TaskCompletionSource ratelimitTcs, bool resetToInitial = false) { if (ratelimitTcs == null && !resetToInitial) return; var bucket = request.RateLimitBucket; bucket._limitValid = false; bucket._limitTestFinished = null; bucket._limitTesting = 0; //Reset to initial values. if (resetToInitial) { this.UpdateHashCaches(request, bucket); bucket.Maximum = 0; bucket._remaining = 0; return; } // no need to wait on all the potentially waiting tasks _ = Task.Run(() => ratelimitTcs.TrySetResult(false)); } /// /// Waits for the initial rate limit. /// /// The bucket. private async Task> WaitForInitialRateLimit(RateLimitBucket bucket) { while (!bucket._limitValid) { if (bucket._limitTesting == 0) { if (Interlocked.CompareExchange(ref bucket._limitTesting, 1, 0) == 0) { // if we got here when the first request was just finishing, we must not create the waiter task as it would signel ExecureRequestAsync to bypass rate limiting if (bucket._limitValid) return null; // allow exactly one request to go through without having rate limits available var ratelimitsTcs = new TaskCompletionSource(); bucket._limitTestFinished = ratelimitsTcs.Task; return ratelimitsTcs; } } // it can take a couple of cycles for the task to be allocated, so wait until it happens or we are no longer probing for the limits Task waitTask = null; while (bucket._limitTesting != 0 && (waitTask = bucket._limitTestFinished) == null) await Task.Yield(); if (waitTask != null) await waitTask.ConfigureAwait(false); // if the request failed and the response did not have rate limit headers we have allow the next request and wait again, thus this is a loop here } return null; } /// /// Builds the request. /// /// The request. /// A http request message. private HttpRequestMessage BuildRequest(BaseRestRequest request) { var req = new HttpRequestMessage(new HttpMethod(request.Method.ToString()), request.Url); if (request.Headers != null && request.Headers.Any()) foreach (var kvp in request.Headers) req.Headers.Add(kvp.Key, kvp.Value); if (request is RestRequest nmprequest && !string.IsNullOrWhiteSpace(nmprequest.Payload)) { this.Logger.LogTrace(LoggerEvents.RestTx, nmprequest.Payload); req.Content = new StringContent(nmprequest.Payload); req.Content.Headers.ContentType = new MediaTypeHeaderValue("application/json"); } if (request is MultipartWebRequest mprequest) { this.Logger.LogTrace(LoggerEvents.RestTx, ""); var boundary = "---------------------------" + DateTime.Now.Ticks.ToString("x"); req.Headers.Add("Connection", "keep-alive"); req.Headers.Add("Keep-Alive", "600"); var content = new MultipartFormDataContent(boundary); if (mprequest.Values != null && mprequest.Values.Any()) foreach (var kvp in mprequest.Values) content.Add(new StringContent(kvp.Value), kvp.Key); var fileId = mprequest.OverwriteFileIdStart ?? 0; if (mprequest.Files != null && mprequest.Files.Any()) { foreach (var f in mprequest.Files) { var name = $"files[{fileId.ToString(CultureInfo.InvariantCulture)}]"; content.Add(new StreamContent(f.Value), name, f.Key); fileId++; } } req.Content = content; } if (request is MultipartStickerWebRequest mpsrequest) { this.Logger.LogTrace(LoggerEvents.RestTx, ""); var boundary = "---------------------------" + DateTime.Now.Ticks.ToString("x"); req.Headers.Add("Connection", "keep-alive"); req.Headers.Add("Keep-Alive", "600"); var sc = new StreamContent(mpsrequest.File.Stream); if (mpsrequest.File.ContentType != null) sc.Headers.ContentType = new MediaTypeHeaderValue(mpsrequest.File.ContentType); var fileName = mpsrequest.File.FileName; if (mpsrequest.File.FileType != null) fileName += '.' + mpsrequest.File.FileType; var content = new MultipartFormDataContent(boundary) { { new StringContent(mpsrequest.Name), "name" }, { new StringContent(mpsrequest.Tags), "tags" }, { new StringContent(mpsrequest.Description), "description" }, { sc, "file", fileName } }; req.Content = content; } return req; } /// /// Handles the http 429 status. /// /// The response. /// The wait task. /// If true, global. private void Handle429(RestResponse response, out Task wait_task, out bool global) { wait_task = null; global = false; if (response.Headers == null) return; var hs = response.Headers; // handle the wait if (hs.TryGetValue("Retry-After", out var retry_after_raw)) { var retry_after = TimeSpan.FromSeconds(int.Parse(retry_after_raw, CultureInfo.InvariantCulture)); wait_task = Task.Delay(retry_after); } // check if global b1nzy if (hs.TryGetValue("X-RateLimit-Global", out var isglobal) && isglobal.ToLowerInvariant() == "true") { // global global = true; } } /// /// Updates the bucket. /// /// The request. /// The response. /// The ratelimit task completion source. private void UpdateBucket(BaseRestRequest request, RestResponse response, TaskCompletionSource ratelimitTcs) { var bucket = request.RateLimitBucket; if (response.Headers == null) { if (response.ResponseCode != 429) // do not fail when ratelimit was or the next request will be scheduled hitting the rate limit again this.FailInitialRateLimitTest(request, ratelimitTcs); return; } var hs = response.Headers; if (hs.TryGetValue("X-RateLimit-Scope", out var scope)) { bucket.Scope = scope; } if (hs.TryGetValue("X-RateLimit-Global", out var isglobal) && isglobal.ToLowerInvariant() == "true") { if (response.ResponseCode != 429) { bucket.IsGlobal = true; this.FailInitialRateLimitTest(request, ratelimitTcs); } return; } var r1 = hs.TryGetValue("X-RateLimit-Limit", out var usesmax); var r2 = hs.TryGetValue("X-RateLimit-Remaining", out var usesleft); var r3 = hs.TryGetValue("X-RateLimit-Reset", out var reset); var r4 = hs.TryGetValue("X-Ratelimit-Reset-After", out var resetAfter); var r5 = hs.TryGetValue("X-Ratelimit-Bucket", out var hash); if (!r1 || !r2 || !r3 || !r4) { //If the limits were determined before this request, make the bucket initial again. if (response.ResponseCode != 429) this.FailInitialRateLimitTest(request, ratelimitTcs, ratelimitTcs == null); return; } var clienttime = DateTimeOffset.UtcNow; var resettime = new DateTimeOffset(1970, 1, 1, 0, 0, 0, TimeSpan.Zero).AddSeconds(double.Parse(reset, CultureInfo.InvariantCulture)); var servertime = clienttime; if (hs.TryGetValue("Date", out var raw_date)) servertime = DateTimeOffset.Parse(raw_date, CultureInfo.InvariantCulture).ToUniversalTime(); var resetdelta = resettime - servertime; //var difference = clienttime - servertime; //if (Math.Abs(difference.TotalSeconds) >= 1) //// this.Logger.LogMessage(LogLevel.DebugBaseDiscordClient.RestEventId, $"Difference between machine and server time: {difference.TotalMilliseconds.ToString("#,##0.00", CultureInfo.InvariantCulture)}ms", DateTime.Now); //else // difference = TimeSpan.Zero; if (request.RateLimitWaitOverride.HasValue) resetdelta = TimeSpan.FromSeconds(request.RateLimitWaitOverride.Value); var newReset = clienttime + resetdelta; if (this.UseResetAfter) { bucket.ResetAfter = TimeSpan.FromSeconds(double.Parse(resetAfter, CultureInfo.InvariantCulture)); newReset = clienttime + bucket.ResetAfter.Value + (request.RateLimitWaitOverride.HasValue ? resetdelta : TimeSpan.Zero); bucket.ResetAfterOffset = newReset; } else bucket.Reset = newReset; var maximum = int.Parse(usesmax, CultureInfo.InvariantCulture); var remaining = int.Parse(usesleft, CultureInfo.InvariantCulture); if (ratelimitTcs != null) { // initial population of the ratelimit data bucket.SetInitialValues(maximum, remaining, newReset); _ = Task.Run(() => ratelimitTcs.TrySetResult(true)); } else { // only update the bucket values if this request was for a newer interval than the one // currently in the bucket, to avoid issues with concurrent requests in one bucket // remaining is reset by TryResetLimit and not the response, just allow that to happen when it is time if (bucket._nextReset == 0) bucket._nextReset = newReset.UtcTicks; } this.UpdateHashCaches(request, bucket, hash); } /// /// Updates the hash caches. /// /// The request. /// The bucket. /// The new hash. private void UpdateHashCaches(BaseRestRequest request, RateLimitBucket bucket, string newHash = null) { var hashKey = RateLimitBucket.GenerateHashKey(request.Method, request.Route); if (!this.RoutesToHashes.TryGetValue(hashKey, out var oldHash)) return; // This is an unlimited bucket, which we don't need to keep track of. if (newHash == null) { _ = this.RoutesToHashes.TryRemove(hashKey, out _); _ = this.HashesToBuckets.TryRemove(bucket.BucketId, out _); return; } // Only update the hash once, due to a bug on Discord's end. // This will cause issues if the bucket hashes are dynamically changed from the API while running, // in which case, Dispose will need to be called to clear the caches. if (bucket._isUnlimited && newHash != oldHash) { this.Logger.LogDebug(LoggerEvents.RestHashMover, "Updating hash in {0}: \"{1}\" -> \"{2}\"", hashKey, oldHash, newHash); var bucketId = RateLimitBucket.GenerateBucketId(newHash, bucket.GuildId, bucket.ChannelId, bucket.WebhookId); _ = this.RoutesToHashes.AddOrUpdate(hashKey, newHash, (key, oldHash) => { bucket.Hash = newHash; var oldBucketId = RateLimitBucket.GenerateBucketId(oldHash, bucket.GuildId, bucket.ChannelId, bucket.WebhookId); // Remove the old unlimited bucket. _ = this.HashesToBuckets.TryRemove(oldBucketId, out _); _ = this.HashesToBuckets.AddOrUpdate(bucketId, bucket, (key, oldBucket) => bucket); return newHash; }); } return; } /// /// Cleanups the buckets. /// private async Task CleanupBucketsAsync() { while (!this._bucketCleanerTokenSource.IsCancellationRequested) { try { await Task.Delay(this._bucketCleanupDelay, this._bucketCleanerTokenSource.Token).ConfigureAwait(false); } catch { } if (this._disposed) return; //Check and clean request queue first in case it wasn't removed properly during requests. foreach (var key in this.RequestQueue.Keys) { var bucket = this.HashesToBuckets.Values.FirstOrDefault(x => x.RouteHashes.Contains(key)); if (bucket == null || (bucket != null && bucket.LastAttemptAt.AddSeconds(5) < DateTimeOffset.UtcNow)) _ = this.RequestQueue.TryRemove(key, out _); } var removedBuckets = 0; StringBuilder bucketIdStrBuilder = default; foreach (var kvp in this.HashesToBuckets) { if (bucketIdStrBuilder == null) bucketIdStrBuilder = new StringBuilder(); var key = kvp.Key; var value = kvp.Value; // Don't remove the bucket if it's currently being handled by the rest client, unless it's an unlimited bucket. if (this.RequestQueue.ContainsKey(value.BucketId) && !value._isUnlimited) continue; var resetOffset = this.UseResetAfter ? value.ResetAfterOffset : value.Reset; // Don't remove the bucket if it's reset date is less than now + the additional wait time, unless it's an unlimited bucket. if (resetOffset != null && !value._isUnlimited && (resetOffset > DateTimeOffset.UtcNow || DateTimeOffset.UtcNow - resetOffset < this._bucketCleanupDelay)) continue; _ = this.HashesToBuckets.TryRemove(key, out _); removedBuckets++; bucketIdStrBuilder.Append(value.BucketId + ", "); } if (removedBuckets > 0) this.Logger.LogDebug(LoggerEvents.RestCleaner, "Removed {0} unused bucket{1}: [{2}]", removedBuckets, removedBuckets > 1 ? "s" : string.Empty, bucketIdStrBuilder.ToString().TrimEnd(',', ' ')); if (this.HashesToBuckets.Count == 0) break; } if (!this._bucketCleanerTokenSource.IsCancellationRequested) this._bucketCleanerTokenSource.Cancel(); this._cleanerRunning = false; this.Logger.LogDebug(LoggerEvents.RestCleaner, "Bucket cleaner task stopped."); } ~RestClient() => this.Dispose(); /// /// Disposes the rest client. /// public void Dispose() { if (this._disposed) return; this._disposed = true; this.GlobalRateLimitEvent.Reset(); if (this._bucketCleanerTokenSource?.IsCancellationRequested == false) { this._bucketCleanerTokenSource?.Cancel(); this.Logger.LogDebug(LoggerEvents.RestCleaner, "Bucket cleaner task stopped."); } try { this._cleanerTask?.Dispose(); this._bucketCleanerTokenSource?.Dispose(); this.HttpClient?.Dispose(); } catch { } this.RoutesToHashes.Clear(); this.HashesToBuckets.Clear(); this.RequestQueue.Clear(); } } }