From aca45340bfd1ae8de9857cb906a2a31fbf8fbcdb Mon Sep 17 00:00:00 2001 From: cpw Date: Sun, 23 Feb 2020 22:15:34 -0500 Subject: [PATCH] Fix up alarming security crisis with network handling that allowed for wrong code execution on the server, resulting in CRASHED SERVERS. AWFUL stuff. Also fixed a silly log message from the login handler. This change introduces a mechanism to direct certain packets to only process on one side or another. Invalid sidedness will result in the connection being terminated. Signed-off-by: cpw --- .../fml/network/FMLLoginWrapper.java | 2 ++ .../fml/network/NetworkHooks.java | 18 +++++++++++ .../fml/network/NetworkInitialization.java | 10 +++---- .../network/simple/IndexedMessageCodec.java | 11 +++++-- .../fml/network/simple/SimpleChannel.java | 30 ++++++++++++++++--- 5 files changed, 59 insertions(+), 12 deletions(-) diff --git a/src/main/java/net/minecraftforge/fml/network/FMLLoginWrapper.java b/src/main/java/net/minecraftforge/fml/network/FMLLoginWrapper.java index 682c10e70..5ab6e9c14 100644 --- a/src/main/java/net/minecraftforge/fml/network/FMLLoginWrapper.java +++ b/src/main/java/net/minecraftforge/fml/network/FMLLoginWrapper.java @@ -47,6 +47,8 @@ public class FMLLoginWrapper { } private void wrapperReceived(final T packet) { + // we don't care about channel registration change events on this channel + if (packet instanceof NetworkEvent.ChannelRegistrationChangeEvent) return; final NetworkEvent.Context wrappedContext = packet.getSource().get(); final PacketBuffer payload = packet.getPayload(); ResourceLocation targetNetworkReceiver = FMLNetworkConstants.FML_HANDSHAKE_RESOURCE; diff --git a/src/main/java/net/minecraftforge/fml/network/NetworkHooks.java b/src/main/java/net/minecraftforge/fml/network/NetworkHooks.java index 4ffecfed2..348efba2a 100644 --- a/src/main/java/net/minecraftforge/fml/network/NetworkHooks.java +++ b/src/main/java/net/minecraftforge/fml/network/NetworkHooks.java @@ -20,6 +20,7 @@ package net.minecraftforge.fml.network; import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.function.Consumer; import java.util.function.Supplier; @@ -28,7 +29,9 @@ import java.util.stream.Collectors; import it.unimi.dsi.fastutil.ints.Int2ObjectMap; import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; import net.minecraft.util.registry.Registry; +import net.minecraft.util.text.StringTextComponent; import net.minecraft.world.dimension.DimensionType; +import net.minecraftforge.fml.common.thread.EffectiveSide; import net.minecraftforge.registries.ClearableRegistry; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -71,9 +74,24 @@ public class NetworkHooks public static boolean onCustomPayload(final ICustomPacket packet, final NetworkManager manager) { return NetworkRegistry.findTarget(packet.getName()). + filter(ni->validateSideForProcessing(packet, ni, manager)). map(ni->ni.dispatch(packet.getDirection(), packet, manager)).orElse(Boolean.FALSE); } + private static boolean validateSideForProcessing(final ICustomPacket packet, final NetworkInstance ni, final NetworkManager manager) { + if (packet.getDirection().getReceptionSide() != EffectiveSide.get()) { + manager.closeChannel(new StringTextComponent("Illegal packet received, terminating connection")); + return false; + } + return true; + } + + public static void validatePacketDirection(final NetworkDirection packetDirection, final Optional expectedDirection, final NetworkManager connection) { + if (packetDirection != expectedDirection.orElse(packetDirection)) { + connection.closeChannel(new StringTextComponent("Illegal packet received, terminating connection")); + throw new IllegalStateException("Invalid packet received, aborting connection"); + } + } public static void registerServerLoginChannel(NetworkManager manager, CHandshakePacket packet) { manager.channel().attr(FMLNetworkConstants.FML_NETVERSION).set(packet.getFMLVersion()); diff --git a/src/main/java/net/minecraftforge/fml/network/NetworkInitialization.java b/src/main/java/net/minecraftforge/fml/network/NetworkInitialization.java index 07065a34a..f93ae8c73 100644 --- a/src/main/java/net/minecraftforge/fml/network/NetworkInitialization.java +++ b/src/main/java/net/minecraftforge/fml/network/NetworkInitialization.java @@ -37,14 +37,14 @@ class NetworkInitialization { networkProtocolVersion(() -> FMLNetworkConstants.NETVERSION). simpleChannel(); - handshakeChannel.messageBuilder(FMLHandshakeMessages.C2SAcknowledge.class, 99). + handshakeChannel.messageBuilder(FMLHandshakeMessages.C2SAcknowledge.class, 99, NetworkDirection.LOGIN_TO_SERVER). loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex). decoder(FMLHandshakeMessages.C2SAcknowledge::decode). encoder(FMLHandshakeMessages.C2SAcknowledge::encode). consumer(FMLHandshakeHandler.indexFirst(FMLHandshakeHandler::handleClientAck)). add(); - handshakeChannel.messageBuilder(FMLHandshakeMessages.S2CModList.class, 1). + handshakeChannel.messageBuilder(FMLHandshakeMessages.S2CModList.class, 1, NetworkDirection.LOGIN_TO_CLIENT). loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex). decoder(FMLHandshakeMessages.S2CModList::decode). encoder(FMLHandshakeMessages.S2CModList::encode). @@ -52,14 +52,14 @@ class NetworkInitialization { consumer(FMLHandshakeHandler.biConsumerFor(FMLHandshakeHandler::handleServerModListOnClient)). add(); - handshakeChannel.messageBuilder(FMLHandshakeMessages.C2SModListReply.class, 2). + handshakeChannel.messageBuilder(FMLHandshakeMessages.C2SModListReply.class, 2, NetworkDirection.LOGIN_TO_SERVER). loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex). decoder(FMLHandshakeMessages.C2SModListReply::decode). encoder(FMLHandshakeMessages.C2SModListReply::encode). consumer(FMLHandshakeHandler.indexFirst(FMLHandshakeHandler::handleClientModListOnServer)). add(); - handshakeChannel.messageBuilder(FMLHandshakeMessages.S2CRegistry.class, 3). + handshakeChannel.messageBuilder(FMLHandshakeMessages.S2CRegistry.class, 3, NetworkDirection.LOGIN_TO_CLIENT). loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex). decoder(FMLHandshakeMessages.S2CRegistry::decode). encoder(FMLHandshakeMessages.S2CRegistry::encode). @@ -67,7 +67,7 @@ class NetworkInitialization { consumer(FMLHandshakeHandler.biConsumerFor(FMLHandshakeHandler::handleRegistryMessage)). add(); - handshakeChannel.messageBuilder(FMLHandshakeMessages.S2CConfigData.class, 4). + handshakeChannel.messageBuilder(FMLHandshakeMessages.S2CConfigData.class, 4, NetworkDirection.LOGIN_TO_CLIENT). loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex). decoder(FMLHandshakeMessages.S2CConfigData::decode). encoder(FMLHandshakeMessages.S2CConfigData::encode). diff --git a/src/main/java/net/minecraftforge/fml/network/simple/IndexedMessageCodec.java b/src/main/java/net/minecraftforge/fml/network/simple/IndexedMessageCodec.java index ab1619d73..344208909 100644 --- a/src/main/java/net/minecraftforge/fml/network/simple/IndexedMessageCodec.java +++ b/src/main/java/net/minecraftforge/fml/network/simple/IndexedMessageCodec.java @@ -22,7 +22,9 @@ package net.minecraftforge.fml.network.simple; import it.unimi.dsi.fastutil.objects.Object2ObjectArrayMap; import it.unimi.dsi.fastutil.shorts.Short2ObjectArrayMap; import net.minecraft.network.PacketBuffer; +import net.minecraftforge.fml.network.NetworkDirection; import net.minecraftforge.fml.network.NetworkEvent; +import net.minecraftforge.fml.network.NetworkHooks; import net.minecraftforge.fml.network.NetworkInstance; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -68,16 +70,18 @@ public class IndexedMessageCodec private final int index; private final BiConsumer> messageConsumer; private final Class messageType; + private final Optional networkDirection; private Optional> loginIndexSetter; private Optional> loginIndexGetter; - public MessageHandler(int index, Class messageType, BiConsumer encoder, Function decoder, BiConsumer> messageConsumer) + public MessageHandler(int index, Class messageType, BiConsumer encoder, Function decoder, BiConsumer> messageConsumer, final Optional networkDirection) { this.index = index; this.messageType = messageType; this.encoder = Optional.ofNullable(encoder); this.decoder = Optional.ofNullable(decoder); this.messageConsumer = messageConsumer; + this.networkDirection = networkDirection; this.loginIndexGetter = Optional.empty(); this.loginIndexSetter = Optional.empty(); indicies.put((short)(index & 0xff), this); @@ -154,10 +158,11 @@ public class IndexedMessageCodec LOGGER.error(SIMPLENET, "Received invalid discriminator byte {} on channel {}", discriminator, Optional.ofNullable(networkInstance).map(NetworkInstance::getChannelName).map(Objects::toString).orElse("MISSING CHANNEL")); return; } + NetworkHooks.validatePacketDirection(context.get().getDirection(), messageHandler.networkDirection, context.get().getNetworkManager()); tryDecode(payload, context, payloadIndex, messageHandler); } - MessageHandler addCodecIndex(int index, Class messageType, BiConsumer encoder, Function decoder, BiConsumer> messageConsumer) { - return new MessageHandler<>(index, messageType, encoder, decoder, messageConsumer); + MessageHandler addCodecIndex(int index, Class messageType, BiConsumer encoder, Function decoder, BiConsumer> messageConsumer, final Optional networkDirection) { + return new MessageHandler<>(index, messageType, encoder, decoder, messageConsumer, networkDirection); } } diff --git a/src/main/java/net/minecraftforge/fml/network/simple/SimpleChannel.java b/src/main/java/net/minecraftforge/fml/network/simple/SimpleChannel.java index 1f5148708..b4d4989a4 100644 --- a/src/main/java/net/minecraftforge/fml/network/simple/SimpleChannel.java +++ b/src/main/java/net/minecraftforge/fml/network/simple/SimpleChannel.java @@ -24,6 +24,7 @@ import net.minecraft.client.Minecraft; import net.minecraft.network.NetworkManager; import net.minecraft.network.IPacket; import net.minecraft.network.PacketBuffer; +import net.minecraft.util.text.StringTextComponent; import net.minecraftforge.fml.network.*; import org.apache.commons.lang3.tuple.Pair; @@ -83,8 +84,13 @@ public class SimpleChannel public int encodeMessage(MSG message, final PacketBuffer target) { return this.indexedCodec.build(message, target); } + public IndexedMessageCodec.MessageHandler registerMessage(int index, Class messageType, BiConsumer encoder, Function decoder, BiConsumer> messageConsumer) { - return this.indexedCodec.addCodecIndex(index, messageType, encoder, decoder, messageConsumer); + return registerMessage(index, messageType, encoder, decoder, messageConsumer, Optional.empty()); + } + + public IndexedMessageCodec.MessageHandler registerMessage(int index, Class messageType, BiConsumer encoder, Function decoder, BiConsumer> messageConsumer, final Optional networkDirection) { + return this.indexedCodec.addCodecIndex(index, messageType, encoder, decoder, messageConsumer, networkDirection); } private Pair toBuffer(MSG msg) { @@ -137,7 +143,21 @@ public class SimpleChannel * @return a MessageBuilder */ public MessageBuilder messageBuilder(final Class type, int id) { - return MessageBuilder.forType(this, type, id); + return MessageBuilder.forType(this, type, id, null); + } + + /** + * Build a new MessageBuilder. The type should implement {@link java.util.function.IntSupplier} if it is a login + * packet. + * @param type Type of message + * @param id id in the indexed codec + * @param direction a network direction which will be asserted before any processing of this message occurs. Use to + * enforce strict sided handling to prevent spoofing. + * @param Type of type + * @return a MessageBuilder + */ + public MessageBuilder messageBuilder(final Class type, int id, NetworkDirection direction) { + return MessageBuilder.forType(this, type, id, direction); } public static class MessageBuilder { @@ -150,12 +170,14 @@ public class SimpleChannel private Function loginIndexGetter; private BiConsumer loginIndexSetter; private Function>> loginPacketGenerators; + private Optional networkDirection; - private static MessageBuilder forType(final SimpleChannel channel, final Class type, int id) { + private static MessageBuilder forType(final SimpleChannel channel, final Class type, int id, NetworkDirection networkDirection) { MessageBuilder builder = new MessageBuilder<>(); builder.channel = channel; builder.id = id; builder.type = type; + builder.networkDirection = Optional.ofNullable(networkDirection); return builder; } @@ -215,7 +237,7 @@ public class SimpleChannel } public void add() { - final IndexedMessageCodec.MessageHandler message = this.channel.registerMessage(this.id, this.type, this.encoder, this.decoder, this.consumer); + final IndexedMessageCodec.MessageHandler message = this.channel.registerMessage(this.id, this.type, this.encoder, this.decoder, this.consumer, this.networkDirection); if (this.loginIndexSetter != null) { message.setLoginIndexSetter(this.loginIndexSetter); }