diff --git a/DisCatSharp/Clients/DiscordClient.WebSocket.cs b/DisCatSharp/Clients/DiscordClient.WebSocket.cs index 77700cf6e..2d3025618 100644 --- a/DisCatSharp/Clients/DiscordClient.WebSocket.cs +++ b/DisCatSharp/Clients/DiscordClient.WebSocket.cs @@ -1,626 +1,628 @@ using System; using System.Collections.Concurrent; using System.IO; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using DisCatSharp.Entities; using DisCatSharp.Enums; using DisCatSharp.EventArgs; using DisCatSharp.Net.Abstractions; using DisCatSharp.Net.WebSocket; using Microsoft.Extensions.Logging; using Newtonsoft.Json; using Newtonsoft.Json.Linq; using Sentry; namespace DisCatSharp; /// /// Represents a discord websocket client. /// public sealed partial class DiscordClient { #region Private Fields /// /// Gets the heartbeat interval. /// private int _heartbeatInterval; /// /// Gets when the last heartbeat was sent. /// private DateTimeOffset _lastHeartbeat; /// /// Gets the heartbeat task. /// private Task _heartbeatTask; /// /// Gets the default discord epoch. /// internal static DateTimeOffset DiscordEpoch = new(2015, 1, 1, 0, 0, 0, TimeSpan.Zero); /// /// Gets the count of skipped heartbeats. /// private int _skippedHeartbeats = 0; /// /// Gets the last sequence number. /// private long _lastSequence = 0; /// /// Gets the websocket client. /// internal IWebSocketClient WebSocketClient; /// /// Gets the payload decompressor. /// private PayloadDecompressor? _payloadDecompressor; /// /// Gets the cancel token source. /// private CancellationTokenSource _cancelTokenSource; /// /// Gets the cancel token. /// private CancellationToken _cancelToken; #endregion #region Connection Semaphore /// /// Gets the socket locks. /// private static ConcurrentDictionary s_socketLocks { get; } = []; /// /// Gets the session lock. /// private readonly ManualResetEventSlim _sessionLock = new(true); #endregion #region Internal Connection Methods /// /// Reconnects the websocket client. /// /// Whether to start a new session. /// The reconnect code. /// The reconnect message. private Task InternalReconnectAsync(bool startNewSession = false, int code = 1000, string message = "") { if (startNewSession) this._sessionId = null; _ = this.WebSocketClient.DisconnectAsync(code, message); return Task.CompletedTask; } /// /// Connects the websocket client. /// internal async Task InternalConnectAsync() { SocketLock? socketLock = null; try { if (this.GatewayInfo is null) await this.InternalUpdateGatewayAsync().ConfigureAwait(false); await this.InitializeAsync().ConfigureAwait(false); socketLock = this.GetSocketLock(); await socketLock.LockAsync().ConfigureAwait(false); } catch { socketLock?.UnlockAfter(TimeSpan.Zero); throw; } if (!this.Presences.ContainsKey(this.CurrentUser.Id)) this.PresencesInternal[this.CurrentUser.Id] = new() { Discord = this, RawActivity = new(), Activity = new(), Status = UserStatus.Online, InternalUser = new() { Id = this.CurrentUser.Id } }; else { var pr = this.PresencesInternal[this.CurrentUser.Id]; pr.RawActivity = new(); pr.Activity = new(); pr.Status = UserStatus.Online; } Volatile.Write(ref this._skippedHeartbeats, 0); this.WebSocketClient = this.Configuration.WebSocketClientFactory(this.Configuration.Proxy, this.ServiceProvider); this._payloadDecompressor = this.Configuration.GatewayCompressionLevel is not GatewayCompressionLevel.None ? new PayloadDecompressor(this.Configuration.GatewayCompressionLevel) : null; this._cancelTokenSource = new(); this._cancelToken = this._cancelTokenSource.Token; this.WebSocketClient.Connected += SocketOnConnect; this.WebSocketClient.Disconnected += SocketOnDisconnect; this.WebSocketClient.MessageReceived += SocketOnMessage; this.WebSocketClient.ExceptionThrown += SocketOnException; var gwuri = this.GatewayUri.AddParameter("v", this.Configuration.ApiVersion).AddParameter("encoding", "json"); if (this.Configuration.GatewayCompressionLevel == GatewayCompressionLevel.Stream) gwuri = gwuri.AddParameter("compress", "zlib-stream"); this.GatewayUri = gwuri; this.Logger.LogDebug(LoggerEvents.Startup, "Connecting to {gw}", this.GatewayUri.AbsoluteUri); await this.WebSocketClient.ConnectAsync(this.GatewayUri).ConfigureAwait(false); return; Task SocketOnConnect(IWebSocketClient sender, SocketEventArgs e) => this._socketOpened.InvokeAsync(this, e); async Task SocketOnMessage(IWebSocketClient sender, SocketMessageEventArgs e) { string? msg = null; switch (e) { case SocketTextMessageEventArgs etext: msg = etext.Message; break; case SocketBinaryMessageEventArgs ebin: { using var ms = new MemoryStream(); if (this._payloadDecompressor?.TryDecompress(new(ebin.Message), ms) is false or null) { this.Logger.LogError(LoggerEvents.WebSocketReceiveFailure, "Payload decompression failed"); return; } ms.Position = 0; using var sr = new StreamReader(ms, Utilities.UTF8); msg = await sr.ReadToEndAsync(this._cancelToken).ConfigureAwait(false); break; } } try { if (msg is not null) { this.Logger.LogTrace(LoggerEvents.GatewayWsRx, "{Message}", msg); await this.HandleSocketMessageAsync(msg).ConfigureAwait(false); } } catch (Exception ex) { this.Logger.LogError(LoggerEvents.WebSocketReceiveFailure, ex, "Socket handler suppressed an exception"); if (this.Configuration.EnableSentry) this.Sentry.CaptureException(ex); } } Task SocketOnException(IWebSocketClient sender, SocketErrorEventArgs e) => this._socketErrored.InvokeAsync(this, e); async Task SocketOnDisconnect(IWebSocketClient sender, SocketCloseEventArgs e) { // release session and connection this._connectionLock.Set(); this._sessionLock.Set(); if (!this._disposed) this._cancelTokenSource.Cancel(); this.Logger.LogDebug(LoggerEvents.ConnectionClose, "Connection closed ({CloseCode}, '{Reason}')", e.CloseCode, e.CloseMessage ?? "No reason given"); await this._socketClosed.InvokeAsync(this, e).ConfigureAwait(false); // TODO: We might need to include more 400X codes if (this.Configuration.AutoReconnect && e.CloseCode is < 4001 or >= 5000) { this.Logger.LogCritical(LoggerEvents.ConnectionClose, "Connection terminated ({CloseCode}, '{Reason}'), reconnecting", e.CloseCode, e.CloseMessage ?? "No reason given"); if (this._status is null) await this.ConnectAsync().ConfigureAwait(false); else if (this._status.IdleSince.HasValue) await this.ConnectAsync(this._status.ActivityInternal, this._status.Status, Utilities.GetDateTimeOffsetFromMilliseconds(this._status.IdleSince.Value)).ConfigureAwait(false); else await this.ConnectAsync(this._status.ActivityInternal, this._status.Status).ConfigureAwait(false); } else this.Logger.LogCritical(LoggerEvents.ConnectionClose, "Connection terminated ({CloseCode}, '{Reason}')", e.CloseCode, e.CloseMessage ?? "No reason given"); } } #endregion #region WebSocket (Events) /// /// Handles the socket message. /// /// The data. internal async Task HandleSocketMessageAsync(string data) { var payload = JsonConvert.DeserializeObject(data)!; this._lastSequence = payload.Sequence ?? this._lastSequence; switch (payload.OpCode) { case GatewayOpCode.Dispatch: _ = Task.Run(async () => await this.HandleDispatchAsync(payload).ConfigureAwait(false), this._cancelToken); break; case GatewayOpCode.Heartbeat: await this.OnHeartbeatAsync((long)payload.Data).ConfigureAwait(false); break; case GatewayOpCode.Reconnect: await this.OnReconnectAsync().ConfigureAwait(false); break; case GatewayOpCode.InvalidSession: await this.OnInvalidateSessionAsync((bool)payload.Data).ConfigureAwait(false); break; case GatewayOpCode.Hello: await this.OnHelloAsync((payload.Data as JObject).ToObject()).ConfigureAwait(false); break; case GatewayOpCode.HeartbeatAck: await this.OnHeartbeatAckAsync().ConfigureAwait(false); break; case GatewayOpCode.Identify: case GatewayOpCode.StatusUpdate: case GatewayOpCode.VoiceStateUpdate: case GatewayOpCode.VoiceServerPing: case GatewayOpCode.Resume: case GatewayOpCode.RequestGuildMembers: case GatewayOpCode.GuildSync: this.Logger.LogTrace(LoggerEvents.WebSocketReceive, "Received op code for non-bot event"); break; default: this.Logger.LogWarning(LoggerEvents.WebSocketReceive, "Unknown Discord opcode: {OpCode}\nPayload: {Payload}", payload.OpCode, payload.Data); break; } } /// /// Handles the heartbeat. /// /// The sequence. internal async Task OnHeartbeatAsync(long seq) { this.Logger.LogTrace(LoggerEvents.WebSocketReceive, "Received HEARTBEAT (OP1)"); await this.SendHeartbeatAsync(seq).ConfigureAwait(false); } /// /// Handles the reconnect event. /// internal async Task OnReconnectAsync() { this.Logger.LogTrace(LoggerEvents.WebSocketReceive, "Received RECONNECT (OP7)"); await this.InternalReconnectAsync(code: 4000, message: "OP7 acknowledged").ConfigureAwait(false); } /// /// Handles the invalidate session event /// /// Unknown. Please fill documentation. internal async Task OnInvalidateSessionAsync(bool data) { // begin a session if one is not open already if (this._sessionLock.Wait(0)) this._sessionLock.Reset(); // we are sending a fresh resume/identify, so lock the socket var socketLock = this.GetSocketLock(); await socketLock.LockAsync().ConfigureAwait(false); socketLock.UnlockAfter(TimeSpan.FromSeconds(5)); if (data) { this.Logger.LogTrace(LoggerEvents.WebSocketReceive, "Received INVALID_SESSION (OP9, true)"); await Task.Delay(6000, this._cancelToken).ConfigureAwait(false); await this.SendResumeAsync().ConfigureAwait(false); } else { this.Logger.LogTrace(LoggerEvents.WebSocketReceive, "Received INVALID_SESSION (OP9, false)"); this._sessionId = null; - await this.SendIdentifyAsync(this._status).ConfigureAwait(false); + await this.SendIdentifyAsync(this._status, "invalidate").ConfigureAwait(false); } } /// /// Handles the hello event. /// /// The gateway hello payload. internal async Task OnHelloAsync(GatewayHello hello) { this.Logger.LogTrace(LoggerEvents.WebSocketReceive, "Received HELLO (OP10)"); if (this._sessionLock.Wait(0)) { this._sessionLock.Reset(); this.GetSocketLock().UnlockAfter(TimeSpan.FromSeconds(5)); } else { this.Logger.LogWarning(LoggerEvents.SessionUpdate, "Attempt to start a session while another session is active"); return; } Interlocked.CompareExchange(ref this._skippedHeartbeats, 0, 0); this._heartbeatInterval = hello.HeartbeatInterval; this._heartbeatTask = Task.Run(this.HeartbeatLoopAsync, this._cancelToken); if (string.IsNullOrEmpty(this._sessionId)) - await this.SendIdentifyAsync(this._status).ConfigureAwait(false); + await this.SendIdentifyAsync(this._status, "hello").ConfigureAwait(false); else await this.SendResumeAsync().ConfigureAwait(false); } /// /// Handles the heartbeat acknowledge event. /// internal async Task OnHeartbeatAckAsync() { Interlocked.Decrement(ref this._skippedHeartbeats); var ping = (int)(DateTime.Now - this._lastHeartbeat).TotalMilliseconds; this.Logger.LogTrace(LoggerEvents.WebSocketReceive, "Received HEARTBEAT_ACK (OP11, {0}ms)", ping); Volatile.Write(ref this._ping, ping); var args = new HeartbeatEventArgs(this.ServiceProvider) { Ping = this.Ping, Timestamp = DateTimeOffset.Now }; await this._heartbeated.InvokeAsync(this, args).ConfigureAwait(false); } /// /// Handles the heartbeat loop. /// internal async Task HeartbeatLoopAsync() { this.Logger.LogDebug(LoggerEvents.Heartbeat, "Heartbeat task started"); var token = this._cancelToken; try { while (true) { await this.SendHeartbeatAsync(this._lastSequence).ConfigureAwait(false); await Task.Delay(this._heartbeatInterval, token).ConfigureAwait(false); token.ThrowIfCancellationRequested(); } } catch (OperationCanceledException) { } } #endregion #region Internal Gateway Methods /// /// Updates the status. /// /// The activity. /// The optional user status. /// Since when is the client performing the specified activity. internal async Task InternalUpdateStatusAsync(DiscordActivity activity, UserStatus? userStatus, DateTimeOffset? idleSince) { if (activity is { Name.Length: > 128 }) throw new("Game name can't be longer than 128 characters!"); var sinceUnix = idleSince != null ? (long?)Utilities.GetUnixTime(idleSince.Value) : null; var act = activity ?? new DiscordActivity(); var status = new StatusUpdate { Activity = new(act), IdleSince = sinceUnix, IsAfk = idleSince != null, Status = userStatus ?? UserStatus.Online }; // Solution to have status persist between sessions this._status = status; var statusUpdate = new GatewayPayload { OpCode = GatewayOpCode.StatusUpdate, Data = status }; var statusstr = JsonConvert.SerializeObject(statusUpdate); await this.WsSendAsync(statusstr).ConfigureAwait(false); if (!this.PresencesInternal.TryGetValue(this.CurrentUser.Id, out var value)) this.PresencesInternal[this.CurrentUser.Id] = new() { Discord = this, Activity = act, Status = userStatus ?? UserStatus.Online, InternalUser = new() { Id = this.CurrentUser.Id } }; else { value.Activity = act; value.Status = userStatus ?? value.Status; } } /// /// Sends the heartbeat. /// /// The sequenze. internal async Task SendHeartbeatAsync(long seq) { var moreThan5 = Volatile.Read(ref this._skippedHeartbeats) > 5; var guildsComp = Volatile.Read(ref this._guildDownloadCompleted); switch (guildsComp) { case true when moreThan5: { this.Logger.LogCritical(LoggerEvents.HeartbeatFailure, "Server failed to acknowledge more than 5 heartbeats - connection is zombie"); var args = new ZombiedEventArgs(this.ServiceProvider) { Failures = Volatile.Read(ref this._skippedHeartbeats), GuildDownloadCompleted = true }; await this._zombied.InvokeAsync(this, args).ConfigureAwait(false); await this.InternalReconnectAsync(code: 4001, message: "Too many heartbeats missed").ConfigureAwait(false); return; } case false when moreThan5: { var args = new ZombiedEventArgs(this.ServiceProvider) { Failures = Volatile.Read(ref this._skippedHeartbeats), GuildDownloadCompleted = false }; await this._zombied.InvokeAsync(this, args).ConfigureAwait(false); this.Logger.LogWarning(LoggerEvents.HeartbeatFailure, "Server failed to acknowledge more than 5 heartbeats, but the guild download is still running - check your connection speed"); break; } } Volatile.Write(ref this._lastSequence, seq); this.Logger.LogTrace(LoggerEvents.Heartbeat, "Sending heartbeat"); var heartbeat = new GatewayPayload { OpCode = GatewayOpCode.Heartbeat, Data = seq }; var heartbeatStr = JsonConvert.SerializeObject(heartbeat); await this.WsSendAsync(heartbeatStr).ConfigureAwait(false); this._lastHeartbeat = DateTimeOffset.Now; Interlocked.Increment(ref this._skippedHeartbeats); } /// /// Sends the identify payload. /// /// The status update payload. - internal async Task SendIdentifyAsync(StatusUpdate? status) + /// The origin of the call. + internal async Task SendIdentifyAsync(StatusUpdate? status, string origin) { var identify = new GatewayIdentify { Token = Utilities.GetFormattedToken(this), Compress = this.Configuration.GatewayCompressionLevel == GatewayCompressionLevel.Payload, LargeThreshold = this.Configuration.LargeThreshold, ShardInfo = new() { ShardId = this.Configuration.ShardId, ShardCount = this.Configuration.ShardCount }, Presence = status, Intents = this.Configuration.Intents, Discord = this }; var payload = new GatewayPayload { OpCode = GatewayOpCode.Identify, Data = identify }; var payloadstr = JsonConvert.SerializeObject(payload); await this.WsSendAsync(payloadstr).ConfigureAwait(false); - this.Logger.LogDebug(LoggerEvents.Intents, "Registered gateway intents ({Intents})", this.Configuration.Intents); + this.Logger.LogDebug(LoggerEvents.Intents, "Registered gateway intents ({Intents}, {origin})", this.Configuration.Intents, origin); } /// /// Sends the resume payload. /// internal async Task SendResumeAsync() { ArgumentNullException.ThrowIfNull(this._sessionId); ArgumentNullException.ThrowIfNull(this._resumeGatewayUrl); var resume = new GatewayResume { Token = Utilities.GetFormattedToken(this), SessionId = this._sessionId, SequenceNumber = Volatile.Read(ref this._lastSequence) }; var resumePayload = new GatewayPayload { OpCode = GatewayOpCode.Resume, Data = resume }; var resumestr = JsonConvert.SerializeObject(resumePayload); this.GatewayUri = new(this._resumeGatewayUrl); this.Logger.LogDebug(LoggerEvents.ConnectionClose, "Request to resume via {gw}", this.GatewayUri.AbsoluteUri); await this.WsSendAsync(resumestr).ConfigureAwait(false); } /// /// Internals the update gateway async. /// /// A Task. internal async Task InternalUpdateGatewayAsync() { var info = await this.GetGatewayInfoAsync().ConfigureAwait(false); this.GatewayInfo = info; this.GatewayUri = new(info.Url); } /// /// Sends a websocket message. /// /// The payload to send. internal async Task WsSendAsync(string payload) { this.Logger.LogTrace(LoggerEvents.GatewayWsTx, "{Payload}", payload); await this.WebSocketClient.SendMessageAsync(payload).ConfigureAwait(false); } #endregion #region Semaphore Methods /// /// Gets the socket lock. /// /// The added socket lock. private SocketLock GetSocketLock() => s_socketLocks.GetOrAdd(this.CurrentApplication.Id, new SocketLock(this.CurrentApplication.Id, this.GatewayInfo!.SessionBucket.MaxConcurrency)); #endregion }