Fix up alarming security crisis with network handling that allowed for wrong code execution on the server, resulting in CRASHED SERVERS. AWFUL stuff.

Also fixed a silly log message from the login handler.

This change introduces a mechanism to direct certain packets to only process on one side or another. Invalid sidedness will result in the connection being terminated.

Signed-off-by: cpw <cpw+github@weeksfamily.ca>
This commit is contained in:
cpw 2020-02-23 22:15:34 -05:00
parent 30d4520c6b
commit aca45340bf
No known key found for this signature in database
GPG Key ID: 8EB3DF749553B1B7
5 changed files with 59 additions and 12 deletions

View File

@ -47,6 +47,8 @@ public class FMLLoginWrapper {
}
private <T extends NetworkEvent> void wrapperReceived(final T packet) {
// we don't care about channel registration change events on this channel
if (packet instanceof NetworkEvent.ChannelRegistrationChangeEvent) return;
final NetworkEvent.Context wrappedContext = packet.getSource().get();
final PacketBuffer payload = packet.getPayload();
ResourceLocation targetNetworkReceiver = FMLNetworkConstants.FML_HANDSHAKE_RESOURCE;

View File

@ -20,6 +20,7 @@
package net.minecraftforge.fml.network;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Supplier;
@ -28,7 +29,9 @@ import java.util.stream.Collectors;
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
import net.minecraft.util.registry.Registry;
import net.minecraft.util.text.StringTextComponent;
import net.minecraft.world.dimension.DimensionType;
import net.minecraftforge.fml.common.thread.EffectiveSide;
import net.minecraftforge.registries.ClearableRegistry;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
@ -71,9 +74,24 @@ public class NetworkHooks
public static boolean onCustomPayload(final ICustomPacket<?> packet, final NetworkManager manager) {
return NetworkRegistry.findTarget(packet.getName()).
filter(ni->validateSideForProcessing(packet, ni, manager)).
map(ni->ni.dispatch(packet.getDirection(), packet, manager)).orElse(Boolean.FALSE);
}
private static boolean validateSideForProcessing(final ICustomPacket<?> packet, final NetworkInstance ni, final NetworkManager manager) {
if (packet.getDirection().getReceptionSide() != EffectiveSide.get()) {
manager.closeChannel(new StringTextComponent("Illegal packet received, terminating connection"));
return false;
}
return true;
}
public static void validatePacketDirection(final NetworkDirection packetDirection, final Optional<NetworkDirection> expectedDirection, final NetworkManager connection) {
if (packetDirection != expectedDirection.orElse(packetDirection)) {
connection.closeChannel(new StringTextComponent("Illegal packet received, terminating connection"));
throw new IllegalStateException("Invalid packet received, aborting connection");
}
}
public static void registerServerLoginChannel(NetworkManager manager, CHandshakePacket packet)
{
manager.channel().attr(FMLNetworkConstants.FML_NETVERSION).set(packet.getFMLVersion());

View File

@ -37,14 +37,14 @@ class NetworkInitialization {
networkProtocolVersion(() -> FMLNetworkConstants.NETVERSION).
simpleChannel();
handshakeChannel.messageBuilder(FMLHandshakeMessages.C2SAcknowledge.class, 99).
handshakeChannel.messageBuilder(FMLHandshakeMessages.C2SAcknowledge.class, 99, NetworkDirection.LOGIN_TO_SERVER).
loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex).
decoder(FMLHandshakeMessages.C2SAcknowledge::decode).
encoder(FMLHandshakeMessages.C2SAcknowledge::encode).
consumer(FMLHandshakeHandler.indexFirst(FMLHandshakeHandler::handleClientAck)).
add();
handshakeChannel.messageBuilder(FMLHandshakeMessages.S2CModList.class, 1).
handshakeChannel.messageBuilder(FMLHandshakeMessages.S2CModList.class, 1, NetworkDirection.LOGIN_TO_CLIENT).
loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex).
decoder(FMLHandshakeMessages.S2CModList::decode).
encoder(FMLHandshakeMessages.S2CModList::encode).
@ -52,14 +52,14 @@ class NetworkInitialization {
consumer(FMLHandshakeHandler.biConsumerFor(FMLHandshakeHandler::handleServerModListOnClient)).
add();
handshakeChannel.messageBuilder(FMLHandshakeMessages.C2SModListReply.class, 2).
handshakeChannel.messageBuilder(FMLHandshakeMessages.C2SModListReply.class, 2, NetworkDirection.LOGIN_TO_SERVER).
loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex).
decoder(FMLHandshakeMessages.C2SModListReply::decode).
encoder(FMLHandshakeMessages.C2SModListReply::encode).
consumer(FMLHandshakeHandler.indexFirst(FMLHandshakeHandler::handleClientModListOnServer)).
add();
handshakeChannel.messageBuilder(FMLHandshakeMessages.S2CRegistry.class, 3).
handshakeChannel.messageBuilder(FMLHandshakeMessages.S2CRegistry.class, 3, NetworkDirection.LOGIN_TO_CLIENT).
loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex).
decoder(FMLHandshakeMessages.S2CRegistry::decode).
encoder(FMLHandshakeMessages.S2CRegistry::encode).
@ -67,7 +67,7 @@ class NetworkInitialization {
consumer(FMLHandshakeHandler.biConsumerFor(FMLHandshakeHandler::handleRegistryMessage)).
add();
handshakeChannel.messageBuilder(FMLHandshakeMessages.S2CConfigData.class, 4).
handshakeChannel.messageBuilder(FMLHandshakeMessages.S2CConfigData.class, 4, NetworkDirection.LOGIN_TO_CLIENT).
loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex).
decoder(FMLHandshakeMessages.S2CConfigData::decode).
encoder(FMLHandshakeMessages.S2CConfigData::encode).

View File

@ -22,7 +22,9 @@ package net.minecraftforge.fml.network.simple;
import it.unimi.dsi.fastutil.objects.Object2ObjectArrayMap;
import it.unimi.dsi.fastutil.shorts.Short2ObjectArrayMap;
import net.minecraft.network.PacketBuffer;
import net.minecraftforge.fml.network.NetworkDirection;
import net.minecraftforge.fml.network.NetworkEvent;
import net.minecraftforge.fml.network.NetworkHooks;
import net.minecraftforge.fml.network.NetworkInstance;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
@ -68,16 +70,18 @@ public class IndexedMessageCodec
private final int index;
private final BiConsumer<MSG,Supplier<NetworkEvent.Context>> messageConsumer;
private final Class<MSG> messageType;
private final Optional<NetworkDirection> networkDirection;
private Optional<BiConsumer<MSG, Integer>> loginIndexSetter;
private Optional<Function<MSG, Integer>> loginIndexGetter;
public MessageHandler(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer)
public MessageHandler(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer, final Optional<NetworkDirection> networkDirection)
{
this.index = index;
this.messageType = messageType;
this.encoder = Optional.ofNullable(encoder);
this.decoder = Optional.ofNullable(decoder);
this.messageConsumer = messageConsumer;
this.networkDirection = networkDirection;
this.loginIndexGetter = Optional.empty();
this.loginIndexSetter = Optional.empty();
indicies.put((short)(index & 0xff), this);
@ -154,10 +158,11 @@ public class IndexedMessageCodec
LOGGER.error(SIMPLENET, "Received invalid discriminator byte {} on channel {}", discriminator, Optional.ofNullable(networkInstance).map(NetworkInstance::getChannelName).map(Objects::toString).orElse("MISSING CHANNEL"));
return;
}
NetworkHooks.validatePacketDirection(context.get().getDirection(), messageHandler.networkDirection, context.get().getNetworkManager());
tryDecode(payload, context, payloadIndex, messageHandler);
}
<MSG> MessageHandler<MSG> addCodecIndex(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer) {
return new MessageHandler<>(index, messageType, encoder, decoder, messageConsumer);
<MSG> MessageHandler<MSG> addCodecIndex(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer, final Optional<NetworkDirection> networkDirection) {
return new MessageHandler<>(index, messageType, encoder, decoder, messageConsumer, networkDirection);
}
}

View File

@ -24,6 +24,7 @@ import net.minecraft.client.Minecraft;
import net.minecraft.network.NetworkManager;
import net.minecraft.network.IPacket;
import net.minecraft.network.PacketBuffer;
import net.minecraft.util.text.StringTextComponent;
import net.minecraftforge.fml.network.*;
import org.apache.commons.lang3.tuple.Pair;
@ -83,8 +84,13 @@ public class SimpleChannel
public <MSG> int encodeMessage(MSG message, final PacketBuffer target) {
return this.indexedCodec.build(message, target);
}
public <MSG> IndexedMessageCodec.MessageHandler<MSG> registerMessage(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer) {
return this.indexedCodec.addCodecIndex(index, messageType, encoder, decoder, messageConsumer);
return registerMessage(index, messageType, encoder, decoder, messageConsumer, Optional.empty());
}
public <MSG> IndexedMessageCodec.MessageHandler<MSG> registerMessage(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer, final Optional<NetworkDirection> networkDirection) {
return this.indexedCodec.addCodecIndex(index, messageType, encoder, decoder, messageConsumer, networkDirection);
}
private <MSG> Pair<PacketBuffer,Integer> toBuffer(MSG msg) {
@ -137,7 +143,21 @@ public class SimpleChannel
* @return a MessageBuilder
*/
public <M> MessageBuilder<M> messageBuilder(final Class<M> type, int id) {
return MessageBuilder.forType(this, type, id);
return MessageBuilder.forType(this, type, id, null);
}
/**
* Build a new MessageBuilder. The type should implement {@link java.util.function.IntSupplier} if it is a login
* packet.
* @param type Type of message
* @param id id in the indexed codec
* @param direction a network direction which will be asserted before any processing of this message occurs. Use to
* enforce strict sided handling to prevent spoofing.
* @param <M> Type of type
* @return a MessageBuilder
*/
public <M> MessageBuilder<M> messageBuilder(final Class<M> type, int id, NetworkDirection direction) {
return MessageBuilder.forType(this, type, id, direction);
}
public static class MessageBuilder<MSG> {
@ -150,12 +170,14 @@ public class SimpleChannel
private Function<MSG, Integer> loginIndexGetter;
private BiConsumer<MSG, Integer> loginIndexSetter;
private Function<Boolean, List<Pair<String, MSG>>> loginPacketGenerators;
private Optional<NetworkDirection> networkDirection;
private static <MSG> MessageBuilder<MSG> forType(final SimpleChannel channel, final Class<MSG> type, int id) {
private static <MSG> MessageBuilder<MSG> forType(final SimpleChannel channel, final Class<MSG> type, int id, NetworkDirection networkDirection) {
MessageBuilder<MSG> builder = new MessageBuilder<>();
builder.channel = channel;
builder.id = id;
builder.type = type;
builder.networkDirection = Optional.ofNullable(networkDirection);
return builder;
}
@ -215,7 +237,7 @@ public class SimpleChannel
}
public void add() {
final IndexedMessageCodec.MessageHandler<MSG> message = this.channel.registerMessage(this.id, this.type, this.encoder, this.decoder, this.consumer);
final IndexedMessageCodec.MessageHandler<MSG> message = this.channel.registerMessage(this.id, this.type, this.encoder, this.decoder, this.consumer, this.networkDirection);
if (this.loginIndexSetter != null) {
message.setLoginIndexSetter(this.loginIndexSetter);
}