small refactorings

- use brackets when possible
- set client.Id and client.Name earlier
- use client.Id instead of header.Id
- rename firstConn to isClientNew
This commit is contained in:
Robin C. Ladiges 2024-04-27 17:23:34 +02:00
parent 61e6fcf2a3
commit f51ef765e5
No known key found for this signature in database
GPG Key ID: B494D3DF92661B99
3 changed files with 73 additions and 42 deletions

View File

@ -41,8 +41,9 @@ public class Client : IDisposable {
} }
public void Dispose() { public void Dispose() {
if (Socket?.Connected is true) if (Socket?.Connected is true) {
Socket.Disconnect(false); Socket.Disconnect(false);
}
} }
@ -52,8 +53,8 @@ public class Client : IDisposable {
PacketAttribute packetAttribute = Constants.PacketMap[typeof(T)]; PacketAttribute packetAttribute = Constants.PacketMap[typeof(T)];
try { try {
Server.FillPacket(new PacketHeader { Server.FillPacket(new PacketHeader {
Id = sender?.Id ?? Id, Id = sender?.Id ?? Id,
Type = packetAttribute.Type, Type = packetAttribute.Type,
PacketSize = packet.Size PacketSize = packet.Size
}, packet, memory.Memory); }, packet, memory.Memory);
} }
@ -69,6 +70,7 @@ public class Client : IDisposable {
public async Task Send(Memory<byte> data, Client? sender) { public async Task Send(Memory<byte> data, Client? sender) {
PacketHeader header = new PacketHeader(); PacketHeader header = new PacketHeader();
header.Deserialize(data.Span); header.Deserialize(data.Span);
if (!Connected && header.Type is not PacketType.Connect) { if (!Connected && header.Type is not PacketType.Connect) {
Server.Logger.Error($"Didn't send {header.Type} to {Id} because they weren't connected yet"); Server.Logger.Error($"Didn't send {header.Type} to {Id} because they weren't connected yet");
return; return;

View File

@ -120,11 +120,13 @@ void logError(Task x) {
server.PacketHandler = (c, p) => { server.PacketHandler = (c, p) => {
switch (p) { switch (p) {
case GamePacket gamePacket: { case GamePacket gamePacket: {
// crash player entering a banned stage
if (BanLists.Enabled && BanLists.IsStageBanned(gamePacket.Stage)) { if (BanLists.Enabled && BanLists.IsStageBanned(gamePacket.Stage)) {
c.Logger.Warn($"Crashing player for entering banned stage {gamePacket.Stage}."); c.Logger.Warn($"Crashing player for entering banned stage {gamePacket.Stage}.");
BanLists.Crash(c, false, false, 500); BanLists.Crash(c, false, false, 500);
return false; return false;
} }
c.Logger.Info($"Got game packet {gamePacket.Stage}->{gamePacket.ScenarioNum}"); c.Logger.Info($"Got game packet {gamePacket.Stage}->{gamePacket.ScenarioNum}");
// reset lastPlayerPacket on stage changes // reset lastPlayerPacket on stage changes
@ -184,7 +186,7 @@ server.PacketHandler = (c, p) => {
break; break;
} }
case CostumePacket costumePacket: case CostumePacket costumePacket: {
c.Logger.Info($"Got costume packet: {costumePacket.BodyName}, {costumePacket.CapName}"); c.Logger.Info($"Got costume packet: {costumePacket.BodyName}, {costumePacket.CapName}");
c.Metadata["lastCostumePacket"] = costumePacket; c.Metadata["lastCostumePacket"] = costumePacket;
c.CurrentCostume = costumePacket; c.CurrentCostume = costumePacket;
@ -193,6 +195,7 @@ server.PacketHandler = (c, p) => {
#pragma warning restore CS4014 #pragma warning restore CS4014
c.Metadata["loadedSave"] = true; c.Metadata["loadedSave"] = true;
break; break;
}
case ShinePacket shinePacket: { case ShinePacket shinePacket: {
if (!Settings.Instance.Shines.Enabled) return false; if (!Settings.Instance.Shines.Enabled) return false;

View File

@ -29,6 +29,7 @@ public class Server {
Socket socket = token.HasValue ? await serverSocket.AcceptAsync(token.Value) : await serverSocket.AcceptAsync(); Socket socket = token.HasValue ? await serverSocket.AcceptAsync(token.Value) : await serverSocket.AcceptAsync();
socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.NoDelay, true); socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.NoDelay, true);
// is the IPv4 address banned?
if (BanLists.Enabled && BanLists.IsIPv4Banned(((IPEndPoint) socket.RemoteEndPoint!).Address!)) { if (BanLists.Enabled && BanLists.IsIPv4Banned(((IPEndPoint) socket.RemoteEndPoint!).Address!)) {
Logger.Warn($"Ignoring banned IPv4 address {socket.RemoteEndPoint}"); Logger.Warn($"Ignoring banned IPv4 address {socket.RemoteEndPoint}");
continue; continue;
@ -36,6 +37,7 @@ public class Server {
Logger.Warn($"Accepted connection for client {socket.RemoteEndPoint}"); Logger.Warn($"Accepted connection for client {socket.RemoteEndPoint}");
// start sub thread to handle client
try { try {
#pragma warning disable CS4014 #pragma warning disable CS4014
Task.Run(() => HandleSocket(socket)) Task.Run(() => HandleSocket(socket))
@ -78,15 +80,17 @@ public class Server {
public delegate void PacketReplacer<in T>(Client from, Client to, T value); // replacer must send public delegate void PacketReplacer<in T>(Client from, Client to, T value); // replacer must send
public void BroadcastReplace<T>(T packet, Client sender, PacketReplacer<T> packetReplacer) where T : struct, IPacket { public void BroadcastReplace<T>(T packet, Client sender, PacketReplacer<T> packetReplacer) where T : struct, IPacket {
foreach (Client client in Clients.Where(client => client.Connected && sender.Id != client.Id)) packetReplacer(sender, client, packet); foreach (Client client in Clients.Where(client => client.Connected && sender.Id != client.Id)) {
packetReplacer(sender, client, packet);
}
} }
public async Task Broadcast<T>(T packet, Client sender) where T : struct, IPacket { public async Task Broadcast<T>(T packet, Client sender) where T : struct, IPacket {
IMemoryOwner<byte> memory = MemoryPool<byte>.Shared.RentZero(Constants.HeaderSize + packet.Size); IMemoryOwner<byte> memory = MemoryPool<byte>.Shared.RentZero(Constants.HeaderSize + packet.Size);
PacketHeader header = new PacketHeader { PacketHeader header = new PacketHeader {
Id = sender?.Id ?? Guid.Empty, Id = sender?.Id ?? Guid.Empty,
Type = Constants.PacketMap[typeof(T)].Type, Type = Constants.PacketMap[typeof(T)].Type,
PacketSize = packet.Size PacketSize = packet.Size,
}; };
FillPacket(header, packet, memory.Memory); FillPacket(header, packet, memory.Memory);
await Broadcast(memory, sender); await Broadcast(memory, sender);
@ -96,9 +100,9 @@ public class Server {
return Task.WhenAll(Clients.Where(c => c.Connected).Select(async client => { return Task.WhenAll(Clients.Where(c => c.Connected).Select(async client => {
IMemoryOwner<byte> memory = MemoryPool<byte>.Shared.RentZero(Constants.HeaderSize + packet.Size); IMemoryOwner<byte> memory = MemoryPool<byte>.Shared.RentZero(Constants.HeaderSize + packet.Size);
PacketHeader header = new PacketHeader { PacketHeader header = new PacketHeader {
Id = client.Id, Id = client.Id,
Type = Constants.PacketMap[typeof(T)].Type, Type = Constants.PacketMap[typeof(T)].Type,
PacketSize = packet.Size PacketSize = packet.Size,
}; };
FillPacket(header, packet, memory.Memory); FillPacket(header, packet, memory.Memory);
await client.Send(memory.Memory, client); await client.Send(memory.Memory, client);
@ -137,6 +141,7 @@ public class Server {
await client.Send(new InitPacket { await client.Send(new InitPacket {
MaxPlayers = Settings.Instance.Server.MaxPlayers MaxPlayers = Settings.Instance.Server.MaxPlayers
}); });
bool first = true; bool first = true;
try { try {
while (true) { while (true) {
@ -159,8 +164,9 @@ public class Server {
return true; return true;
} }
if (!await Read(memory.Memory[..Constants.HeaderSize], Constants.HeaderSize, 0)) if (!await Read(memory.Memory[..Constants.HeaderSize], Constants.HeaderSize, 0)) {
break; break;
}
PacketHeader header = GetHeader(memory.Memory.Span[..Constants.HeaderSize]); PacketHeader header = GetHeader(memory.Memory.Span[..Constants.HeaderSize]);
Range packetRange = Constants.HeaderSize..(Constants.HeaderSize + header.PacketSize); Range packetRange = Constants.HeaderSize..(Constants.HeaderSize + header.PacketSize);
if (header.PacketSize > 0) { if (header.PacketSize > 0) {
@ -168,10 +174,10 @@ public class Server {
memory = memoryPool.Rent(Constants.HeaderSize + header.PacketSize); memory = memoryPool.Rent(Constants.HeaderSize + header.PacketSize);
memTemp.Memory.Span[..Constants.HeaderSize].CopyTo(memory.Memory.Span[..Constants.HeaderSize]); memTemp.Memory.Span[..Constants.HeaderSize].CopyTo(memory.Memory.Span[..Constants.HeaderSize]);
memTemp.Dispose(); memTemp.Dispose();
if (!await Read(memory.Memory, header.PacketSize, Constants.HeaderSize)) if (!await Read(memory.Memory, header.PacketSize, Constants.HeaderSize)) {
break; break;
}
} }
if (client.Ignored) { if (client.Ignored) {
memory.Dispose(); memory.Dispose();
continue; continue;
@ -179,60 +185,71 @@ public class Server {
// connection initialization // connection initialization
if (first) { if (first) {
first = false; first = false; // only do this once
if (header.Type != PacketType.Connect) throw new Exception($"First packet was not init, instead it was {header.Type}");
// first client packet has to be the client init
if (header.Type != PacketType.Connect) {
throw new Exception($"First packet was not init, instead it was {header.Type} ({remote})");
}
ConnectPacket connect = new ConnectPacket(); ConnectPacket connect = new ConnectPacket();
connect.Deserialize(memory.Memory.Span[packetRange]); connect.Deserialize(memory.Memory.Span[packetRange]);
bool wasFirst = connect.ConnectionType == ConnectPacket.ConnectionTypes.FirstConnection; client.Id = header.Id;
client.Name = connect.ClientName;
if (BanLists.Enabled && BanLists.IsProfileBanned(header.Id)) { // is the profile ID banned?
client.Id = header.Id; if (BanLists.Enabled && BanLists.IsProfileBanned(client.Id)) {
client.Name = connect.ClientName;
client.Ignored = true; client.Ignored = true;
client.Logger.Warn($"Ignoring banned profile ID {header.Id}"); client.Logger.Warn($"Ignoring banned profile ID {client.Id}");
memory.Dispose(); memory.Dispose();
continue; continue;
} }
bool wasFirst = connect.ConnectionType == ConnectPacket.ConnectionTypes.FirstConnection;
// add client to the set of connected players
lock (Clients) { lock (Clients) {
// is the server full?
if (Clients.Count(x => x.Connected) == Settings.Instance.Server.MaxPlayers) { if (Clients.Count(x => x.Connected) == Settings.Instance.Server.MaxPlayers) {
client.Logger.Error($"Turned away as server is at max clients"); client.Logger.Error($"Turned away as server is at max clients");
memory.Dispose(); memory.Dispose();
goto disconnect; goto disconnect;
} }
bool firstConn = true; // detect and handle reconnections
bool isClientNew = true;
switch (connect.ConnectionType) { switch (connect.ConnectionType) {
case ConnectPacket.ConnectionTypes.FirstConnection: case ConnectPacket.ConnectionTypes.FirstConnection:
case ConnectPacket.ConnectionTypes.Reconnecting: { case ConnectPacket.ConnectionTypes.Reconnecting: {
client.Id = header.Id; if (FindExistingClient(client.Id) is { } oldClient) {
if (FindExistingClient(header.Id) is { } oldClient) { isClientNew = false;
firstConn = false;
client = new Client(oldClient, socket); client = new Client(oldClient, socket);
client.Name = connect.ClientName;
Clients.Remove(oldClient); Clients.Remove(oldClient);
Clients.Add(client); Clients.Add(client);
if (oldClient.Connected) { if (oldClient.Connected) {
oldClient.Logger.Info($"Disconnecting already connected client {oldClient.Socket?.RemoteEndPoint} for {client.Socket?.RemoteEndPoint}"); oldClient.Logger.Info($"Disconnecting already connected client {oldClient.Socket?.RemoteEndPoint} for {client.Socket?.RemoteEndPoint}");
oldClient.Dispose(); oldClient.Dispose();
} }
} else { }
else {
connect.ConnectionType = ConnectPacket.ConnectionTypes.FirstConnection; connect.ConnectionType = ConnectPacket.ConnectionTypes.FirstConnection;
} }
break; break;
} }
default: default: {
throw new Exception($"Invalid connection type {connect.ConnectionType}"); throw new Exception($"Invalid connection type {connect.ConnectionType} for {client.Name} ({client.Id}/{remote})");
}
} }
client.Name = connect.ClientName;
client.Connected = true; client.Connected = true;
if (firstConn) {
if (isClientNew) {
// do any cleanup required when it comes to new clients // do any cleanup required when it comes to new clients
List<Client> toDisconnect = Clients.FindAll(c => c.Id == header.Id && c.Connected && c.Socket != null); List<Client> toDisconnect = Clients.FindAll(c => c.Id == client.Id && c.Connected && c.Socket != null);
Clients.RemoveAll(c => c.Id == header.Id); Clients.RemoveAll(c => c.Id == client.Id);
Clients.Add(client); Clients.Add(client);
@ -240,18 +257,19 @@ public class Server {
// done disconnecting and removing stale clients with the same id // done disconnecting and removing stale clients with the same id
ClientJoined?.Invoke(client, connect); ClientJoined?.Invoke(client, connect);
// a new connection, not a reconnect, for an existing client }
} else if (wasFirst) { // a known client reconnects, but with a new first connection (e.g. after a restart)
else if (wasFirst) {
client.CleanMetadataOnNewConnection(); client.CleanMetadataOnNewConnection();
} }
} }
// for all other clients that are already connected // for all other clients that are already connected
List<Client> otherConnectedPlayers = Clients.FindAll(c => c.Id != header.Id && c.Connected && c.Socket != null); List<Client> otherConnectedPlayers = Clients.FindAll(c => c.Id != client.Id && c.Connected && c.Socket != null);
await Parallel.ForEachAsync(otherConnectedPlayers, async (other, _) => { await Parallel.ForEachAsync(otherConnectedPlayers, async (other, _) => {
IMemoryOwner<byte> tempBuffer = MemoryPool<byte>.Shared.RentZero(Constants.HeaderSize + (other.CurrentCostume.HasValue ? Math.Max(connect.Size, other.CurrentCostume.Value.Size) : connect.Size)); IMemoryOwner<byte> tempBuffer = MemoryPool<byte>.Shared.RentZero(Constants.HeaderSize + (other.CurrentCostume.HasValue ? Math.Max(connect.Size, other.CurrentCostume.Value.Size) : connect.Size));
// make the other client known to the (new) client // make the other client known to the new client
PacketHeader connectHeader = new PacketHeader { PacketHeader connectHeader = new PacketHeader {
Id = other.Id, Id = other.Id,
Type = PacketType.Connect, Type = PacketType.Connect,
@ -266,7 +284,7 @@ public class Server {
connectPacket.Serialize(tempBuffer.Memory.Span[Constants.HeaderSize..]); connectPacket.Serialize(tempBuffer.Memory.Span[Constants.HeaderSize..]);
await client.Send(tempBuffer.Memory[..(Constants.HeaderSize + connect.Size)], null); await client.Send(tempBuffer.Memory[..(Constants.HeaderSize + connect.Size)], null);
// tell the (new) client what costume the other client has // tell the new client what costume the other client has
if (other.CurrentCostume.HasValue) { if (other.CurrentCostume.HasValue) {
connectHeader.Type = PacketType.Costume; connectHeader.Type = PacketType.Costume;
connectHeader.PacketSize = other.CurrentCostume.Value.Size; connectHeader.PacketSize = other.CurrentCostume.Value.Size;
@ -277,7 +295,7 @@ public class Server {
tempBuffer.Dispose(); tempBuffer.Dispose();
// make the other client reset their puppet cache for this client, if it is a new connection (after restart) // make the other client reset their puppet cache for this new client, if it is a new connection (after restart)
if (wasFirst) { if (wasFirst) {
await SendEmptyPackets(client, other); await SendEmptyPackets(client, other);
} }
@ -287,14 +305,19 @@ public class Server {
// send missing or outdated packets from others to the new client // send missing or outdated packets from others to the new client
await ResendPackets(client); await ResendPackets(client);
} else if (header.Id != client.Id && client.Id != Guid.Empty) { }
else if (header.Id != client.Id && client.Id != Guid.Empty) {
throw new Exception($"Client {client.Name} sent packet with invalid client id {header.Id} instead of {client.Id}"); throw new Exception($"Client {client.Name} sent packet with invalid client id {header.Id} instead of {client.Id}");
} }
try { try {
// parse the packet
IPacket packet = (IPacket) Activator.CreateInstance(Constants.PacketIdMap[header.Type])!; IPacket packet = (IPacket) Activator.CreateInstance(Constants.PacketIdMap[header.Type])!;
packet.Deserialize(memory.Memory.Span[Constants.HeaderSize..(Constants.HeaderSize + packet.Size)]); packet.Deserialize(memory.Memory.Span[Constants.HeaderSize..(Constants.HeaderSize + packet.Size)]);
// process the packet
if (PacketHandler?.Invoke(client, packet) is false) { if (PacketHandler?.Invoke(client, packet) is false) {
// don't broadcast the packet to everyone
memory.Dispose(); memory.Dispose();
continue; continue;
} }
@ -302,7 +325,9 @@ public class Server {
catch (Exception e) { catch (Exception e) {
client.Logger.Error($"Packet handler warning: {e}"); client.Logger.Error($"Packet handler warning: {e}");
} }
#pragma warning disable CS4014 #pragma warning disable CS4014
// broadcast the packet to everyone
Broadcast(memory, client) Broadcast(memory, client)
.ContinueWith(x => { if (x.Exception != null) { Logger.Error(x.Exception.ToString()); } }); .ContinueWith(x => { if (x.Exception != null) { Logger.Error(x.Exception.ToString()); } });
#pragma warning restore CS4014 #pragma warning restore CS4014
@ -311,7 +336,8 @@ public class Server {
catch (Exception e) { catch (Exception e) {
if (e is SocketException {SocketErrorCode: SocketError.ConnectionReset}) { if (e is SocketException {SocketErrorCode: SocketError.ConnectionReset}) {
client.Logger.Info($"Disconnected from the server: Connection reset"); client.Logger.Info($"Disconnected from the server: Connection reset");
} else { }
else {
client.Logger.Error($"Disconnecting due to exception: {e}"); client.Logger.Error($"Disconnecting due to exception: {e}");
if (socket.Connected) { if (socket.Connected) {
#pragma warning disable CS4014 #pragma warning disable CS4014
@ -324,6 +350,7 @@ public class Server {
memory?.Dispose(); memory?.Dispose();
} }
// client disconnected
disconnect: disconnect:
if (client.Name != "Unknown User" && client.Id != Guid.Parse("00000000-0000-0000-0000-000000000000")) { if (client.Name != "Unknown User" && client.Id != Guid.Parse("00000000-0000-0000-0000-000000000000")) {
Logger.Info($"Client {remote} ({client.Name}/{client.Id}) disconnected from the server"); Logger.Info($"Client {remote} ({client.Name}/{client.Id}) disconnected from the server");
@ -333,7 +360,6 @@ public class Server {
} }
bool wasConnected = client.Connected; bool wasConnected = client.Connected;
// Clients.Remove(client)
client.Connected = false; client.Connected = false;
try { try {
client.Dispose(); client.Dispose();
@ -359,7 +385,7 @@ public class Server {
} }
}; };
async Task trySendMeta<T>(Client other, string packetType) where T : struct, IPacket { async Task trySendMeta<T>(Client other, string packetType) where T : struct, IPacket {
if (! other.Metadata.ContainsKey(packetType)) { return; } if (!other.Metadata.ContainsKey(packetType)) { return; }
await trySendPack<T>(other, (T) other.Metadata[packetType]!); await trySendPack<T>(other, (T) other.Metadata[packetType]!);
}; };
await Parallel.ForEachAsync(this.ClientsConnected, async (other, _) => { await Parallel.ForEachAsync(this.ClientsConnected, async (other, _) => {