Fix up alarming security crisis with network handling that allowed for wrong code execution on the server, resulting in CRASHED SERVERS. AWFUL stuff.
Also fixed a silly log message from the login handler. This change introduces a mechanism to direct certain packets to only process on one side or another. Invalid sidedness will result in the connection being terminated. Signed-off-by: cpw <cpw+github@weeksfamily.ca>
This commit is contained in:
parent
30d4520c6b
commit
aca45340bf
5 changed files with 59 additions and 12 deletions
|
@ -47,6 +47,8 @@ public class FMLLoginWrapper {
|
|||
}
|
||||
|
||||
private <T extends NetworkEvent> void wrapperReceived(final T packet) {
|
||||
// we don't care about channel registration change events on this channel
|
||||
if (packet instanceof NetworkEvent.ChannelRegistrationChangeEvent) return;
|
||||
final NetworkEvent.Context wrappedContext = packet.getSource().get();
|
||||
final PacketBuffer payload = packet.getPayload();
|
||||
ResourceLocation targetNetworkReceiver = FMLNetworkConstants.FML_HANDSHAKE_RESOURCE;
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
package net.minecraftforge.fml.network;
|
||||
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.function.Supplier;
|
||||
|
@ -28,7 +29,9 @@ import java.util.stream.Collectors;
|
|||
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
|
||||
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
|
||||
import net.minecraft.util.registry.Registry;
|
||||
import net.minecraft.util.text.StringTextComponent;
|
||||
import net.minecraft.world.dimension.DimensionType;
|
||||
import net.minecraftforge.fml.common.thread.EffectiveSide;
|
||||
import net.minecraftforge.registries.ClearableRegistry;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
|
@ -71,9 +74,24 @@ public class NetworkHooks
|
|||
|
||||
public static boolean onCustomPayload(final ICustomPacket<?> packet, final NetworkManager manager) {
|
||||
return NetworkRegistry.findTarget(packet.getName()).
|
||||
filter(ni->validateSideForProcessing(packet, ni, manager)).
|
||||
map(ni->ni.dispatch(packet.getDirection(), packet, manager)).orElse(Boolean.FALSE);
|
||||
}
|
||||
|
||||
private static boolean validateSideForProcessing(final ICustomPacket<?> packet, final NetworkInstance ni, final NetworkManager manager) {
|
||||
if (packet.getDirection().getReceptionSide() != EffectiveSide.get()) {
|
||||
manager.closeChannel(new StringTextComponent("Illegal packet received, terminating connection"));
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
public static void validatePacketDirection(final NetworkDirection packetDirection, final Optional<NetworkDirection> expectedDirection, final NetworkManager connection) {
|
||||
if (packetDirection != expectedDirection.orElse(packetDirection)) {
|
||||
connection.closeChannel(new StringTextComponent("Illegal packet received, terminating connection"));
|
||||
throw new IllegalStateException("Invalid packet received, aborting connection");
|
||||
}
|
||||
}
|
||||
public static void registerServerLoginChannel(NetworkManager manager, CHandshakePacket packet)
|
||||
{
|
||||
manager.channel().attr(FMLNetworkConstants.FML_NETVERSION).set(packet.getFMLVersion());
|
||||
|
|
|
@ -37,14 +37,14 @@ class NetworkInitialization {
|
|||
networkProtocolVersion(() -> FMLNetworkConstants.NETVERSION).
|
||||
simpleChannel();
|
||||
|
||||
handshakeChannel.messageBuilder(FMLHandshakeMessages.C2SAcknowledge.class, 99).
|
||||
handshakeChannel.messageBuilder(FMLHandshakeMessages.C2SAcknowledge.class, 99, NetworkDirection.LOGIN_TO_SERVER).
|
||||
loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex).
|
||||
decoder(FMLHandshakeMessages.C2SAcknowledge::decode).
|
||||
encoder(FMLHandshakeMessages.C2SAcknowledge::encode).
|
||||
consumer(FMLHandshakeHandler.indexFirst(FMLHandshakeHandler::handleClientAck)).
|
||||
add();
|
||||
|
||||
handshakeChannel.messageBuilder(FMLHandshakeMessages.S2CModList.class, 1).
|
||||
handshakeChannel.messageBuilder(FMLHandshakeMessages.S2CModList.class, 1, NetworkDirection.LOGIN_TO_CLIENT).
|
||||
loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex).
|
||||
decoder(FMLHandshakeMessages.S2CModList::decode).
|
||||
encoder(FMLHandshakeMessages.S2CModList::encode).
|
||||
|
@ -52,14 +52,14 @@ class NetworkInitialization {
|
|||
consumer(FMLHandshakeHandler.biConsumerFor(FMLHandshakeHandler::handleServerModListOnClient)).
|
||||
add();
|
||||
|
||||
handshakeChannel.messageBuilder(FMLHandshakeMessages.C2SModListReply.class, 2).
|
||||
handshakeChannel.messageBuilder(FMLHandshakeMessages.C2SModListReply.class, 2, NetworkDirection.LOGIN_TO_SERVER).
|
||||
loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex).
|
||||
decoder(FMLHandshakeMessages.C2SModListReply::decode).
|
||||
encoder(FMLHandshakeMessages.C2SModListReply::encode).
|
||||
consumer(FMLHandshakeHandler.indexFirst(FMLHandshakeHandler::handleClientModListOnServer)).
|
||||
add();
|
||||
|
||||
handshakeChannel.messageBuilder(FMLHandshakeMessages.S2CRegistry.class, 3).
|
||||
handshakeChannel.messageBuilder(FMLHandshakeMessages.S2CRegistry.class, 3, NetworkDirection.LOGIN_TO_CLIENT).
|
||||
loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex).
|
||||
decoder(FMLHandshakeMessages.S2CRegistry::decode).
|
||||
encoder(FMLHandshakeMessages.S2CRegistry::encode).
|
||||
|
@ -67,7 +67,7 @@ class NetworkInitialization {
|
|||
consumer(FMLHandshakeHandler.biConsumerFor(FMLHandshakeHandler::handleRegistryMessage)).
|
||||
add();
|
||||
|
||||
handshakeChannel.messageBuilder(FMLHandshakeMessages.S2CConfigData.class, 4).
|
||||
handshakeChannel.messageBuilder(FMLHandshakeMessages.S2CConfigData.class, 4, NetworkDirection.LOGIN_TO_CLIENT).
|
||||
loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex).
|
||||
decoder(FMLHandshakeMessages.S2CConfigData::decode).
|
||||
encoder(FMLHandshakeMessages.S2CConfigData::encode).
|
||||
|
|
|
@ -22,7 +22,9 @@ package net.minecraftforge.fml.network.simple;
|
|||
import it.unimi.dsi.fastutil.objects.Object2ObjectArrayMap;
|
||||
import it.unimi.dsi.fastutil.shorts.Short2ObjectArrayMap;
|
||||
import net.minecraft.network.PacketBuffer;
|
||||
import net.minecraftforge.fml.network.NetworkDirection;
|
||||
import net.minecraftforge.fml.network.NetworkEvent;
|
||||
import net.minecraftforge.fml.network.NetworkHooks;
|
||||
import net.minecraftforge.fml.network.NetworkInstance;
|
||||
import org.apache.logging.log4j.LogManager;
|
||||
import org.apache.logging.log4j.Logger;
|
||||
|
@ -68,16 +70,18 @@ public class IndexedMessageCodec
|
|||
private final int index;
|
||||
private final BiConsumer<MSG,Supplier<NetworkEvent.Context>> messageConsumer;
|
||||
private final Class<MSG> messageType;
|
||||
private final Optional<NetworkDirection> networkDirection;
|
||||
private Optional<BiConsumer<MSG, Integer>> loginIndexSetter;
|
||||
private Optional<Function<MSG, Integer>> loginIndexGetter;
|
||||
|
||||
public MessageHandler(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer)
|
||||
public MessageHandler(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer, final Optional<NetworkDirection> networkDirection)
|
||||
{
|
||||
this.index = index;
|
||||
this.messageType = messageType;
|
||||
this.encoder = Optional.ofNullable(encoder);
|
||||
this.decoder = Optional.ofNullable(decoder);
|
||||
this.messageConsumer = messageConsumer;
|
||||
this.networkDirection = networkDirection;
|
||||
this.loginIndexGetter = Optional.empty();
|
||||
this.loginIndexSetter = Optional.empty();
|
||||
indicies.put((short)(index & 0xff), this);
|
||||
|
@ -154,10 +158,11 @@ public class IndexedMessageCodec
|
|||
LOGGER.error(SIMPLENET, "Received invalid discriminator byte {} on channel {}", discriminator, Optional.ofNullable(networkInstance).map(NetworkInstance::getChannelName).map(Objects::toString).orElse("MISSING CHANNEL"));
|
||||
return;
|
||||
}
|
||||
NetworkHooks.validatePacketDirection(context.get().getDirection(), messageHandler.networkDirection, context.get().getNetworkManager());
|
||||
tryDecode(payload, context, payloadIndex, messageHandler);
|
||||
}
|
||||
|
||||
<MSG> MessageHandler<MSG> addCodecIndex(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer) {
|
||||
return new MessageHandler<>(index, messageType, encoder, decoder, messageConsumer);
|
||||
<MSG> MessageHandler<MSG> addCodecIndex(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer, final Optional<NetworkDirection> networkDirection) {
|
||||
return new MessageHandler<>(index, messageType, encoder, decoder, messageConsumer, networkDirection);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@ import net.minecraft.client.Minecraft;
|
|||
import net.minecraft.network.NetworkManager;
|
||||
import net.minecraft.network.IPacket;
|
||||
import net.minecraft.network.PacketBuffer;
|
||||
import net.minecraft.util.text.StringTextComponent;
|
||||
import net.minecraftforge.fml.network.*;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
|
||||
|
@ -83,8 +84,13 @@ public class SimpleChannel
|
|||
public <MSG> int encodeMessage(MSG message, final PacketBuffer target) {
|
||||
return this.indexedCodec.build(message, target);
|
||||
}
|
||||
|
||||
public <MSG> IndexedMessageCodec.MessageHandler<MSG> registerMessage(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer) {
|
||||
return this.indexedCodec.addCodecIndex(index, messageType, encoder, decoder, messageConsumer);
|
||||
return registerMessage(index, messageType, encoder, decoder, messageConsumer, Optional.empty());
|
||||
}
|
||||
|
||||
public <MSG> IndexedMessageCodec.MessageHandler<MSG> registerMessage(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer, final Optional<NetworkDirection> networkDirection) {
|
||||
return this.indexedCodec.addCodecIndex(index, messageType, encoder, decoder, messageConsumer, networkDirection);
|
||||
}
|
||||
|
||||
private <MSG> Pair<PacketBuffer,Integer> toBuffer(MSG msg) {
|
||||
|
@ -137,7 +143,21 @@ public class SimpleChannel
|
|||
* @return a MessageBuilder
|
||||
*/
|
||||
public <M> MessageBuilder<M> messageBuilder(final Class<M> type, int id) {
|
||||
return MessageBuilder.forType(this, type, id);
|
||||
return MessageBuilder.forType(this, type, id, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a new MessageBuilder. The type should implement {@link java.util.function.IntSupplier} if it is a login
|
||||
* packet.
|
||||
* @param type Type of message
|
||||
* @param id id in the indexed codec
|
||||
* @param direction a network direction which will be asserted before any processing of this message occurs. Use to
|
||||
* enforce strict sided handling to prevent spoofing.
|
||||
* @param <M> Type of type
|
||||
* @return a MessageBuilder
|
||||
*/
|
||||
public <M> MessageBuilder<M> messageBuilder(final Class<M> type, int id, NetworkDirection direction) {
|
||||
return MessageBuilder.forType(this, type, id, direction);
|
||||
}
|
||||
|
||||
public static class MessageBuilder<MSG> {
|
||||
|
@ -150,12 +170,14 @@ public class SimpleChannel
|
|||
private Function<MSG, Integer> loginIndexGetter;
|
||||
private BiConsumer<MSG, Integer> loginIndexSetter;
|
||||
private Function<Boolean, List<Pair<String, MSG>>> loginPacketGenerators;
|
||||
private Optional<NetworkDirection> networkDirection;
|
||||
|
||||
private static <MSG> MessageBuilder<MSG> forType(final SimpleChannel channel, final Class<MSG> type, int id) {
|
||||
private static <MSG> MessageBuilder<MSG> forType(final SimpleChannel channel, final Class<MSG> type, int id, NetworkDirection networkDirection) {
|
||||
MessageBuilder<MSG> builder = new MessageBuilder<>();
|
||||
builder.channel = channel;
|
||||
builder.id = id;
|
||||
builder.type = type;
|
||||
builder.networkDirection = Optional.ofNullable(networkDirection);
|
||||
return builder;
|
||||
}
|
||||
|
||||
|
@ -215,7 +237,7 @@ public class SimpleChannel
|
|||
}
|
||||
|
||||
public void add() {
|
||||
final IndexedMessageCodec.MessageHandler<MSG> message = this.channel.registerMessage(this.id, this.type, this.encoder, this.decoder, this.consumer);
|
||||
final IndexedMessageCodec.MessageHandler<MSG> message = this.channel.registerMessage(this.id, this.type, this.encoder, this.decoder, this.consumer, this.networkDirection);
|
||||
if (this.loginIndexSetter != null) {
|
||||
message.setLoginIndexSetter(this.loginIndexSetter);
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue