Fix up alarming security crisis with network handling that allowed for wrong code execution on the server, resulting in CRASHED SERVERS. AWFUL stuff.

Also fixed a silly log message from the login handler.

This change introduces a mechanism to direct certain packets to only process on one side or another. Invalid sidedness will result in the connection being terminated.

Signed-off-by: cpw <cpw+github@weeksfamily.ca>
This commit is contained in:
cpw 2020-02-23 22:15:34 -05:00
parent 30d4520c6b
commit aca45340bf
No known key found for this signature in database
GPG Key ID: 8EB3DF749553B1B7
5 changed files with 59 additions and 12 deletions

View File

@ -47,6 +47,8 @@ public class FMLLoginWrapper {
} }
private <T extends NetworkEvent> void wrapperReceived(final T packet) { private <T extends NetworkEvent> void wrapperReceived(final T packet) {
// we don't care about channel registration change events on this channel
if (packet instanceof NetworkEvent.ChannelRegistrationChangeEvent) return;
final NetworkEvent.Context wrappedContext = packet.getSource().get(); final NetworkEvent.Context wrappedContext = packet.getSource().get();
final PacketBuffer payload = packet.getPayload(); final PacketBuffer payload = packet.getPayload();
ResourceLocation targetNetworkReceiver = FMLNetworkConstants.FML_HANDSHAKE_RESOURCE; ResourceLocation targetNetworkReceiver = FMLNetworkConstants.FML_HANDSHAKE_RESOURCE;

View File

@ -20,6 +20,7 @@
package net.minecraftforge.fml.network; package net.minecraftforge.fml.network;
import java.util.Objects; import java.util.Objects;
import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Supplier; import java.util.function.Supplier;
@ -28,7 +29,9 @@ import java.util.stream.Collectors;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap; import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import net.minecraft.util.registry.Registry; import net.minecraft.util.registry.Registry;
import net.minecraft.util.text.StringTextComponent;
import net.minecraft.world.dimension.DimensionType; import net.minecraft.world.dimension.DimensionType;
import net.minecraftforge.fml.common.thread.EffectiveSide;
import net.minecraftforge.registries.ClearableRegistry; import net.minecraftforge.registries.ClearableRegistry;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
@ -71,9 +74,24 @@ public class NetworkHooks
public static boolean onCustomPayload(final ICustomPacket<?> packet, final NetworkManager manager) { public static boolean onCustomPayload(final ICustomPacket<?> packet, final NetworkManager manager) {
return NetworkRegistry.findTarget(packet.getName()). return NetworkRegistry.findTarget(packet.getName()).
filter(ni->validateSideForProcessing(packet, ni, manager)).
map(ni->ni.dispatch(packet.getDirection(), packet, manager)).orElse(Boolean.FALSE); map(ni->ni.dispatch(packet.getDirection(), packet, manager)).orElse(Boolean.FALSE);
} }
private static boolean validateSideForProcessing(final ICustomPacket<?> packet, final NetworkInstance ni, final NetworkManager manager) {
if (packet.getDirection().getReceptionSide() != EffectiveSide.get()) {
manager.closeChannel(new StringTextComponent("Illegal packet received, terminating connection"));
return false;
}
return true;
}
public static void validatePacketDirection(final NetworkDirection packetDirection, final Optional<NetworkDirection> expectedDirection, final NetworkManager connection) {
if (packetDirection != expectedDirection.orElse(packetDirection)) {
connection.closeChannel(new StringTextComponent("Illegal packet received, terminating connection"));
throw new IllegalStateException("Invalid packet received, aborting connection");
}
}
public static void registerServerLoginChannel(NetworkManager manager, CHandshakePacket packet) public static void registerServerLoginChannel(NetworkManager manager, CHandshakePacket packet)
{ {
manager.channel().attr(FMLNetworkConstants.FML_NETVERSION).set(packet.getFMLVersion()); manager.channel().attr(FMLNetworkConstants.FML_NETVERSION).set(packet.getFMLVersion());

View File

@ -37,14 +37,14 @@ class NetworkInitialization {
networkProtocolVersion(() -> FMLNetworkConstants.NETVERSION). networkProtocolVersion(() -> FMLNetworkConstants.NETVERSION).
simpleChannel(); simpleChannel();
handshakeChannel.messageBuilder(FMLHandshakeMessages.C2SAcknowledge.class, 99). handshakeChannel.messageBuilder(FMLHandshakeMessages.C2SAcknowledge.class, 99, NetworkDirection.LOGIN_TO_SERVER).
loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex). loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex).
decoder(FMLHandshakeMessages.C2SAcknowledge::decode). decoder(FMLHandshakeMessages.C2SAcknowledge::decode).
encoder(FMLHandshakeMessages.C2SAcknowledge::encode). encoder(FMLHandshakeMessages.C2SAcknowledge::encode).
consumer(FMLHandshakeHandler.indexFirst(FMLHandshakeHandler::handleClientAck)). consumer(FMLHandshakeHandler.indexFirst(FMLHandshakeHandler::handleClientAck)).
add(); add();
handshakeChannel.messageBuilder(FMLHandshakeMessages.S2CModList.class, 1). handshakeChannel.messageBuilder(FMLHandshakeMessages.S2CModList.class, 1, NetworkDirection.LOGIN_TO_CLIENT).
loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex). loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex).
decoder(FMLHandshakeMessages.S2CModList::decode). decoder(FMLHandshakeMessages.S2CModList::decode).
encoder(FMLHandshakeMessages.S2CModList::encode). encoder(FMLHandshakeMessages.S2CModList::encode).
@ -52,14 +52,14 @@ class NetworkInitialization {
consumer(FMLHandshakeHandler.biConsumerFor(FMLHandshakeHandler::handleServerModListOnClient)). consumer(FMLHandshakeHandler.biConsumerFor(FMLHandshakeHandler::handleServerModListOnClient)).
add(); add();
handshakeChannel.messageBuilder(FMLHandshakeMessages.C2SModListReply.class, 2). handshakeChannel.messageBuilder(FMLHandshakeMessages.C2SModListReply.class, 2, NetworkDirection.LOGIN_TO_SERVER).
loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex). loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex).
decoder(FMLHandshakeMessages.C2SModListReply::decode). decoder(FMLHandshakeMessages.C2SModListReply::decode).
encoder(FMLHandshakeMessages.C2SModListReply::encode). encoder(FMLHandshakeMessages.C2SModListReply::encode).
consumer(FMLHandshakeHandler.indexFirst(FMLHandshakeHandler::handleClientModListOnServer)). consumer(FMLHandshakeHandler.indexFirst(FMLHandshakeHandler::handleClientModListOnServer)).
add(); add();
handshakeChannel.messageBuilder(FMLHandshakeMessages.S2CRegistry.class, 3). handshakeChannel.messageBuilder(FMLHandshakeMessages.S2CRegistry.class, 3, NetworkDirection.LOGIN_TO_CLIENT).
loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex). loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex).
decoder(FMLHandshakeMessages.S2CRegistry::decode). decoder(FMLHandshakeMessages.S2CRegistry::decode).
encoder(FMLHandshakeMessages.S2CRegistry::encode). encoder(FMLHandshakeMessages.S2CRegistry::encode).
@ -67,7 +67,7 @@ class NetworkInitialization {
consumer(FMLHandshakeHandler.biConsumerFor(FMLHandshakeHandler::handleRegistryMessage)). consumer(FMLHandshakeHandler.biConsumerFor(FMLHandshakeHandler::handleRegistryMessage)).
add(); add();
handshakeChannel.messageBuilder(FMLHandshakeMessages.S2CConfigData.class, 4). handshakeChannel.messageBuilder(FMLHandshakeMessages.S2CConfigData.class, 4, NetworkDirection.LOGIN_TO_CLIENT).
loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex). loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex).
decoder(FMLHandshakeMessages.S2CConfigData::decode). decoder(FMLHandshakeMessages.S2CConfigData::decode).
encoder(FMLHandshakeMessages.S2CConfigData::encode). encoder(FMLHandshakeMessages.S2CConfigData::encode).

View File

@ -22,7 +22,9 @@ package net.minecraftforge.fml.network.simple;
import it.unimi.dsi.fastutil.objects.Object2ObjectArrayMap; import it.unimi.dsi.fastutil.objects.Object2ObjectArrayMap;
import it.unimi.dsi.fastutil.shorts.Short2ObjectArrayMap; import it.unimi.dsi.fastutil.shorts.Short2ObjectArrayMap;
import net.minecraft.network.PacketBuffer; import net.minecraft.network.PacketBuffer;
import net.minecraftforge.fml.network.NetworkDirection;
import net.minecraftforge.fml.network.NetworkEvent; import net.minecraftforge.fml.network.NetworkEvent;
import net.minecraftforge.fml.network.NetworkHooks;
import net.minecraftforge.fml.network.NetworkInstance; import net.minecraftforge.fml.network.NetworkInstance;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
@ -68,16 +70,18 @@ public class IndexedMessageCodec
private final int index; private final int index;
private final BiConsumer<MSG,Supplier<NetworkEvent.Context>> messageConsumer; private final BiConsumer<MSG,Supplier<NetworkEvent.Context>> messageConsumer;
private final Class<MSG> messageType; private final Class<MSG> messageType;
private final Optional<NetworkDirection> networkDirection;
private Optional<BiConsumer<MSG, Integer>> loginIndexSetter; private Optional<BiConsumer<MSG, Integer>> loginIndexSetter;
private Optional<Function<MSG, Integer>> loginIndexGetter; private Optional<Function<MSG, Integer>> loginIndexGetter;
public MessageHandler(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer) public MessageHandler(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer, final Optional<NetworkDirection> networkDirection)
{ {
this.index = index; this.index = index;
this.messageType = messageType; this.messageType = messageType;
this.encoder = Optional.ofNullable(encoder); this.encoder = Optional.ofNullable(encoder);
this.decoder = Optional.ofNullable(decoder); this.decoder = Optional.ofNullable(decoder);
this.messageConsumer = messageConsumer; this.messageConsumer = messageConsumer;
this.networkDirection = networkDirection;
this.loginIndexGetter = Optional.empty(); this.loginIndexGetter = Optional.empty();
this.loginIndexSetter = Optional.empty(); this.loginIndexSetter = Optional.empty();
indicies.put((short)(index & 0xff), this); indicies.put((short)(index & 0xff), this);
@ -154,10 +158,11 @@ public class IndexedMessageCodec
LOGGER.error(SIMPLENET, "Received invalid discriminator byte {} on channel {}", discriminator, Optional.ofNullable(networkInstance).map(NetworkInstance::getChannelName).map(Objects::toString).orElse("MISSING CHANNEL")); LOGGER.error(SIMPLENET, "Received invalid discriminator byte {} on channel {}", discriminator, Optional.ofNullable(networkInstance).map(NetworkInstance::getChannelName).map(Objects::toString).orElse("MISSING CHANNEL"));
return; return;
} }
NetworkHooks.validatePacketDirection(context.get().getDirection(), messageHandler.networkDirection, context.get().getNetworkManager());
tryDecode(payload, context, payloadIndex, messageHandler); tryDecode(payload, context, payloadIndex, messageHandler);
} }
<MSG> MessageHandler<MSG> addCodecIndex(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer) { <MSG> MessageHandler<MSG> addCodecIndex(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer, final Optional<NetworkDirection> networkDirection) {
return new MessageHandler<>(index, messageType, encoder, decoder, messageConsumer); return new MessageHandler<>(index, messageType, encoder, decoder, messageConsumer, networkDirection);
} }
} }

View File

@ -24,6 +24,7 @@ import net.minecraft.client.Minecraft;
import net.minecraft.network.NetworkManager; import net.minecraft.network.NetworkManager;
import net.minecraft.network.IPacket; import net.minecraft.network.IPacket;
import net.minecraft.network.PacketBuffer; import net.minecraft.network.PacketBuffer;
import net.minecraft.util.text.StringTextComponent;
import net.minecraftforge.fml.network.*; import net.minecraftforge.fml.network.*;
import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Pair;
@ -83,8 +84,13 @@ public class SimpleChannel
public <MSG> int encodeMessage(MSG message, final PacketBuffer target) { public <MSG> int encodeMessage(MSG message, final PacketBuffer target) {
return this.indexedCodec.build(message, target); return this.indexedCodec.build(message, target);
} }
public <MSG> IndexedMessageCodec.MessageHandler<MSG> registerMessage(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer) { public <MSG> IndexedMessageCodec.MessageHandler<MSG> registerMessage(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer) {
return this.indexedCodec.addCodecIndex(index, messageType, encoder, decoder, messageConsumer); return registerMessage(index, messageType, encoder, decoder, messageConsumer, Optional.empty());
}
public <MSG> IndexedMessageCodec.MessageHandler<MSG> registerMessage(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer, final Optional<NetworkDirection> networkDirection) {
return this.indexedCodec.addCodecIndex(index, messageType, encoder, decoder, messageConsumer, networkDirection);
} }
private <MSG> Pair<PacketBuffer,Integer> toBuffer(MSG msg) { private <MSG> Pair<PacketBuffer,Integer> toBuffer(MSG msg) {
@ -137,7 +143,21 @@ public class SimpleChannel
* @return a MessageBuilder * @return a MessageBuilder
*/ */
public <M> MessageBuilder<M> messageBuilder(final Class<M> type, int id) { public <M> MessageBuilder<M> messageBuilder(final Class<M> type, int id) {
return MessageBuilder.forType(this, type, id); return MessageBuilder.forType(this, type, id, null);
}
/**
* Build a new MessageBuilder. The type should implement {@link java.util.function.IntSupplier} if it is a login
* packet.
* @param type Type of message
* @param id id in the indexed codec
* @param direction a network direction which will be asserted before any processing of this message occurs. Use to
* enforce strict sided handling to prevent spoofing.
* @param <M> Type of type
* @return a MessageBuilder
*/
public <M> MessageBuilder<M> messageBuilder(final Class<M> type, int id, NetworkDirection direction) {
return MessageBuilder.forType(this, type, id, direction);
} }
public static class MessageBuilder<MSG> { public static class MessageBuilder<MSG> {
@ -150,12 +170,14 @@ public class SimpleChannel
private Function<MSG, Integer> loginIndexGetter; private Function<MSG, Integer> loginIndexGetter;
private BiConsumer<MSG, Integer> loginIndexSetter; private BiConsumer<MSG, Integer> loginIndexSetter;
private Function<Boolean, List<Pair<String, MSG>>> loginPacketGenerators; private Function<Boolean, List<Pair<String, MSG>>> loginPacketGenerators;
private Optional<NetworkDirection> networkDirection;
private static <MSG> MessageBuilder<MSG> forType(final SimpleChannel channel, final Class<MSG> type, int id) { private static <MSG> MessageBuilder<MSG> forType(final SimpleChannel channel, final Class<MSG> type, int id, NetworkDirection networkDirection) {
MessageBuilder<MSG> builder = new MessageBuilder<>(); MessageBuilder<MSG> builder = new MessageBuilder<>();
builder.channel = channel; builder.channel = channel;
builder.id = id; builder.id = id;
builder.type = type; builder.type = type;
builder.networkDirection = Optional.ofNullable(networkDirection);
return builder; return builder;
} }
@ -215,7 +237,7 @@ public class SimpleChannel
} }
public void add() { public void add() {
final IndexedMessageCodec.MessageHandler<MSG> message = this.channel.registerMessage(this.id, this.type, this.encoder, this.decoder, this.consumer); final IndexedMessageCodec.MessageHandler<MSG> message = this.channel.registerMessage(this.id, this.type, this.encoder, this.decoder, this.consumer, this.networkDirection);
if (this.loginIndexSetter != null) { if (this.loginIndexSetter != null) {
message.setLoginIndexSetter(this.loginIndexSetter); message.setLoginIndexSetter(this.loginIndexSetter);
} }