Fix up alarming security crisis with network handling that allowed for wrong code execution on the server, resulting in CRASHED SERVERS. AWFUL stuff.
Also fixed a silly log message from the login handler. This change introduces a mechanism to direct certain packets to only process on one side or another. Invalid sidedness will result in the connection being terminated. Signed-off-by: cpw <cpw+github@weeksfamily.ca>
This commit is contained in:
parent
30d4520c6b
commit
aca45340bf
|
@ -47,6 +47,8 @@ public class FMLLoginWrapper {
|
||||||
}
|
}
|
||||||
|
|
||||||
private <T extends NetworkEvent> void wrapperReceived(final T packet) {
|
private <T extends NetworkEvent> void wrapperReceived(final T packet) {
|
||||||
|
// we don't care about channel registration change events on this channel
|
||||||
|
if (packet instanceof NetworkEvent.ChannelRegistrationChangeEvent) return;
|
||||||
final NetworkEvent.Context wrappedContext = packet.getSource().get();
|
final NetworkEvent.Context wrappedContext = packet.getSource().get();
|
||||||
final PacketBuffer payload = packet.getPayload();
|
final PacketBuffer payload = packet.getPayload();
|
||||||
ResourceLocation targetNetworkReceiver = FMLNetworkConstants.FML_HANDSHAKE_RESOURCE;
|
ResourceLocation targetNetworkReceiver = FMLNetworkConstants.FML_HANDSHAKE_RESOURCE;
|
||||||
|
|
|
@ -20,6 +20,7 @@
|
||||||
package net.minecraftforge.fml.network;
|
package net.minecraftforge.fml.network;
|
||||||
|
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
import java.util.Optional;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.function.Consumer;
|
import java.util.function.Consumer;
|
||||||
import java.util.function.Supplier;
|
import java.util.function.Supplier;
|
||||||
|
@ -28,7 +29,9 @@ import java.util.stream.Collectors;
|
||||||
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
|
import it.unimi.dsi.fastutil.ints.Int2ObjectMap;
|
||||||
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
|
import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap;
|
||||||
import net.minecraft.util.registry.Registry;
|
import net.minecraft.util.registry.Registry;
|
||||||
|
import net.minecraft.util.text.StringTextComponent;
|
||||||
import net.minecraft.world.dimension.DimensionType;
|
import net.minecraft.world.dimension.DimensionType;
|
||||||
|
import net.minecraftforge.fml.common.thread.EffectiveSide;
|
||||||
import net.minecraftforge.registries.ClearableRegistry;
|
import net.minecraftforge.registries.ClearableRegistry;
|
||||||
import org.apache.logging.log4j.LogManager;
|
import org.apache.logging.log4j.LogManager;
|
||||||
import org.apache.logging.log4j.Logger;
|
import org.apache.logging.log4j.Logger;
|
||||||
|
@ -71,9 +74,24 @@ public class NetworkHooks
|
||||||
|
|
||||||
public static boolean onCustomPayload(final ICustomPacket<?> packet, final NetworkManager manager) {
|
public static boolean onCustomPayload(final ICustomPacket<?> packet, final NetworkManager manager) {
|
||||||
return NetworkRegistry.findTarget(packet.getName()).
|
return NetworkRegistry.findTarget(packet.getName()).
|
||||||
|
filter(ni->validateSideForProcessing(packet, ni, manager)).
|
||||||
map(ni->ni.dispatch(packet.getDirection(), packet, manager)).orElse(Boolean.FALSE);
|
map(ni->ni.dispatch(packet.getDirection(), packet, manager)).orElse(Boolean.FALSE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static boolean validateSideForProcessing(final ICustomPacket<?> packet, final NetworkInstance ni, final NetworkManager manager) {
|
||||||
|
if (packet.getDirection().getReceptionSide() != EffectiveSide.get()) {
|
||||||
|
manager.closeChannel(new StringTextComponent("Illegal packet received, terminating connection"));
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void validatePacketDirection(final NetworkDirection packetDirection, final Optional<NetworkDirection> expectedDirection, final NetworkManager connection) {
|
||||||
|
if (packetDirection != expectedDirection.orElse(packetDirection)) {
|
||||||
|
connection.closeChannel(new StringTextComponent("Illegal packet received, terminating connection"));
|
||||||
|
throw new IllegalStateException("Invalid packet received, aborting connection");
|
||||||
|
}
|
||||||
|
}
|
||||||
public static void registerServerLoginChannel(NetworkManager manager, CHandshakePacket packet)
|
public static void registerServerLoginChannel(NetworkManager manager, CHandshakePacket packet)
|
||||||
{
|
{
|
||||||
manager.channel().attr(FMLNetworkConstants.FML_NETVERSION).set(packet.getFMLVersion());
|
manager.channel().attr(FMLNetworkConstants.FML_NETVERSION).set(packet.getFMLVersion());
|
||||||
|
|
|
@ -37,14 +37,14 @@ class NetworkInitialization {
|
||||||
networkProtocolVersion(() -> FMLNetworkConstants.NETVERSION).
|
networkProtocolVersion(() -> FMLNetworkConstants.NETVERSION).
|
||||||
simpleChannel();
|
simpleChannel();
|
||||||
|
|
||||||
handshakeChannel.messageBuilder(FMLHandshakeMessages.C2SAcknowledge.class, 99).
|
handshakeChannel.messageBuilder(FMLHandshakeMessages.C2SAcknowledge.class, 99, NetworkDirection.LOGIN_TO_SERVER).
|
||||||
loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex).
|
loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex).
|
||||||
decoder(FMLHandshakeMessages.C2SAcknowledge::decode).
|
decoder(FMLHandshakeMessages.C2SAcknowledge::decode).
|
||||||
encoder(FMLHandshakeMessages.C2SAcknowledge::encode).
|
encoder(FMLHandshakeMessages.C2SAcknowledge::encode).
|
||||||
consumer(FMLHandshakeHandler.indexFirst(FMLHandshakeHandler::handleClientAck)).
|
consumer(FMLHandshakeHandler.indexFirst(FMLHandshakeHandler::handleClientAck)).
|
||||||
add();
|
add();
|
||||||
|
|
||||||
handshakeChannel.messageBuilder(FMLHandshakeMessages.S2CModList.class, 1).
|
handshakeChannel.messageBuilder(FMLHandshakeMessages.S2CModList.class, 1, NetworkDirection.LOGIN_TO_CLIENT).
|
||||||
loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex).
|
loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex).
|
||||||
decoder(FMLHandshakeMessages.S2CModList::decode).
|
decoder(FMLHandshakeMessages.S2CModList::decode).
|
||||||
encoder(FMLHandshakeMessages.S2CModList::encode).
|
encoder(FMLHandshakeMessages.S2CModList::encode).
|
||||||
|
@ -52,14 +52,14 @@ class NetworkInitialization {
|
||||||
consumer(FMLHandshakeHandler.biConsumerFor(FMLHandshakeHandler::handleServerModListOnClient)).
|
consumer(FMLHandshakeHandler.biConsumerFor(FMLHandshakeHandler::handleServerModListOnClient)).
|
||||||
add();
|
add();
|
||||||
|
|
||||||
handshakeChannel.messageBuilder(FMLHandshakeMessages.C2SModListReply.class, 2).
|
handshakeChannel.messageBuilder(FMLHandshakeMessages.C2SModListReply.class, 2, NetworkDirection.LOGIN_TO_SERVER).
|
||||||
loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex).
|
loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex).
|
||||||
decoder(FMLHandshakeMessages.C2SModListReply::decode).
|
decoder(FMLHandshakeMessages.C2SModListReply::decode).
|
||||||
encoder(FMLHandshakeMessages.C2SModListReply::encode).
|
encoder(FMLHandshakeMessages.C2SModListReply::encode).
|
||||||
consumer(FMLHandshakeHandler.indexFirst(FMLHandshakeHandler::handleClientModListOnServer)).
|
consumer(FMLHandshakeHandler.indexFirst(FMLHandshakeHandler::handleClientModListOnServer)).
|
||||||
add();
|
add();
|
||||||
|
|
||||||
handshakeChannel.messageBuilder(FMLHandshakeMessages.S2CRegistry.class, 3).
|
handshakeChannel.messageBuilder(FMLHandshakeMessages.S2CRegistry.class, 3, NetworkDirection.LOGIN_TO_CLIENT).
|
||||||
loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex).
|
loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex).
|
||||||
decoder(FMLHandshakeMessages.S2CRegistry::decode).
|
decoder(FMLHandshakeMessages.S2CRegistry::decode).
|
||||||
encoder(FMLHandshakeMessages.S2CRegistry::encode).
|
encoder(FMLHandshakeMessages.S2CRegistry::encode).
|
||||||
|
@ -67,7 +67,7 @@ class NetworkInitialization {
|
||||||
consumer(FMLHandshakeHandler.biConsumerFor(FMLHandshakeHandler::handleRegistryMessage)).
|
consumer(FMLHandshakeHandler.biConsumerFor(FMLHandshakeHandler::handleRegistryMessage)).
|
||||||
add();
|
add();
|
||||||
|
|
||||||
handshakeChannel.messageBuilder(FMLHandshakeMessages.S2CConfigData.class, 4).
|
handshakeChannel.messageBuilder(FMLHandshakeMessages.S2CConfigData.class, 4, NetworkDirection.LOGIN_TO_CLIENT).
|
||||||
loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex).
|
loginIndex(FMLHandshakeMessages.LoginIndexedMessage::getLoginIndex, FMLHandshakeMessages.LoginIndexedMessage::setLoginIndex).
|
||||||
decoder(FMLHandshakeMessages.S2CConfigData::decode).
|
decoder(FMLHandshakeMessages.S2CConfigData::decode).
|
||||||
encoder(FMLHandshakeMessages.S2CConfigData::encode).
|
encoder(FMLHandshakeMessages.S2CConfigData::encode).
|
||||||
|
|
|
@ -22,7 +22,9 @@ package net.minecraftforge.fml.network.simple;
|
||||||
import it.unimi.dsi.fastutil.objects.Object2ObjectArrayMap;
|
import it.unimi.dsi.fastutil.objects.Object2ObjectArrayMap;
|
||||||
import it.unimi.dsi.fastutil.shorts.Short2ObjectArrayMap;
|
import it.unimi.dsi.fastutil.shorts.Short2ObjectArrayMap;
|
||||||
import net.minecraft.network.PacketBuffer;
|
import net.minecraft.network.PacketBuffer;
|
||||||
|
import net.minecraftforge.fml.network.NetworkDirection;
|
||||||
import net.minecraftforge.fml.network.NetworkEvent;
|
import net.minecraftforge.fml.network.NetworkEvent;
|
||||||
|
import net.minecraftforge.fml.network.NetworkHooks;
|
||||||
import net.minecraftforge.fml.network.NetworkInstance;
|
import net.minecraftforge.fml.network.NetworkInstance;
|
||||||
import org.apache.logging.log4j.LogManager;
|
import org.apache.logging.log4j.LogManager;
|
||||||
import org.apache.logging.log4j.Logger;
|
import org.apache.logging.log4j.Logger;
|
||||||
|
@ -68,16 +70,18 @@ public class IndexedMessageCodec
|
||||||
private final int index;
|
private final int index;
|
||||||
private final BiConsumer<MSG,Supplier<NetworkEvent.Context>> messageConsumer;
|
private final BiConsumer<MSG,Supplier<NetworkEvent.Context>> messageConsumer;
|
||||||
private final Class<MSG> messageType;
|
private final Class<MSG> messageType;
|
||||||
|
private final Optional<NetworkDirection> networkDirection;
|
||||||
private Optional<BiConsumer<MSG, Integer>> loginIndexSetter;
|
private Optional<BiConsumer<MSG, Integer>> loginIndexSetter;
|
||||||
private Optional<Function<MSG, Integer>> loginIndexGetter;
|
private Optional<Function<MSG, Integer>> loginIndexGetter;
|
||||||
|
|
||||||
public MessageHandler(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer)
|
public MessageHandler(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer, final Optional<NetworkDirection> networkDirection)
|
||||||
{
|
{
|
||||||
this.index = index;
|
this.index = index;
|
||||||
this.messageType = messageType;
|
this.messageType = messageType;
|
||||||
this.encoder = Optional.ofNullable(encoder);
|
this.encoder = Optional.ofNullable(encoder);
|
||||||
this.decoder = Optional.ofNullable(decoder);
|
this.decoder = Optional.ofNullable(decoder);
|
||||||
this.messageConsumer = messageConsumer;
|
this.messageConsumer = messageConsumer;
|
||||||
|
this.networkDirection = networkDirection;
|
||||||
this.loginIndexGetter = Optional.empty();
|
this.loginIndexGetter = Optional.empty();
|
||||||
this.loginIndexSetter = Optional.empty();
|
this.loginIndexSetter = Optional.empty();
|
||||||
indicies.put((short)(index & 0xff), this);
|
indicies.put((short)(index & 0xff), this);
|
||||||
|
@ -154,10 +158,11 @@ public class IndexedMessageCodec
|
||||||
LOGGER.error(SIMPLENET, "Received invalid discriminator byte {} on channel {}", discriminator, Optional.ofNullable(networkInstance).map(NetworkInstance::getChannelName).map(Objects::toString).orElse("MISSING CHANNEL"));
|
LOGGER.error(SIMPLENET, "Received invalid discriminator byte {} on channel {}", discriminator, Optional.ofNullable(networkInstance).map(NetworkInstance::getChannelName).map(Objects::toString).orElse("MISSING CHANNEL"));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
NetworkHooks.validatePacketDirection(context.get().getDirection(), messageHandler.networkDirection, context.get().getNetworkManager());
|
||||||
tryDecode(payload, context, payloadIndex, messageHandler);
|
tryDecode(payload, context, payloadIndex, messageHandler);
|
||||||
}
|
}
|
||||||
|
|
||||||
<MSG> MessageHandler<MSG> addCodecIndex(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer) {
|
<MSG> MessageHandler<MSG> addCodecIndex(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer, final Optional<NetworkDirection> networkDirection) {
|
||||||
return new MessageHandler<>(index, messageType, encoder, decoder, messageConsumer);
|
return new MessageHandler<>(index, messageType, encoder, decoder, messageConsumer, networkDirection);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,6 +24,7 @@ import net.minecraft.client.Minecraft;
|
||||||
import net.minecraft.network.NetworkManager;
|
import net.minecraft.network.NetworkManager;
|
||||||
import net.minecraft.network.IPacket;
|
import net.minecraft.network.IPacket;
|
||||||
import net.minecraft.network.PacketBuffer;
|
import net.minecraft.network.PacketBuffer;
|
||||||
|
import net.minecraft.util.text.StringTextComponent;
|
||||||
import net.minecraftforge.fml.network.*;
|
import net.minecraftforge.fml.network.*;
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
|
|
||||||
|
@ -83,8 +84,13 @@ public class SimpleChannel
|
||||||
public <MSG> int encodeMessage(MSG message, final PacketBuffer target) {
|
public <MSG> int encodeMessage(MSG message, final PacketBuffer target) {
|
||||||
return this.indexedCodec.build(message, target);
|
return this.indexedCodec.build(message, target);
|
||||||
}
|
}
|
||||||
|
|
||||||
public <MSG> IndexedMessageCodec.MessageHandler<MSG> registerMessage(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer) {
|
public <MSG> IndexedMessageCodec.MessageHandler<MSG> registerMessage(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer) {
|
||||||
return this.indexedCodec.addCodecIndex(index, messageType, encoder, decoder, messageConsumer);
|
return registerMessage(index, messageType, encoder, decoder, messageConsumer, Optional.empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
public <MSG> IndexedMessageCodec.MessageHandler<MSG> registerMessage(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer, final Optional<NetworkDirection> networkDirection) {
|
||||||
|
return this.indexedCodec.addCodecIndex(index, messageType, encoder, decoder, messageConsumer, networkDirection);
|
||||||
}
|
}
|
||||||
|
|
||||||
private <MSG> Pair<PacketBuffer,Integer> toBuffer(MSG msg) {
|
private <MSG> Pair<PacketBuffer,Integer> toBuffer(MSG msg) {
|
||||||
|
@ -137,7 +143,21 @@ public class SimpleChannel
|
||||||
* @return a MessageBuilder
|
* @return a MessageBuilder
|
||||||
*/
|
*/
|
||||||
public <M> MessageBuilder<M> messageBuilder(final Class<M> type, int id) {
|
public <M> MessageBuilder<M> messageBuilder(final Class<M> type, int id) {
|
||||||
return MessageBuilder.forType(this, type, id);
|
return MessageBuilder.forType(this, type, id, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build a new MessageBuilder. The type should implement {@link java.util.function.IntSupplier} if it is a login
|
||||||
|
* packet.
|
||||||
|
* @param type Type of message
|
||||||
|
* @param id id in the indexed codec
|
||||||
|
* @param direction a network direction which will be asserted before any processing of this message occurs. Use to
|
||||||
|
* enforce strict sided handling to prevent spoofing.
|
||||||
|
* @param <M> Type of type
|
||||||
|
* @return a MessageBuilder
|
||||||
|
*/
|
||||||
|
public <M> MessageBuilder<M> messageBuilder(final Class<M> type, int id, NetworkDirection direction) {
|
||||||
|
return MessageBuilder.forType(this, type, id, direction);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static class MessageBuilder<MSG> {
|
public static class MessageBuilder<MSG> {
|
||||||
|
@ -150,12 +170,14 @@ public class SimpleChannel
|
||||||
private Function<MSG, Integer> loginIndexGetter;
|
private Function<MSG, Integer> loginIndexGetter;
|
||||||
private BiConsumer<MSG, Integer> loginIndexSetter;
|
private BiConsumer<MSG, Integer> loginIndexSetter;
|
||||||
private Function<Boolean, List<Pair<String, MSG>>> loginPacketGenerators;
|
private Function<Boolean, List<Pair<String, MSG>>> loginPacketGenerators;
|
||||||
|
private Optional<NetworkDirection> networkDirection;
|
||||||
|
|
||||||
private static <MSG> MessageBuilder<MSG> forType(final SimpleChannel channel, final Class<MSG> type, int id) {
|
private static <MSG> MessageBuilder<MSG> forType(final SimpleChannel channel, final Class<MSG> type, int id, NetworkDirection networkDirection) {
|
||||||
MessageBuilder<MSG> builder = new MessageBuilder<>();
|
MessageBuilder<MSG> builder = new MessageBuilder<>();
|
||||||
builder.channel = channel;
|
builder.channel = channel;
|
||||||
builder.id = id;
|
builder.id = id;
|
||||||
builder.type = type;
|
builder.type = type;
|
||||||
|
builder.networkDirection = Optional.ofNullable(networkDirection);
|
||||||
return builder;
|
return builder;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -215,7 +237,7 @@ public class SimpleChannel
|
||||||
}
|
}
|
||||||
|
|
||||||
public void add() {
|
public void add() {
|
||||||
final IndexedMessageCodec.MessageHandler<MSG> message = this.channel.registerMessage(this.id, this.type, this.encoder, this.decoder, this.consumer);
|
final IndexedMessageCodec.MessageHandler<MSG> message = this.channel.registerMessage(this.id, this.type, this.encoder, this.decoder, this.consumer, this.networkDirection);
|
||||||
if (this.loginIndexSetter != null) {
|
if (this.loginIndexSetter != null) {
|
||||||
message.setLoginIndexSetter(this.loginIndexSetter);
|
message.setLoginIndexSetter(this.loginIndexSetter);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue