From d6a8df448c08fc9dc3572701fde8a7399aadbeed Mon Sep 17 00:00:00 2001 From: "Robin C. Ladiges" Date: Wed, 6 Sep 2023 01:14:54 +0200 Subject: [PATCH] Refactoring ban command (#48) * rename command: `ban ...` => `ban player ...` To enable adding other subcommands starting with `ban`. Moving ban list and crash related code into its own class to tidy the Program class up. Change Id values of the crash cmds, to fit into the 16 byte max length imposed by ChangeStagePacket.IdSize. * add command: `ban ip ` To add an IPv4 address to the ban list. * add command: `ban profile ` To add a profile ID to the ban list. * add command: `unban ip ` To remove a banned IPv4 address from the ban list. * add command: `unban profile ` To remove a banned profile ID from the ban list. * add commands: `ban enable` and `ban disable` To set the value of `BanList.Enabled` to `true` or `false` without editing the `settings.json` file. * add command: `ban list` To show the current ban list settings. * fix: actually working ban functionality Changes: - ignore new sockets from banned IP addresses way earlier. - ignore all packets by banned profiles. Intentionally keeping the connection open instead of d/c banned clients. This is to prevent endless server logs due to automatically reconnecting clients. Before: Reconnecting clients aren't entering `ClientJoined` and therefore the d/c is only working on first connections. Effectively banned clients got a d/c and then automatically reconnected again without getting a d/c again. Therefore allowing them to play normally. * use SortedSet instead of List for settings To enforce unique entries and maintain a stable order inside of the `settings.json`. * add commands: `ban stage ` and `unban stage ` To kick players from the server when they enter a banned stage. can also be a kingdom alias, which bans/unbans all stages in that kingdom. Because we aren't banning the player, d/c them would be no good, because of the client auto reconnect. Instead send them the crash and ignore all packets by them until they d/c on their own. This is an alternative solution for issue #43. * Update Server.cs --------- Co-authored-by: Sanae <32604996+Sanae6@users.noreply.github.com> --- Server/BanLists.cs | 361 +++++++++++++++++++++++++++++++++++++++++++++ Server/Client.cs | 1 + Server/Program.cs | 87 ++++------- Server/Server.cs | 20 +++ Server/Settings.cs | 17 ++- Shared/Stages.cs | 28 +++- 6 files changed, 447 insertions(+), 67 deletions(-) create mode 100644 Server/BanLists.cs diff --git a/Server/BanLists.cs b/Server/BanLists.cs new file mode 100644 index 0000000..dbe0d12 --- /dev/null +++ b/Server/BanLists.cs @@ -0,0 +1,361 @@ +using System.Net; +using System.Net.Sockets; +using System.Text; + +using Shared; +using Shared.Packet.Packets; + +namespace Server; + +using MUCH = Func failToFind, HashSet toActUpon, List<(string arg, IEnumerable amb)> ambig)>; + +public static class BanLists { + public static bool Enabled { + get { + return Settings.Instance.BanList.Enabled; + } + private set { + Settings.Instance.BanList.Enabled = value; + } + } + + private static ISet IPs { + get { + return Settings.Instance.BanList.IpAddresses; + } + } + + private static ISet Profiles { + get { + return Settings.Instance.BanList.Players; + } + } + + private static ISet Stages { + get { + return Settings.Instance.BanList.Stages; + } + } + + + private static bool IsIPv4(string str) { + return IPAddress.TryParse(str, out IPAddress? ip) + && ip != null + && ip.AddressFamily == AddressFamily.InterNetwork; + ; + } + + + public static bool IsIPv4Banned(Client user) { + IPEndPoint? ipv4 = (IPEndPoint?) user.Socket?.RemoteEndPoint; + if (ipv4 == null) { return false; } + return IsIPv4Banned(ipv4.Address); + } + public static bool IsIPv4Banned(IPAddress ipv4) { + return IsIPv4Banned(ipv4.ToString()); + } + public static bool IsIPv4Banned(string ipv4) { + return IPs.Contains(ipv4); + } + + public static bool IsProfileBanned(Client user) { + return IsProfileBanned(user.Id); + } + public static bool IsProfileBanned(string str) { + if (!Guid.TryParse(str, out Guid id)) { return false; } + return IsProfileBanned(id); + } + public static bool IsProfileBanned(Guid id) { + return Profiles.Contains(id); + } + + public static bool IsStageBanned(string stage) { + return Stages.Contains(stage); + } + + public static bool IsClientBanned(Client user) { + return IsProfileBanned(user) || IsIPv4Banned(user); + } + + + private static void BanIPv4(Client user) { + IPEndPoint? ipv4 = (IPEndPoint?) user.Socket?.RemoteEndPoint; + if (ipv4 != null) { + BanIPv4(ipv4.Address); + } + } + private static void BanIPv4(IPAddress ipv4) { + BanIPv4(ipv4.ToString()); + } + private static void BanIPv4(string ipv4) { + IPs.Add(ipv4); + } + + private static void BanProfile(Client user) { + BanProfile(user.Id); + } + private static void BanProfile(string str) { + if (!Guid.TryParse(str, out Guid id)) { return; } + BanProfile(id); + } + private static void BanProfile(Guid id) { + Profiles.Add(id); + } + + private static void BanStage(string stage) { + Stages.Add(stage); + } + + private static void BanClient(Client user) { + BanProfile(user); + BanIPv4(user); + } + + + private static void UnbanIPv4(Client user) { + IPEndPoint? ipv4 = (IPEndPoint?) user.Socket?.RemoteEndPoint; + if (ipv4 != null) { + UnbanIPv4(ipv4.Address); + } + } + private static void UnbanIPv4(IPAddress ipv4) { + UnbanIPv4(ipv4.ToString()); + } + private static void UnbanIPv4(string ipv4) { + IPs.Remove(ipv4); + } + + private static void UnbanProfile(Client user) { + UnbanProfile(user.Id); + } + private static void UnbanProfile(string str) { + if (!Guid.TryParse(str, out Guid id)) { return; } + UnbanProfile(id); + } + private static void UnbanProfile(Guid id) { + Profiles.Remove(id); + } + + private static void UnbanStage(string stage) { + Stages.Remove(stage); + } + + + private static void Save() { + Settings.SaveSettings(true); + } + + + public static void Crash( + Client user, + bool permanent = false, + bool dispose_user = true, + int delay_ms = 0 + ) { + user.Ignored = true; + Task.Run(async () => { + if (delay_ms > 0) { + await Task.Delay(delay_ms); + } + await user.Send(new ChangeStagePacket { + Id = (permanent ? "$agogus/ban4lyfe" : "$among$us/cr4sh%"), + Stage = (permanent ? "$ejected" : "$agogusStage"), + Scenario = (sbyte) (permanent ? 69 : 21), + SubScenarioType = (byte) (permanent ? 21 : 69), + }); + if (dispose_user) { + user.Dispose(); + } + }); + } + + private static void CrashMultiple(string[] args, MUCH much) { + foreach (Client user in much(args).toActUpon) { + Crash(user, true); + } + } + + + public static string HandleBanCommand(string[] args, MUCH much) { + if (args.Length == 0) { + return "Usage: ban {list|enable|disable|player|profile|ip|stage} ..."; + } + + string cmd = args[0]; + args = args.Skip(1).ToArray(); + + switch (cmd) { + default: + return "Usage: ban {list|enable|disable|player|profile|ip|stage} ..."; + + case "list": + if (args.Length != 0) { + return "Usage: ban list"; + } + StringBuilder list = new StringBuilder(); + list.Append("BanList: " + (Enabled ? "enabled" : "disabled")); + + if (IPs.Count > 0) { + list.Append("\nBanned IPv4 addresses:\n- "); + list.Append(string.Join("\n- ", IPs)); + } + + if (Profiles.Count > 0) { + list.Append("\nBanned profile IDs:\n- "); + list.Append(string.Join("\n- ", Profiles)); + } + + if (Stages.Count > 0) { + list.Append("\nBanned stages:\n- "); + list.Append(string.Join("\n- ", Stages)); + } + + return list.ToString(); + + case "enable": + if (args.Length != 0) { + return "Usage: ban enable"; + } + Enabled = true; + Save(); + return "BanList enabled."; + + case "disable": + if (args.Length != 0) { + return "Usage: ban disable"; + } + Enabled = false; + Save(); + return "BanList disabled."; + + case "player": + if (args.Length == 0) { + return "Usage: ban player <* | !* (usernames to not ban...) | (usernames to ban...)>"; + } + + var res = much(args); + + StringBuilder sb = new StringBuilder(); + sb.Append(res.toActUpon.Count > 0 ? "Banned players: " + string.Join(", ", res.toActUpon.Select(x => $"\"{x.Name}\"")) : ""); + sb.Append(res.failToFind.Count > 0 ? "\nFailed to find matches for: " + string.Join(", ", res.failToFind.Select(x => $"\"{x.ToLower()}\"")) : ""); + if (res.ambig.Count > 0) { + res.ambig.ForEach(x => { + sb.Append($"\nAmbiguous for \"{x.arg}\": {string.Join(", ", x.amb.Select(x => $"\"{x}\""))}"); + }); + } + + foreach (Client user in res.toActUpon) { + BanClient(user); + Crash(user, true); + } + + Save(); + return sb.ToString(); + + case "profile": + if (args.Length != 1) { + return "Usage: ban profile "; + } + if (!Guid.TryParse(args[0], out Guid id)) { + return "Invalid profile ID value!"; + } + if (IsProfileBanned(id)) { + return "Profile " + id.ToString() + " is already banned."; + } + BanProfile(id); + CrashMultiple(args, much); + Save(); + return "Banned profile: " + id.ToString(); + + case "ip": + if (args.Length != 1) { + return "Usage: ban ip "; + } + if (!IsIPv4(args[0])) { + return "Invalid IPv4 address!"; + } + if (IsIPv4Banned(args[0])) { + return "IP " + args[0] + " is already banned."; + } + BanIPv4(args[0]); + CrashMultiple(args, much); + Save(); + return "Banned ip: " + args[0]; + + case "stage": + if (args.Length != 1) { + return "Usage: ban stage "; + } + string? stage = Shared.Stages.Input2Stage(args[0]); + if (stage == null) { + return "Invalid stage name!"; + } + if (IsStageBanned(stage)) { + return "Stage " + stage + " is already banned."; + } + var stages = Shared.Stages + .StagesByInput(args[0]) + .Where(s => !IsStageBanned(s)) + .ToList() + ; + foreach (string s in stages) { + BanStage(s); + } + Save(); + return "Banned stage: " + string.Join(", ", stages); + } + } + + + public static string HandleUnbanCommand(string[] args) { + if (args.Length != 2) { + return "Usage: unban {profile|ip|stage} "; + } + + string cmd = args[0]; + string val = args[1]; + + switch (cmd) { + default: + return "Usage: unban {profile|ip|stage} "; + + case "profile": + if (!Guid.TryParse(val, out Guid id)) { + return "Invalid profile ID value!"; + } + if (!IsProfileBanned(id)) { + return "Profile " + id.ToString() + " is not banned."; + } + UnbanProfile(id); + Save(); + return "Unbanned profile: " + id.ToString(); + + case "ip": + if (!IsIPv4(val)) { + return "Invalid IPv4 address!"; + } + if (!IsIPv4Banned(val)) { + return "IP " + val + " is not banned."; + } + UnbanIPv4(val); + Save(); + return "Unbanned ip: " + val; + + case "stage": + string stage = Shared.Stages.Input2Stage(val) ?? val; + if (!IsStageBanned(stage)) { + return "Stage " + stage + " is not banned."; + } + var stages = Shared.Stages + .StagesByInput(val) + .Where(IsStageBanned) + .ToList() + ; + foreach (string s in stages) { + UnbanStage(s); + } + Save(); + return "Unbanned stage: " + string.Join(", ", stages); + } + } +} diff --git a/Server/Client.cs b/Server/Client.cs index 7e9434f..4b53525 100644 --- a/Server/Client.cs +++ b/Server/Client.cs @@ -12,6 +12,7 @@ namespace Server; public class Client : IDisposable { public readonly ConcurrentDictionary Metadata = new ConcurrentDictionary(); // can be used to store any information about a player public bool Connected = false; + public bool Ignored = false; public CostumePacket? CurrentCostume = null; // required for proper client sync public string Name { get => Logger.Name; diff --git a/Server/Program.cs b/Server/Program.cs index 230459c..c9728d8 100644 --- a/Server/Program.cs +++ b/Server/Program.cs @@ -62,11 +62,6 @@ async Task LoadShines() await LoadShines(); server.ClientJoined += (c, _) => { - if (Settings.Instance.BanList.Enabled - && (Settings.Instance.BanList.Players.Contains(c.Id) - || Settings.Instance.BanList.IpAddresses.Contains( - ((IPEndPoint) c.Socket!.RemoteEndPoint!).Address.ToString()))) - throw new Exception($"Banned player attempted join: {c.Name}"); c.Metadata["shineSync"] = new ConcurrentBag(); c.Metadata["loadedSave"] = false; c.Metadata["scenario"] = (byte?) 0; @@ -125,6 +120,11 @@ void logError(Task x) { server.PacketHandler = (c, p) => { switch (p) { case GamePacket gamePacket: { + if (BanLists.Enabled && BanLists.IsStageBanned(gamePacket.Stage)) { + c.Logger.Warn($"Crashing player for entering banned stage {gamePacket.Stage}."); + BanLists.Crash(c, false, false, 500); + return false; + } c.Logger.Info($"Got game packet {gamePacket.Stage}->{gamePacket.ScenarioNum}"); // reset lastPlayerPacket on stage changes @@ -246,38 +246,49 @@ server.PacketHandler = (c, p) => { HashSet failToFind = new(); HashSet toActUpon; List<(string arg, IEnumerable amb)> ambig = new(); - if (args[0] == "*") + if (args[0] == "*") { toActUpon = new(server.Clients.Where(c => c.Connected)); + } else { toActUpon = args[0] == "!*" ? new(server.Clients.Where(c => c.Connected)) : new(); for (int i = (args[0] == "!*" ? 1 : 0); i < args.Length; i++) { string arg = args[i]; - IEnumerable search = server.Clients.Where(c => c.Connected && - (c.Name.ToLower().StartsWith(arg.ToLower()) || (Guid.TryParse(arg, out Guid res) && res == c.Id))); - if (!search.Any()) + IEnumerable search = server.Clients.Where(c => c.Connected && ( + c.Name.ToLower().StartsWith(arg.ToLower()) + || (Guid.TryParse(arg, out Guid res) && res == c.Id) + || (IPAddress.TryParse(arg, out IPAddress? ip) && ip.Equals(((IPEndPoint) c.Socket!.RemoteEndPoint!).Address)) + )); + if (!search.Any()) { failToFind.Add(arg); //none found + } else if (search.Count() > 1) { Client? exact = search.FirstOrDefault(x => x.Name == arg); if (!ReferenceEquals(exact, null)) { //even though multiple matches, since exact match, it isn't ambiguous - if (args[0] == "!*") + if (args[0] == "!*") { toActUpon.Remove(exact); - else + } + else { toActUpon.Add(exact); + } } else { - if (!ambig.Any(x => x.arg == arg)) + if (!ambig.Any(x => x.arg == arg)) { ambig.Add((arg, search.Select(x => x.Name))); //more than one match - foreach (var rem in search.ToList()) //need copy because can't remove from list while iterating over it + } + foreach (var rem in search.ToList()) { //need copy because can't remove from list while iterating over it toActUpon.Remove(rem); + } } } else { //only one match, so autocomplete - if (args[0] == "!*") + if (args[0] == "!*") { toActUpon.Remove(search.First()); - else + } + else { toActUpon.Add(search.First()); + } } } } @@ -324,54 +335,14 @@ CommandHandler.RegisterCommand("crash", args => { } foreach (Client user in res.toActUpon) { - Task.Run(async () => { - await user.Send(new ChangeStagePacket { - Id = "$among$us/SubArea", - Stage = "$agogusStage", - Scenario = 21, - SubScenarioType = 69 // invalid id - }); - user.Dispose(); - }); + BanLists.Crash(user); } return sb.ToString(); }); -CommandHandler.RegisterCommand("ban", args => { - if (args.Length == 0) { - return "Usage: ban <* | !* (usernames to not ban...) | (usernames to ban...)>"; - } - - var res = MultiUserCommandHelper(args); - - StringBuilder sb = new StringBuilder(); - sb.Append(res.toActUpon.Count > 0 ? "Banned: " + string.Join(", ", res.toActUpon.Select(x => $"\"{x.Name}\"")) : ""); - sb.Append(res.failToFind.Count > 0 ? "\nFailed to find matches for: " + string.Join(", ", res.failToFind.Select(x => $"\"{x.ToLower()}\"")) : ""); - if (res.ambig.Count > 0) { - res.ambig.ForEach(x => { - sb.Append($"\nAmbiguous for \"{x.arg}\": {string.Join(", ", x.amb.Select(x => $"\"{x}\""))}"); - }); - } - - foreach (Client user in res.toActUpon) { - Task.Run(async () => { - await user.Send(new ChangeStagePacket { - Id = "$agogus/banned4lyfe", - Stage = "$ejected", - Scenario = 69, - SubScenarioType = 21 // invalid id - }); - IPEndPoint? endpoint = (IPEndPoint?) user.Socket?.RemoteEndPoint; - Settings.Instance.BanList.Players.Add(user.Id); - if (endpoint != null) Settings.Instance.BanList.IpAddresses.Add(endpoint.ToString()); - user.Dispose(); - }); - } - - Settings.SaveSettings(); - return sb.ToString(); -}); +CommandHandler.RegisterCommand("ban", args => { return BanLists.HandleBanCommand(args, (args) => MultiUserCommandHelper(args)); }); +CommandHandler.RegisterCommand("unban", args => { return BanLists.HandleUnbanCommand(args); }); CommandHandler.RegisterCommand("send", args => { const string optionUsage = "Usage: send "; diff --git a/Server/Server.cs b/Server/Server.cs index 2e8eff2..fea3b8a 100644 --- a/Server/Server.cs +++ b/Server/Server.cs @@ -29,6 +29,11 @@ public class Server { Socket socket = token.HasValue ? await serverSocket.AcceptAsync(token.Value) : await serverSocket.AcceptAsync(); socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.NoDelay, true); + if (BanLists.Enabled && BanLists.IsIPv4Banned(((IPEndPoint) socket.RemoteEndPoint!).Address!)) { + Logger.Warn($"Ignoring banned IPv4 address {socket.RemoteEndPoint}"); + continue; + } + Logger.Warn($"Accepted connection for client {socket.RemoteEndPoint}"); try { @@ -167,6 +172,11 @@ public class Server { break; } + if (client.Ignored) { + memory.Dispose(); + continue; + } + // connection initialization if (first) { first = false; @@ -174,8 +184,18 @@ public class Server { ConnectPacket connect = new ConnectPacket(); connect.Deserialize(memory.Memory.Span[packetRange]); + bool wasFirst = connect.ConnectionType == ConnectPacket.ConnectionTypes.FirstConnection; + if (BanLists.Enabled && BanLists.IsProfileBanned(header.Id)) { + client.Id = header.Id; + client.Name = connect.ClientName; + client.Ignored = true; + client.Logger.Warn($"Ignoring banned profile ID {header.Id}"); + memory.Dispose(); + continue; + } + lock (Clients) { if (Clients.Count(x => x.Connected) == Settings.Instance.Server.MaxPlayers) { client.Logger.Error($"Turned away as server is at max clients"); diff --git a/Server/Settings.cs b/Server/Settings.cs index fe075e5..22a7dc9 100644 --- a/Server/Settings.cs +++ b/Server/Settings.cs @@ -30,10 +30,10 @@ public class Settings { LoadHandler?.Invoke(); } - public static void SaveSettings() { + public static void SaveSettings(bool silent = false) { try { File.WriteAllText("settings.json", JsonConvert.SerializeObject(Instance, Formatting.Indented, new StringEnumConverter(new CamelCaseNamingStrategy()))); - Logger.Info("Saved settings to settings.json"); + if (!silent) { Logger.Info("Saved settings to settings.json"); } } catch (Exception e) { Logger.Error($"Failed to save settings.json {e}"); @@ -43,7 +43,7 @@ public class Settings { public ServerTable Server { get; set; } = new ServerTable(); public FlipTable Flip { get; set; } = new FlipTable(); public ScenarioTable Scenario { get; set; } = new ScenarioTable(); - public BannedPlayers BanList { get; set; } = new BannedPlayers(); + public BanListTable BanList { get; set; } = new BanListTable(); public DiscordTable Discord { get; set; } = new DiscordTable(); public ShineTable Shines { get; set; } = new ShineTable(); public PersistShinesTable PersistShines { get; set; } = new PersistShinesTable(); @@ -58,15 +58,16 @@ public class Settings { public bool MergeEnabled { get; set; } = false; } - public class BannedPlayers { + public class BanListTable { public bool Enabled { get; set; } = false; - public List Players { get; set; } = new List(); - public List IpAddresses { get; set; } = new List(); + public ISet Players { get; set; } = new SortedSet(); + public ISet IpAddresses { get; set; } = new SortedSet(); + public ISet Stages { get; set; } = new SortedSet(); } public class FlipTable { public bool Enabled { get; set; } = true; - public List Players { get; set; } = new List(); + public ISet Players { get; set; } = new SortedSet(); public FlipOptions Pov { get; set; } = FlipOptions.Both; } @@ -86,4 +87,4 @@ public class Settings { public bool Enabled { get; set; } = false; public string Filename { get; set; } = "./moons.json"; } -} \ No newline at end of file +} diff --git a/Shared/Stages.cs b/Shared/Stages.cs index 5eb03b4..04315d5 100644 --- a/Shared/Stages.cs +++ b/Shared/Stages.cs @@ -11,7 +11,7 @@ public static class Stages { return mapName; } // exact stage value - if (Stage2Alias.ContainsKey(input)) { + if (IsStage(input)) { return input; } // force input value with a ! @@ -29,6 +29,32 @@ public static class Stages { return result; } + public static bool IsAlias(string input) { + return Alias2Stage.ContainsKey(input); + } + + public static bool IsStage(string input) { + return Stage2Alias.ContainsKey(input); + } + + public static IEnumerable StagesByInput(string input) { + if (IsAlias(input)) { + var stages = Stage2Alias + .Where(e => e.Value == input) + .Select(e => e.Key) + ; + foreach (string stage in stages) { + yield return stage; + } + } + else { + string? stage = Input2Stage(input); + if (stage != null) { + yield return stage; + } + } + } + public static readonly Dictionary Alias2Stage = new Dictionary() { { "cap", "CapWorldHomeStage" }, { "cascade", "WaterfallWorldHomeStage" },