diff --git a/DisCatSharp/Clients/DiscordClient.WebSocket.cs b/DisCatSharp/Clients/DiscordClient.WebSocket.cs
index 77700cf6e..2d3025618 100644
--- a/DisCatSharp/Clients/DiscordClient.WebSocket.cs
+++ b/DisCatSharp/Clients/DiscordClient.WebSocket.cs
@@ -1,626 +1,628 @@
using System;
using System.Collections.Concurrent;
using System.IO;
+using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using DisCatSharp.Entities;
using DisCatSharp.Enums;
using DisCatSharp.EventArgs;
using DisCatSharp.Net.Abstractions;
using DisCatSharp.Net.WebSocket;
using Microsoft.Extensions.Logging;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using Sentry;
namespace DisCatSharp;
///
/// Represents a discord websocket client.
///
public sealed partial class DiscordClient
{
#region Private Fields
///
/// Gets the heartbeat interval.
///
private int _heartbeatInterval;
///
/// Gets when the last heartbeat was sent.
///
private DateTimeOffset _lastHeartbeat;
///
/// Gets the heartbeat task.
///
private Task _heartbeatTask;
///
/// Gets the default discord epoch.
///
internal static DateTimeOffset DiscordEpoch = new(2015, 1, 1, 0, 0, 0, TimeSpan.Zero);
///
/// Gets the count of skipped heartbeats.
///
private int _skippedHeartbeats = 0;
///
/// Gets the last sequence number.
///
private long _lastSequence = 0;
///
/// Gets the websocket client.
///
internal IWebSocketClient WebSocketClient;
///
/// Gets the payload decompressor.
///
private PayloadDecompressor? _payloadDecompressor;
///
/// Gets the cancel token source.
///
private CancellationTokenSource _cancelTokenSource;
///
/// Gets the cancel token.
///
private CancellationToken _cancelToken;
#endregion
#region Connection Semaphore
///
/// Gets the socket locks.
///
private static ConcurrentDictionary s_socketLocks { get; } = [];
///
/// Gets the session lock.
///
private readonly ManualResetEventSlim _sessionLock = new(true);
#endregion
#region Internal Connection Methods
///
/// Reconnects the websocket client.
///
/// Whether to start a new session.
/// The reconnect code.
/// The reconnect message.
private Task InternalReconnectAsync(bool startNewSession = false, int code = 1000, string message = "")
{
if (startNewSession)
this._sessionId = null;
_ = this.WebSocketClient.DisconnectAsync(code, message);
return Task.CompletedTask;
}
///
/// Connects the websocket client.
///
internal async Task InternalConnectAsync()
{
SocketLock? socketLock = null;
try
{
if (this.GatewayInfo is null)
await this.InternalUpdateGatewayAsync().ConfigureAwait(false);
await this.InitializeAsync().ConfigureAwait(false);
socketLock = this.GetSocketLock();
await socketLock.LockAsync().ConfigureAwait(false);
}
catch
{
socketLock?.UnlockAfter(TimeSpan.Zero);
throw;
}
if (!this.Presences.ContainsKey(this.CurrentUser.Id))
this.PresencesInternal[this.CurrentUser.Id] = new()
{
Discord = this,
RawActivity = new(),
Activity = new(),
Status = UserStatus.Online,
InternalUser = new()
{
Id = this.CurrentUser.Id
}
};
else
{
var pr = this.PresencesInternal[this.CurrentUser.Id];
pr.RawActivity = new();
pr.Activity = new();
pr.Status = UserStatus.Online;
}
Volatile.Write(ref this._skippedHeartbeats, 0);
this.WebSocketClient = this.Configuration.WebSocketClientFactory(this.Configuration.Proxy, this.ServiceProvider);
this._payloadDecompressor = this.Configuration.GatewayCompressionLevel is not GatewayCompressionLevel.None
? new PayloadDecompressor(this.Configuration.GatewayCompressionLevel)
: null;
this._cancelTokenSource = new();
this._cancelToken = this._cancelTokenSource.Token;
this.WebSocketClient.Connected += SocketOnConnect;
this.WebSocketClient.Disconnected += SocketOnDisconnect;
this.WebSocketClient.MessageReceived += SocketOnMessage;
this.WebSocketClient.ExceptionThrown += SocketOnException;
var gwuri = this.GatewayUri.AddParameter("v", this.Configuration.ApiVersion).AddParameter("encoding", "json");
if (this.Configuration.GatewayCompressionLevel == GatewayCompressionLevel.Stream)
gwuri = gwuri.AddParameter("compress", "zlib-stream");
this.GatewayUri = gwuri;
this.Logger.LogDebug(LoggerEvents.Startup, "Connecting to {gw}", this.GatewayUri.AbsoluteUri);
await this.WebSocketClient.ConnectAsync(this.GatewayUri).ConfigureAwait(false);
return;
Task SocketOnConnect(IWebSocketClient sender, SocketEventArgs e)
=> this._socketOpened.InvokeAsync(this, e);
async Task SocketOnMessage(IWebSocketClient sender, SocketMessageEventArgs e)
{
string? msg = null;
switch (e)
{
case SocketTextMessageEventArgs etext:
msg = etext.Message;
break;
case SocketBinaryMessageEventArgs ebin:
{
using var ms = new MemoryStream();
if (this._payloadDecompressor?.TryDecompress(new(ebin.Message), ms) is false or null)
{
this.Logger.LogError(LoggerEvents.WebSocketReceiveFailure, "Payload decompression failed");
return;
}
ms.Position = 0;
using var sr = new StreamReader(ms, Utilities.UTF8);
msg = await sr.ReadToEndAsync(this._cancelToken).ConfigureAwait(false);
break;
}
}
try
{
if (msg is not null)
{
this.Logger.LogTrace(LoggerEvents.GatewayWsRx, "{Message}", msg);
await this.HandleSocketMessageAsync(msg).ConfigureAwait(false);
}
}
catch (Exception ex)
{
this.Logger.LogError(LoggerEvents.WebSocketReceiveFailure, ex, "Socket handler suppressed an exception");
if (this.Configuration.EnableSentry)
this.Sentry.CaptureException(ex);
}
}
Task SocketOnException(IWebSocketClient sender, SocketErrorEventArgs e)
=> this._socketErrored.InvokeAsync(this, e);
async Task SocketOnDisconnect(IWebSocketClient sender, SocketCloseEventArgs e)
{
// release session and connection
this._connectionLock.Set();
this._sessionLock.Set();
if (!this._disposed)
this._cancelTokenSource.Cancel();
this.Logger.LogDebug(LoggerEvents.ConnectionClose, "Connection closed ({CloseCode}, '{Reason}')", e.CloseCode, e.CloseMessage ?? "No reason given");
await this._socketClosed.InvokeAsync(this, e).ConfigureAwait(false);
// TODO: We might need to include more 400X codes
if (this.Configuration.AutoReconnect && e.CloseCode is < 4001 or >= 5000)
{
this.Logger.LogCritical(LoggerEvents.ConnectionClose, "Connection terminated ({CloseCode}, '{Reason}'), reconnecting", e.CloseCode, e.CloseMessage ?? "No reason given");
if (this._status is null)
await this.ConnectAsync().ConfigureAwait(false);
else if (this._status.IdleSince.HasValue)
await this.ConnectAsync(this._status.ActivityInternal, this._status.Status, Utilities.GetDateTimeOffsetFromMilliseconds(this._status.IdleSince.Value)).ConfigureAwait(false);
else
await this.ConnectAsync(this._status.ActivityInternal, this._status.Status).ConfigureAwait(false);
}
else
this.Logger.LogCritical(LoggerEvents.ConnectionClose, "Connection terminated ({CloseCode}, '{Reason}')", e.CloseCode, e.CloseMessage ?? "No reason given");
}
}
#endregion
#region WebSocket (Events)
///
/// Handles the socket message.
///
/// The data.
internal async Task HandleSocketMessageAsync(string data)
{
var payload = JsonConvert.DeserializeObject(data)!;
this._lastSequence = payload.Sequence ?? this._lastSequence;
switch (payload.OpCode)
{
case GatewayOpCode.Dispatch:
_ = Task.Run(async () => await this.HandleDispatchAsync(payload).ConfigureAwait(false), this._cancelToken);
break;
case GatewayOpCode.Heartbeat:
await this.OnHeartbeatAsync((long)payload.Data).ConfigureAwait(false);
break;
case GatewayOpCode.Reconnect:
await this.OnReconnectAsync().ConfigureAwait(false);
break;
case GatewayOpCode.InvalidSession:
await this.OnInvalidateSessionAsync((bool)payload.Data).ConfigureAwait(false);
break;
case GatewayOpCode.Hello:
await this.OnHelloAsync((payload.Data as JObject).ToObject()).ConfigureAwait(false);
break;
case GatewayOpCode.HeartbeatAck:
await this.OnHeartbeatAckAsync().ConfigureAwait(false);
break;
case GatewayOpCode.Identify:
case GatewayOpCode.StatusUpdate:
case GatewayOpCode.VoiceStateUpdate:
case GatewayOpCode.VoiceServerPing:
case GatewayOpCode.Resume:
case GatewayOpCode.RequestGuildMembers:
case GatewayOpCode.GuildSync:
this.Logger.LogTrace(LoggerEvents.WebSocketReceive, "Received op code for non-bot event");
break;
default:
this.Logger.LogWarning(LoggerEvents.WebSocketReceive, "Unknown Discord opcode: {OpCode}\nPayload: {Payload}", payload.OpCode, payload.Data);
break;
}
}
///
/// Handles the heartbeat.
///
/// The sequence.
internal async Task OnHeartbeatAsync(long seq)
{
this.Logger.LogTrace(LoggerEvents.WebSocketReceive, "Received HEARTBEAT (OP1)");
await this.SendHeartbeatAsync(seq).ConfigureAwait(false);
}
///
/// Handles the reconnect event.
///
internal async Task OnReconnectAsync()
{
this.Logger.LogTrace(LoggerEvents.WebSocketReceive, "Received RECONNECT (OP7)");
await this.InternalReconnectAsync(code: 4000, message: "OP7 acknowledged").ConfigureAwait(false);
}
///
/// Handles the invalidate session event
///
/// Unknown. Please fill documentation.
internal async Task OnInvalidateSessionAsync(bool data)
{
// begin a session if one is not open already
if (this._sessionLock.Wait(0))
this._sessionLock.Reset();
// we are sending a fresh resume/identify, so lock the socket
var socketLock = this.GetSocketLock();
await socketLock.LockAsync().ConfigureAwait(false);
socketLock.UnlockAfter(TimeSpan.FromSeconds(5));
if (data)
{
this.Logger.LogTrace(LoggerEvents.WebSocketReceive, "Received INVALID_SESSION (OP9, true)");
await Task.Delay(6000, this._cancelToken).ConfigureAwait(false);
await this.SendResumeAsync().ConfigureAwait(false);
}
else
{
this.Logger.LogTrace(LoggerEvents.WebSocketReceive, "Received INVALID_SESSION (OP9, false)");
this._sessionId = null;
- await this.SendIdentifyAsync(this._status).ConfigureAwait(false);
+ await this.SendIdentifyAsync(this._status, "invalidate").ConfigureAwait(false);
}
}
///
/// Handles the hello event.
///
/// The gateway hello payload.
internal async Task OnHelloAsync(GatewayHello hello)
{
this.Logger.LogTrace(LoggerEvents.WebSocketReceive, "Received HELLO (OP10)");
if (this._sessionLock.Wait(0))
{
this._sessionLock.Reset();
this.GetSocketLock().UnlockAfter(TimeSpan.FromSeconds(5));
}
else
{
this.Logger.LogWarning(LoggerEvents.SessionUpdate, "Attempt to start a session while another session is active");
return;
}
Interlocked.CompareExchange(ref this._skippedHeartbeats, 0, 0);
this._heartbeatInterval = hello.HeartbeatInterval;
this._heartbeatTask = Task.Run(this.HeartbeatLoopAsync, this._cancelToken);
if (string.IsNullOrEmpty(this._sessionId))
- await this.SendIdentifyAsync(this._status).ConfigureAwait(false);
+ await this.SendIdentifyAsync(this._status, "hello").ConfigureAwait(false);
else
await this.SendResumeAsync().ConfigureAwait(false);
}
///
/// Handles the heartbeat acknowledge event.
///
internal async Task OnHeartbeatAckAsync()
{
Interlocked.Decrement(ref this._skippedHeartbeats);
var ping = (int)(DateTime.Now - this._lastHeartbeat).TotalMilliseconds;
this.Logger.LogTrace(LoggerEvents.WebSocketReceive, "Received HEARTBEAT_ACK (OP11, {0}ms)", ping);
Volatile.Write(ref this._ping, ping);
var args = new HeartbeatEventArgs(this.ServiceProvider)
{
Ping = this.Ping,
Timestamp = DateTimeOffset.Now
};
await this._heartbeated.InvokeAsync(this, args).ConfigureAwait(false);
}
///
/// Handles the heartbeat loop.
///
internal async Task HeartbeatLoopAsync()
{
this.Logger.LogDebug(LoggerEvents.Heartbeat, "Heartbeat task started");
var token = this._cancelToken;
try
{
while (true)
{
await this.SendHeartbeatAsync(this._lastSequence).ConfigureAwait(false);
await Task.Delay(this._heartbeatInterval, token).ConfigureAwait(false);
token.ThrowIfCancellationRequested();
}
}
catch (OperationCanceledException)
{ }
}
#endregion
#region Internal Gateway Methods
///
/// Updates the status.
///
/// The activity.
/// The optional user status.
/// Since when is the client performing the specified activity.
internal async Task InternalUpdateStatusAsync(DiscordActivity activity, UserStatus? userStatus, DateTimeOffset? idleSince)
{
if (activity is { Name.Length: > 128 })
throw new("Game name can't be longer than 128 characters!");
var sinceUnix = idleSince != null ? (long?)Utilities.GetUnixTime(idleSince.Value) : null;
var act = activity ?? new DiscordActivity();
var status = new StatusUpdate
{
Activity = new(act),
IdleSince = sinceUnix,
IsAfk = idleSince != null,
Status = userStatus ?? UserStatus.Online
};
// Solution to have status persist between sessions
this._status = status;
var statusUpdate = new GatewayPayload
{
OpCode = GatewayOpCode.StatusUpdate,
Data = status
};
var statusstr = JsonConvert.SerializeObject(statusUpdate);
await this.WsSendAsync(statusstr).ConfigureAwait(false);
if (!this.PresencesInternal.TryGetValue(this.CurrentUser.Id, out var value))
this.PresencesInternal[this.CurrentUser.Id] = new()
{
Discord = this,
Activity = act,
Status = userStatus ?? UserStatus.Online,
InternalUser = new()
{
Id = this.CurrentUser.Id
}
};
else
{
value.Activity = act;
value.Status = userStatus ?? value.Status;
}
}
///
/// Sends the heartbeat.
///
/// The sequenze.
internal async Task SendHeartbeatAsync(long seq)
{
var moreThan5 = Volatile.Read(ref this._skippedHeartbeats) > 5;
var guildsComp = Volatile.Read(ref this._guildDownloadCompleted);
switch (guildsComp)
{
case true when moreThan5:
{
this.Logger.LogCritical(LoggerEvents.HeartbeatFailure, "Server failed to acknowledge more than 5 heartbeats - connection is zombie");
var args = new ZombiedEventArgs(this.ServiceProvider)
{
Failures = Volatile.Read(ref this._skippedHeartbeats),
GuildDownloadCompleted = true
};
await this._zombied.InvokeAsync(this, args).ConfigureAwait(false);
await this.InternalReconnectAsync(code: 4001, message: "Too many heartbeats missed").ConfigureAwait(false);
return;
}
case false when moreThan5:
{
var args = new ZombiedEventArgs(this.ServiceProvider)
{
Failures = Volatile.Read(ref this._skippedHeartbeats),
GuildDownloadCompleted = false
};
await this._zombied.InvokeAsync(this, args).ConfigureAwait(false);
this.Logger.LogWarning(LoggerEvents.HeartbeatFailure, "Server failed to acknowledge more than 5 heartbeats, but the guild download is still running - check your connection speed");
break;
}
}
Volatile.Write(ref this._lastSequence, seq);
this.Logger.LogTrace(LoggerEvents.Heartbeat, "Sending heartbeat");
var heartbeat = new GatewayPayload
{
OpCode = GatewayOpCode.Heartbeat,
Data = seq
};
var heartbeatStr = JsonConvert.SerializeObject(heartbeat);
await this.WsSendAsync(heartbeatStr).ConfigureAwait(false);
this._lastHeartbeat = DateTimeOffset.Now;
Interlocked.Increment(ref this._skippedHeartbeats);
}
///
/// Sends the identify payload.
///
/// The status update payload.
- internal async Task SendIdentifyAsync(StatusUpdate? status)
+ /// The origin of the call.
+ internal async Task SendIdentifyAsync(StatusUpdate? status, string origin)
{
var identify = new GatewayIdentify
{
Token = Utilities.GetFormattedToken(this),
Compress = this.Configuration.GatewayCompressionLevel == GatewayCompressionLevel.Payload,
LargeThreshold = this.Configuration.LargeThreshold,
ShardInfo = new()
{
ShardId = this.Configuration.ShardId,
ShardCount = this.Configuration.ShardCount
},
Presence = status,
Intents = this.Configuration.Intents,
Discord = this
};
var payload = new GatewayPayload
{
OpCode = GatewayOpCode.Identify,
Data = identify
};
var payloadstr = JsonConvert.SerializeObject(payload);
await this.WsSendAsync(payloadstr).ConfigureAwait(false);
- this.Logger.LogDebug(LoggerEvents.Intents, "Registered gateway intents ({Intents})", this.Configuration.Intents);
+ this.Logger.LogDebug(LoggerEvents.Intents, "Registered gateway intents ({Intents}, {origin})", this.Configuration.Intents, origin);
}
///
/// Sends the resume payload.
///
internal async Task SendResumeAsync()
{
ArgumentNullException.ThrowIfNull(this._sessionId);
ArgumentNullException.ThrowIfNull(this._resumeGatewayUrl);
var resume = new GatewayResume
{
Token = Utilities.GetFormattedToken(this),
SessionId = this._sessionId,
SequenceNumber = Volatile.Read(ref this._lastSequence)
};
var resumePayload = new GatewayPayload
{
OpCode = GatewayOpCode.Resume,
Data = resume
};
var resumestr = JsonConvert.SerializeObject(resumePayload);
this.GatewayUri = new(this._resumeGatewayUrl);
this.Logger.LogDebug(LoggerEvents.ConnectionClose, "Request to resume via {gw}", this.GatewayUri.AbsoluteUri);
await this.WsSendAsync(resumestr).ConfigureAwait(false);
}
///
/// Internals the update gateway async.
///
/// A Task.
internal async Task InternalUpdateGatewayAsync()
{
var info = await this.GetGatewayInfoAsync().ConfigureAwait(false);
this.GatewayInfo = info;
this.GatewayUri = new(info.Url);
}
///
/// Sends a websocket message.
///
/// The payload to send.
internal async Task WsSendAsync(string payload)
{
this.Logger.LogTrace(LoggerEvents.GatewayWsTx, "{Payload}", payload);
await this.WebSocketClient.SendMessageAsync(payload).ConfigureAwait(false);
}
#endregion
#region Semaphore Methods
///
/// Gets the socket lock.
///
/// The added socket lock.
private SocketLock GetSocketLock()
=> s_socketLocks.GetOrAdd(this.CurrentApplication.Id, new SocketLock(this.CurrentApplication.Id, this.GatewayInfo!.SessionBucket.MaxConcurrency));
#endregion
}