Net Handshake phase 1. Validates pairings on client and server.

This commit is contained in:
cpw 2018-09-15 15:44:48 -04:00
parent afead63381
commit 9f2c7c881c
10 changed files with 297 additions and 186 deletions

View file

@ -87,6 +87,7 @@ project(':forge') {
target: 'fmldevserver' target: 'fmldevserver'
] ]
} }
mcVersion = '1.13'
} }
applyPatches { applyPatches {
canonicalizeAccess true canonicalizeAccess true

View file

@ -60,7 +60,7 @@
float f7 = f2 * f4; float f7 = f2 * f4;
- double d3 = 5.0D; - double d3 = 5.0D;
- Vec3d vec3d1 = vec3d.add((double)f6 * 5.0D, (double)f5 * 5.0D, (double)f7 * 5.0D); - Vec3d vec3d1 = vec3d.add((double)f6 * 5.0D, (double)f5 * 5.0D, (double)f7 * 5.0D);
+ double d3 = playerIn.getEntityAttribute(EntityPlayer.REACH_DISTANCE).getAttributeValue(); + double d3 = 6;
+ Vec3d vec3d1 = vec3d.add((double)f6 * d3, (double)f5 * d3, (double)f7 * d3); + Vec3d vec3d1 = vec3d.add((double)f6 * d3, (double)f5 * d3, (double)f7 * d3);
return worldIn.func_200259_a(vec3d, vec3d1, useLiquids ? RayTraceFluidMode.SOURCE_ONLY : RayTraceFluidMode.NEVER, false, false); return worldIn.func_200259_a(vec3d, vec3d1, useLiquids ? RayTraceFluidMode.SOURCE_ONLY : RayTraceFluidMode.NEVER, false, false);
} }

View file

@ -0,0 +1,89 @@
package net.minecraftforge.fml.network;
import net.minecraft.nbt.INBTBase;
import net.minecraft.nbt.NBTTagCompound;
import net.minecraft.nbt.NBTTagList;
import net.minecraft.nbt.NBTTagString;
import net.minecraft.network.PacketBuffer;
import net.minecraftforge.fml.ModList;
import net.minecraftforge.fml.loading.moddiscovery.ModInfo;
import java.util.List;
import java.util.stream.Collectors;
class FMLHandshakeMessage
{
// Login index sequence number
private int index;
void setPacketIndexSequence(int i)
{
this.index = i;
}
int getPacketIndexSequence()
{
return index;
}
/**
* Server to client "list of mods". Always first handshake message.
*/
static class S2CModList extends FMLHandshakeMessage
{
private NBTTagList channels;
private List<String> modList;
S2CModList() {
this.modList = ModList.get().getMods().stream().map(ModInfo::getModId).collect(Collectors.toList());
}
S2CModList(NBTTagCompound nbtTagCompound)
{
this.modList = nbtTagCompound.getTagList("modlist", 8).stream().map(INBTBase::getString).collect(Collectors.toList());
this.channels = nbtTagCompound.getTagList("channels", 10);
}
static S2CModList decode(PacketBuffer packetBuffer)
{
final NBTTagCompound nbtTagCompound = packetBuffer.readCompoundTag();
return new S2CModList(nbtTagCompound);
}
void encode(PacketBuffer packetBuffer)
{
NBTTagCompound tag = new NBTTagCompound();
tag.setTag("modlist",modList.stream().map(NBTTagString::new).collect(Collectors.toCollection(NBTTagList::new)));
tag.setTag("channels", NetworkRegistry.buildChannelVersions());
packetBuffer.writeCompoundTag(tag);
}
String getModList() {
return String.join(",", modList);
}
NBTTagList getChannels() {
return this.channels;
}
}
static class C2SModListReply extends S2CModList
{
C2SModListReply() {
super();
}
C2SModListReply(final NBTTagCompound buffer) {
super(buffer);
}
static C2SModListReply decode(PacketBuffer buffer)
{
return new C2SModListReply(buffer.readCompoundTag());
}
public void encode(PacketBuffer buffer)
{
super.encode(buffer);
}
}
}

View file

@ -1,14 +1,9 @@
package net.minecraftforge.fml.network; package net.minecraftforge.fml.network;
import io.netty.util.AttributeKey; import io.netty.util.AttributeKey;
import net.minecraft.nbt.INBTBase;
import net.minecraft.nbt.NBTTagCompound;
import net.minecraft.nbt.NBTTagList;
import net.minecraft.nbt.NBTTagString;
import net.minecraft.network.NetworkManager; import net.minecraft.network.NetworkManager;
import net.minecraft.network.PacketBuffer; import net.minecraft.util.text.ITextComponent;
import net.minecraftforge.fml.ModList; import net.minecraft.util.text.TextComponentString;
import net.minecraftforge.fml.loading.moddiscovery.ModInfo;
import net.minecraftforge.fml.network.simple.SimpleChannel; import net.minecraftforge.fml.network.simple.SimpleChannel;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
@ -19,7 +14,6 @@ import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.function.Supplier; import java.util.function.Supplier;
import java.util.stream.Collectors;
public class FMLNetworking public class FMLNetworking
{ {
@ -49,24 +43,24 @@ public class FMLNetworking
public static class FMLHandshake { public static class FMLHandshake {
private static SimpleChannel channel; private static SimpleChannel channel;
private static List<Supplier<HandshakeMessage>> messages = Arrays.asList(HandshakeMessage.S2CModList::new); private static List<Supplier<FMLHandshakeMessage>> messages = Arrays.asList(FMLHandshakeMessage.S2CModList::new);
private List<HandshakeMessage> sentMessages = new ArrayList<>(); private List<FMLHandshakeMessage> sentMessages = new ArrayList<>();
static { static {
channel = NetworkRegistry.ChannelBuilder.named(NetworkHooks.FMLHANDSHAKE). channel = NetworkRegistry.ChannelBuilder.named(NetworkHooks.FMLHANDSHAKE).
clientAcceptedVersions(a -> true). clientAcceptedVersions(a -> true).
serverAcceptedVersions(a -> true). serverAcceptedVersions(a -> true).
networkProtocolVersion(() -> NetworkHooks.NETVERSION). networkProtocolVersion(() -> NetworkHooks.NETVERSION).
simpleChannel(); simpleChannel();
channel.messageBuilder(HandshakeMessage.S2CModList.class, 1). channel.messageBuilder(FMLHandshakeMessage.S2CModList.class, 1).
decoder(HandshakeMessage.S2CModList::decode). decoder(FMLHandshakeMessage.S2CModList::decode).
encoder(HandshakeMessage.S2CModList::encode). encoder(FMLHandshakeMessage.S2CModList::encode).
loginIndex(HandshakeMessage.S2CModList::setPacketIndexSequence). loginIndex(FMLHandshakeMessage::getPacketIndexSequence, FMLHandshakeMessage::setPacketIndexSequence).
consumer((m,c)->getHandshake(c).handleServerModListOnClient(m, c)). consumer((m,c)->getHandshake(c).handleServerModListOnClient(m, c)).
add(); add();
channel.messageBuilder(HandshakeMessage.C2SModListReply.class, 2). channel.messageBuilder(FMLHandshakeMessage.C2SModListReply.class, 2).
loginIndex(HandshakeMessage::setPacketIndexSequence). loginIndex(FMLHandshakeMessage::getPacketIndexSequence, FMLHandshakeMessage::setPacketIndexSequence).
decoder(HandshakeMessage.C2SModListReply::decode). decoder(FMLHandshakeMessage.C2SModListReply::decode).
encoder(HandshakeMessage.C2SModListReply::encode). encoder(FMLHandshakeMessage.C2SModListReply::encode).
consumer((m,c) -> getHandshake(c).handleClientModListOnServer(m,c)). consumer((m,c) -> getHandshake(c).handleClientModListOnServer(m,c)).
add(); add();
} }
@ -82,30 +76,63 @@ public class FMLNetworking
this.manager = networkManager; this.manager = networkManager;
} }
public void handleServerModListOnClient(HandshakeMessage.S2CModList serverModList, Supplier<NetworkEvent.Context> c) private void handleServerModListOnClient(FMLHandshakeMessage.S2CModList serverModList, Supplier<NetworkEvent.Context> c)
{ {
LOGGER.debug(FMLHSMARKER, "Received S2CModList packet with index {}", serverModList.getPacketIndexSequence()); LOGGER.debug(FMLHSMARKER, "Logging into server with mod list [{}]", serverModList.getModList());
boolean accepted = NetworkRegistry.validateClientChannels(serverModList.getChannels());
c.get().setPacketHandled(true); c.get().setPacketHandled(true);
final HandshakeMessage.C2SModListReply reply = new HandshakeMessage.C2SModListReply(); if (!accepted) {
channel.sendLogin(reply, c.get().getNetworkManager(), c.get().getDirection().reply(), reply.getPacketIndexSequence()); LOGGER.error(FMLHSMARKER, "Terminating connection with server, mismatched mod list");
c.get().getNetworkManager().closeChannel(new TextComponentString("Connection closed - mismatched mod channel list"));
return;
}
final FMLHandshakeMessage.C2SModListReply reply = new FMLHandshakeMessage.C2SModListReply();
reply.setPacketIndexSequence(serverModList.getPacketIndexSequence());
channel.reply(reply, c.get());
LOGGER.debug(FMLHSMARKER, "Sent C2SModListReply packet with index {}", reply.getPacketIndexSequence()); LOGGER.debug(FMLHSMARKER, "Sent C2SModListReply packet with index {}", reply.getPacketIndexSequence());
} }
private void handleClientModListOnServer(HandshakeMessage.C2SModListReply m, Supplier<NetworkEvent.Context> c) private void handleClientModListOnServer(FMLHandshakeMessage.C2SModListReply clientModList, Supplier<NetworkEvent.Context> c)
{ {
LOGGER.debug(FMLHSMARKER, "Received C2SModListReply with index {}", m.getPacketIndexSequence()); LOGGER.debug(FMLHSMARKER, "Received client connection with modlist [{}]", clientModList.getModList());
final HandshakeMessage message = this.sentMessages.stream().filter(ob -> ob.getPacketIndexSequence() == m.getPacketIndexSequence()).findFirst().orElseThrow(() -> new RuntimeException("Unexpected reply from client")); final FMLHandshakeMessage message = this.sentMessages.stream().filter(ob -> ob.getPacketIndexSequence() == clientModList.getPacketIndexSequence()).findFirst().orElseThrow(() -> new RuntimeException("Unexpected reply from client"));
boolean removed = this.sentMessages.remove(message); boolean removed = this.sentMessages.remove(message);
boolean accepted = NetworkRegistry.validateServerChannels(clientModList.getChannels());
c.get().setPacketHandled(true); c.get().setPacketHandled(true);
if (!accepted) {
LOGGER.error(FMLHSMARKER, "Terminating connection with client, mismatched mod list");
c.get().getNetworkManager().closeChannel(new TextComponentString("Connection closed - mismatched mod channel list"));
return;
}
LOGGER.debug(FMLHSMARKER, "Cleared original message {}", removed); LOGGER.debug(FMLHSMARKER, "Cleared original message {}", removed);
} }
/**
* Design of handshake.
*
* After {@link net.minecraft.server.network.NetHandlerLoginServer} enters the {@link net.minecraft.server.network.NetHandlerLoginServer.LoginState#NEGOTIATING}
* state, this will be ticked once per server tick.
*
* FML will send packets, from Server to Client, from the messages queue until the queue is drained. Each message
* will be indexed, and placed into the "pending acknowledgement" queue.
*
* The client should send an acknowledgement for every packet that has a positive index, containing
* that index (and maybe other data as well).
*
* As indexed packets are received at the server, they will be removed from the "pending acknowledgement" queue.
*
* Once the pending queue is drained, this method returns true - indicating that login processing can proceed to
* the next step.
*
* @return true if there is no more need to tick this login connection.
*/
public boolean tickServer() public boolean tickServer()
{ {
if (packetPosition < messages.size()) { if (packetPosition < messages.size()) {
final HandshakeMessage message = messages.get(packetPosition).get(); final FMLHandshakeMessage message = messages.get(packetPosition).get();
message.setPacketIndexSequence(packetPosition);
LOGGER.debug(FMLHSMARKER, "Sending ticking packet {} index {}", message.getClass().getName(), message.getPacketIndexSequence()); LOGGER.debug(FMLHSMARKER, "Sending ticking packet {} index {}", message.getClass().getName(), message.getPacketIndexSequence());
channel.sendLogin(message, this.manager, this.direction, packetPosition); channel.sendTo(message, this.manager, this.direction);
sentMessages.add(message); sentMessages.add(message);
packetPosition++; packetPosition++;
} }
@ -121,57 +148,4 @@ public class FMLNetworking
} }
static class HandshakeMessage
{
private int index;
public void setPacketIndexSequence(int i)
{
this.index = i;
}
public int getPacketIndexSequence()
{
return index;
}
static class S2CModList extends HandshakeMessage
{
private List<String> modList;
S2CModList() {
this.modList = ModList.get().getMods().stream().map(ModInfo::getModId).collect(Collectors.toList());
}
S2CModList(NBTTagCompound nbtTagCompound)
{
this.modList = nbtTagCompound.getTagList("list", 8).stream().map(INBTBase::getString).collect(Collectors.toList());
}
public static S2CModList decode(PacketBuffer packetBuffer)
{
final NBTTagCompound nbtTagCompound = packetBuffer.readCompoundTag();
return new S2CModList(nbtTagCompound);
}
public void encode(PacketBuffer packetBuffer)
{
NBTTagCompound tag = new NBTTagCompound();
tag.setTag("list",modList.stream().map(NBTTagString::new).collect(Collectors.toCollection(NBTTagList::new)));
packetBuffer.writeCompoundTag(tag);
}
}
static class C2SModListReply extends HandshakeMessage
{
public static C2SModListReply decode(PacketBuffer buffer)
{
return new C2SModListReply();
}
public void encode(PacketBuffer buffer)
{
}
}
}
} }

View file

@ -29,6 +29,7 @@ import net.minecraft.network.play.server.SPacketCustomPayload;
import net.minecraft.util.ResourceLocation; import net.minecraft.util.ResourceLocation;
import net.minecraftforge.fml.LogicalSide; import net.minecraftforge.fml.LogicalSide;
import net.minecraftforge.fml.UnsafeHacks; import net.minecraftforge.fml.UnsafeHacks;
import org.apache.commons.lang3.tuple.Pair;
import java.util.function.BiFunction; import java.util.function.BiFunction;
import java.util.function.Function; import java.util.function.Function;
@ -84,12 +85,12 @@ public enum NetworkDirection
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public <T extends Packet<?>> ICustomPacket<T> buildPacket(PacketBuffer packetBuffer, ResourceLocation channelName, int index) public <T extends Packet<?>> ICustomPacket<T> buildPacket(Pair<PacketBuffer,Integer> packetData, ResourceLocation channelName)
{ {
ICustomPacket<T> packet = (ICustomPacket<T>)UnsafeHacks.newInstance(getPacketClass()); ICustomPacket<T> packet = (ICustomPacket<T>)UnsafeHacks.newInstance(getPacketClass());
packet.setName(channelName); packet.setName(channelName);
packet.setData(packetBuffer); packet.setData(packetData.getLeft());
packet.setIndex(index); packet.setIndex(packetData.getRight());
return packet; return packet;
} }
} }

View file

@ -32,10 +32,13 @@ public class NetworkEvent extends Event
{ {
private final PacketBuffer payload; private final PacketBuffer payload;
private final Supplier<Context> source; private final Supplier<Context> source;
private NetworkEvent(ICustomPacket<?> payload, Supplier<Context> source) private final int loginIndex;
private NetworkEvent(final ICustomPacket<?> payload, final Supplier<Context> source)
{ {
this.payload = payload.getData(); this.payload = payload.getData();
this.source = source; this.source = source;
this.loginIndex = payload.getIndex();
} }
public PacketBuffer getPayload() public PacketBuffer getPayload()
@ -48,9 +51,13 @@ public class NetworkEvent extends Event
return source; return source;
} }
public int getLoginIndex()
{
return loginIndex;
}
public static class ServerCustomPayloadEvent extends NetworkEvent public static class ServerCustomPayloadEvent extends NetworkEvent
{ {
ServerCustomPayloadEvent(final ICustomPacket<?> payload, final Supplier<Context> source) { ServerCustomPayloadEvent(final ICustomPacket<?> payload, final Supplier<Context> source) {
super(payload, source); super(payload, source);
} }
@ -61,39 +68,20 @@ public class NetworkEvent extends Event
super(payload, source); super(payload, source);
} }
} }
public static class ServerCustomPayloadLoginEvent extends ServerCustomPayloadEvent implements ILoginIndex { public static class ServerCustomPayloadLoginEvent extends ServerCustomPayloadEvent {
private final int index;
ServerCustomPayloadLoginEvent(ICustomPacket<?> payload, Supplier<Context> source) ServerCustomPayloadLoginEvent(ICustomPacket<?> payload, Supplier<Context> source)
{ {
super(payload, source); super(payload, source);
this.index = payload.getIndex();
}
public int getIndex()
{
return index;
} }
} }
public static class ClientCustomPayloadLoginEvent extends ClientCustomPayloadEvent implements ILoginIndex { public static class ClientCustomPayloadLoginEvent extends ClientCustomPayloadEvent {
private final int index;
ClientCustomPayloadLoginEvent(ICustomPacket<?> payload, Supplier<Context> source) ClientCustomPayloadLoginEvent(ICustomPacket<?> payload, Supplier<Context> source)
{ {
super(payload, source); super(payload, source);
this.index = payload.getIndex();
}
public int getIndex()
{
return index;
} }
} }
public interface ILoginIndex {
int getIndex();
}
/** /**
* Context for {@link NetworkEvent} * Context for {@link NetworkEvent}
*/ */

View file

@ -19,6 +19,7 @@
package net.minecraftforge.fml.network; package net.minecraftforge.fml.network;
import net.minecraft.nbt.INBTBase;
import net.minecraft.network.NetworkManager; import net.minecraft.network.NetworkManager;
import net.minecraft.network.PacketBuffer; import net.minecraft.network.PacketBuffer;
import net.minecraft.util.ResourceLocation; import net.minecraft.util.ResourceLocation;
@ -38,7 +39,7 @@ public class NetworkInstance
} }
private final ResourceLocation channelName; private final ResourceLocation channelName;
private final Supplier<String> networkProtocolVersion; private final String networkProtocolVersion;
private final Predicate<String> clientAcceptedVersions; private final Predicate<String> clientAcceptedVersions;
private final Predicate<String> serverAcceptedVersions; private final Predicate<String> serverAcceptedVersions;
private final IEventBus networkEventBus; private final IEventBus networkEventBus;
@ -46,7 +47,7 @@ public class NetworkInstance
NetworkInstance(ResourceLocation channelName, Supplier<String> networkProtocolVersion, Predicate<String> clientAcceptedVersions, Predicate<String> serverAcceptedVersions) NetworkInstance(ResourceLocation channelName, Supplier<String> networkProtocolVersion, Predicate<String> clientAcceptedVersions, Predicate<String> serverAcceptedVersions)
{ {
this.channelName = channelName; this.channelName = channelName;
this.networkProtocolVersion = networkProtocolVersion; this.networkProtocolVersion = networkProtocolVersion.get();
this.clientAcceptedVersions = clientAcceptedVersions; this.clientAcceptedVersions = clientAcceptedVersions;
this.serverAcceptedVersions = serverAcceptedVersions; this.serverAcceptedVersions = serverAcceptedVersions;
this.networkEventBus = IEventBus.create(this::handleError); this.networkEventBus = IEventBus.create(this::handleError);
@ -78,4 +79,15 @@ public class NetworkInstance
} }
String getNetworkProtocolVersion() {
return networkProtocolVersion;
}
boolean tryServerVersionOnClient(final String serverVersion) {
return this.clientAcceptedVersions.test(serverVersion);
}
boolean tryClientVersionOnServer(final String clientVersion) {
return this.serverAcceptedVersions.test(clientVersion);
}
} }

View file

@ -19,10 +19,12 @@
package net.minecraftforge.fml.network; package net.minecraftforge.fml.network;
import io.netty.util.AttributeKey; import net.minecraft.nbt.NBTTagCompound;
import net.minecraft.nbt.NBTTagList;
import net.minecraft.util.ResourceLocation; import net.minecraft.util.ResourceLocation;
import net.minecraftforge.fml.network.event.EventNetworkChannel; import net.minecraftforge.fml.network.event.EventNetworkChannel;
import net.minecraftforge.fml.network.simple.SimpleChannel; import net.minecraftforge.fml.network.simple.SimpleChannel;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.Marker; import org.apache.logging.log4j.Marker;
@ -35,6 +37,7 @@ import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.function.Predicate; import java.util.function.Predicate;
import java.util.function.Supplier; import java.util.function.Supplier;
import java.util.stream.Collectors;
public class NetworkRegistry public class NetworkRegistry
{ {
@ -71,6 +74,53 @@ public class NetworkRegistry
return Optional.ofNullable(instances.get(resourceLocation)); return Optional.ofNullable(instances.get(resourceLocation));
} }
static NBTTagList buildChannelVersions() {
return instances.entrySet().stream().map(e-> {
final NBTTagCompound tag = new NBTTagCompound();
tag.setString("name", e.getKey().toString());
tag.setString("version", e.getValue().getNetworkProtocolVersion());
return tag;
}).collect(Collectors.toCollection(NBTTagList::new));
}
static boolean validateClientChannels(final NBTTagList channels) {
final List<Pair<ResourceLocation, Boolean>> results = channels.stream().map(t -> {
NBTTagCompound tag = (NBTTagCompound) t;
final ResourceLocation rl = new ResourceLocation(tag.getString("name"));
final String serverVersion = tag.getString("version");
boolean test = instances.get(rl).tryServerVersionOnClient(serverVersion);
LOGGER.debug(NETREGISTRY, "Channel {} : Client version test of ''{}'' from server : {}", rl, serverVersion, test);
return Pair.of(rl, test);
}).filter(p->!p.getRight()).collect(Collectors.toList());
if (!results.isEmpty()) {
LOGGER.error(NETREGISTRY, "Channels [{}] rejected their server side version number",
results.stream().map(Pair::getLeft).map(Object::toString).collect(Collectors.joining(",")));
return false;
}
LOGGER.debug(NETREGISTRY, "Accepting channel list from server");
return true;
}
static boolean validateServerChannels(final NBTTagList channels) {
final List<Pair<ResourceLocation, Boolean>> results = channels.stream().map(t -> {
NBTTagCompound tag = (NBTTagCompound) t;
final ResourceLocation rl = new ResourceLocation(tag.getString("name"));
final String clientVersion = tag.getString("version");
boolean test = instances.get(rl).tryClientVersionOnServer(clientVersion);
LOGGER.debug(NETREGISTRY, "Channel {} : Server version test of ''{}'' from client : {}", rl, clientVersion, test);
return Pair.of(rl, test);
}).filter(p->!p.getRight()).collect(Collectors.toList());
if (!results.isEmpty()) {
LOGGER.error(NETREGISTRY, "Channels [{}] rejected their client side version number",
results.stream().map(Pair::getLeft).map(Object::toString).collect(Collectors.joining(",")));
return false;
}
LOGGER.debug(NETREGISTRY, "Accepting channel list from client");
return true;
}
public static class ChannelBuilder { public static class ChannelBuilder {
private ResourceLocation channelName; private ResourceLocation channelName;
private Supplier<String> networkProtocolVersion; private Supplier<String> networkProtocolVersion;

View file

@ -21,9 +21,7 @@ package net.minecraftforge.fml.network.simple;
import it.unimi.dsi.fastutil.objects.Object2ObjectArrayMap; import it.unimi.dsi.fastutil.objects.Object2ObjectArrayMap;
import it.unimi.dsi.fastutil.shorts.Short2ObjectArrayMap; import it.unimi.dsi.fastutil.shorts.Short2ObjectArrayMap;
import net.minecraft.network.Packet;
import net.minecraft.network.PacketBuffer; import net.minecraft.network.PacketBuffer;
import net.minecraftforge.fml.network.ICustomPacket;
import net.minecraftforge.fml.network.NetworkEvent; import net.minecraftforge.fml.network.NetworkEvent;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Logger;
@ -32,7 +30,6 @@ import org.apache.logging.log4j.MarkerManager;
import java.util.Optional; import java.util.Optional;
import java.util.function.BiConsumer; import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Function; import java.util.function.Function;
import java.util.function.Supplier; import java.util.function.Supplier;
@ -40,69 +37,86 @@ public class IndexedMessageCodec
{ {
private static final Logger LOGGER = LogManager.getLogger(); private static final Logger LOGGER = LogManager.getLogger();
private static final Marker SIMPLENET = MarkerManager.getMarker("SIMPLENET"); private static final Marker SIMPLENET = MarkerManager.getMarker("SIMPLENET");
private final Short2ObjectArrayMap<CodecIndex<?>> indicies = new Short2ObjectArrayMap<>(); private final Short2ObjectArrayMap<MessageHandler<?>> indicies = new Short2ObjectArrayMap<>();
private final Object2ObjectArrayMap<Class<?>, CodecIndex<?>> types = new Object2ObjectArrayMap<>(); private final Object2ObjectArrayMap<Class<?>, MessageHandler<?>> types = new Object2ObjectArrayMap<>();
@SuppressWarnings("unchecked")
public <MSG> MessageHandler<MSG> findMessageType(final MSG msgToReply) {
return (MessageHandler<MSG>) types.get(msgToReply.getClass());
}
@SuppressWarnings("OptionalUsedAsFieldOrParameterType") @SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class CodecIndex<MSG> class MessageHandler<MSG>
{ {
private final Optional<BiConsumer<MSG, PacketBuffer>> encoder; private final Optional<BiConsumer<MSG, PacketBuffer>> encoder;
private final Optional<Function<PacketBuffer, MSG>> decoder; private final Optional<Function<PacketBuffer, MSG>> decoder;
private final int index; private final int index;
private final BiConsumer<MSG,Supplier<NetworkEvent.Context>> messageConsumer; private final BiConsumer<MSG,Supplier<NetworkEvent.Context>> messageConsumer;
private final Class<MSG> messageType; private final Class<MSG> messageType;
private Optional<BiConsumer<MSG, Integer>> loginIndexFunction; private Optional<BiConsumer<MSG, Integer>> loginIndexSetter;
private Optional<Function<MSG, Integer>> loginIndexGetter;
public CodecIndex(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer) public MessageHandler(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer)
{ {
this.index = index; this.index = index;
this.messageType = messageType; this.messageType = messageType;
this.encoder = Optional.ofNullable(encoder); this.encoder = Optional.ofNullable(encoder);
this.decoder = Optional.ofNullable(decoder); this.decoder = Optional.ofNullable(decoder);
this.messageConsumer = messageConsumer; this.messageConsumer = messageConsumer;
this.loginIndexFunction = Optional.empty(); this.loginIndexGetter = Optional.empty();
this.loginIndexSetter = Optional.empty();
indicies.put((short)(index & 0xff), this); indicies.put((short)(index & 0xff), this);
types.put(messageType, this); types.put(messageType, this);
} }
public void setLoginIndexFunction(BiConsumer<MSG, Integer> loginIndexFunction) void setLoginIndexSetter(BiConsumer<MSG, Integer> loginIndexSetter)
{ {
this.loginIndexFunction = Optional.of(loginIndexFunction); this.loginIndexSetter = Optional.of(loginIndexSetter);
} }
public Optional<BiConsumer<MSG, Integer>> getLoginIndexFunction() { Optional<BiConsumer<MSG, Integer>> getLoginIndexSetter() {
return this.loginIndexFunction; return this.loginIndexSetter;
}
void setLoginIndexGetter(Function<MSG, Integer> loginIndexGetter) {
this.loginIndexGetter = Optional.of(loginIndexGetter);
}
public Optional<Function<MSG, Integer>> getLoginIndexGetter() {
return this.loginIndexGetter;
} }
} }
private static <M> void tryDecode(PacketBuffer payload, Supplier<NetworkEvent.Context> context, CodecIndex<M> codec)
{
codec.decoder.map(d->d.apply(payload)).ifPresent(m->codec.messageConsumer.accept(m, context));
}
private static <M> void tryDecode(PacketBuffer payload, Supplier<NetworkEvent.Context> context, int payloadIndex, CodecIndex<M> codec) private static <M> void tryDecode(PacketBuffer payload, Supplier<NetworkEvent.Context> context, int payloadIndex, MessageHandler<M> codec)
{ {
codec.decoder.map(d->d.apply(payload)). codec.decoder.map(d->d.apply(payload)).
map(p->{ codec.getLoginIndexFunction().ifPresent(f-> f.accept(p, payloadIndex)); return p; }). map(p->{
ifPresent(m->codec.messageConsumer.accept(m, context)); // Only run the loginIndex function for payloadIndexed packets (login)
if (payloadIndex != Integer.MIN_VALUE)
{
codec.getLoginIndexSetter().ifPresent(f-> f.accept(p, payloadIndex));
}
return p;
}).ifPresent(m->codec.messageConsumer.accept(m, context));
} }
private static <M> void tryEncode(PacketBuffer target, M message, CodecIndex<M> codec) { private static <M> int tryEncode(PacketBuffer target, M message, MessageHandler<M> codec) {
codec.encoder.ifPresent(encoder->{ codec.encoder.ifPresent(encoder->{
target.writeByte(codec.index & 0xff); target.writeByte(codec.index & 0xff);
encoder.accept(message, target); encoder.accept(message, target);
}); });
return codec.loginIndexGetter.orElse(m -> Integer.MIN_VALUE).apply(message);
} }
public <MSG> void build(MSG message, PacketBuffer target) public <MSG> int build(MSG message, PacketBuffer target)
{ {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
CodecIndex<MSG> codecIndex = (CodecIndex<MSG>)types.get(message.getClass()); MessageHandler<MSG> messageHandler = (MessageHandler<MSG>)types.get(message.getClass());
if (codecIndex == null) { if (messageHandler == null) {
LOGGER.error(SIMPLENET, "Received invalid message {}", message.getClass().getName()); LOGGER.error(SIMPLENET, "Received invalid message {}", message.getClass().getName());
throw new IllegalArgumentException("Invalid message "+message.getClass().getName()); throw new IllegalArgumentException("Invalid message "+message.getClass().getName());
} }
tryEncode(target, message, codecIndex); return tryEncode(target, message, messageHandler);
} }
void consume(PacketBuffer payload, int payloadIndex, Supplier<NetworkEvent.Context> context) { void consume(PacketBuffer payload, int payloadIndex, Supplier<NetworkEvent.Context> context) {
@ -111,30 +125,15 @@ public class IndexedMessageCodec
return; return;
} }
short discriminator = payload.readUnsignedByte(); short discriminator = payload.readUnsignedByte();
final CodecIndex<?> codecIndex = indicies.get(discriminator); final MessageHandler<?> messageHandler = indicies.get(discriminator);
if (codecIndex == null) { if (messageHandler == null) {
LOGGER.error(SIMPLENET, "Received invalid discriminator byte {}", discriminator); LOGGER.error(SIMPLENET, "Received invalid discriminator byte {}", discriminator);
return; return;
} }
tryDecode(payload, context, payloadIndex, codecIndex); tryDecode(payload, context, payloadIndex, messageHandler);
} }
void consume(PacketBuffer payload, Supplier<NetworkEvent.Context> context) { <MSG> MessageHandler<MSG> addCodecIndex(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer) {
// no data in empty payload return new MessageHandler<>(index, messageType, encoder, decoder, messageConsumer);
if (payload == null) {
LOGGER.error(SIMPLENET, "Received empty payload");
return;
}
short discriminator = payload.readUnsignedByte();
final CodecIndex<?> codecIndex = indicies.get(discriminator);
if (codecIndex == null) {
LOGGER.error(SIMPLENET, "Received invalid discriminator byte {}", discriminator);
return;
}
tryDecode(payload, context, codecIndex);
}
<MSG> CodecIndex<MSG> addCodecIndex(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer) {
return new CodecIndex<>(index, messageType, encoder, decoder, messageConsumer);
} }
} }

View file

@ -24,18 +24,14 @@ import net.minecraft.client.Minecraft;
import net.minecraft.network.NetworkManager; import net.minecraft.network.NetworkManager;
import net.minecraft.network.Packet; import net.minecraft.network.Packet;
import net.minecraft.network.PacketBuffer; import net.minecraft.network.PacketBuffer;
import net.minecraft.network.play.client.CPacketCustomPayload;
import net.minecraftforge.fml.network.ICustomPacket; import net.minecraftforge.fml.network.ICustomPacket;
import net.minecraftforge.fml.network.NetworkDirection; import net.minecraftforge.fml.network.NetworkDirection;
import net.minecraftforge.fml.network.NetworkEvent; import net.minecraftforge.fml.network.NetworkEvent;
import net.minecraftforge.fml.network.NetworkInstance; import net.minecraftforge.fml.network.NetworkInstance;
import org.apache.commons.lang3.tuple.Pair;
import java.util.function.BiConsumer; import java.util.function.BiConsumer;
import java.util.function.BinaryOperator;
import java.util.function.Function; import java.util.function.Function;
import java.util.function.IntBinaryOperator;
import java.util.function.IntConsumer;
import java.util.function.IntFunction;
import java.util.function.Supplier; import java.util.function.Supplier;
public class SimpleChannel public class SimpleChannel
@ -52,42 +48,38 @@ public class SimpleChannel
private void networkEventListener(final NetworkEvent networkEvent) private void networkEventListener(final NetworkEvent networkEvent)
{ {
if (networkEvent instanceof NetworkEvent.ILoginIndex) this.indexedCodec.consume(networkEvent.getPayload(), networkEvent.getLoginIndex(), networkEvent.getSource());
{
this.indexedCodec.consume(networkEvent.getPayload(), ((NetworkEvent.ILoginIndex)networkEvent).getIndex(), networkEvent.getSource());
}
else
{
this.indexedCodec.consume(networkEvent.getPayload(), networkEvent.getSource());
}
} }
public <MSG> void encodeMessage(MSG message, final PacketBuffer target) { public <MSG> int encodeMessage(MSG message, final PacketBuffer target) {
this.indexedCodec.build(message, target); return this.indexedCodec.build(message, target);
} }
public <MSG> IndexedMessageCodec.CodecIndex<MSG> registerMessage(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer) { public <MSG> IndexedMessageCodec.MessageHandler<MSG> registerMessage(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer) {
return this.indexedCodec.addCodecIndex(index, messageType, encoder, decoder, messageConsumer); return this.indexedCodec.addCodecIndex(index, messageType, encoder, decoder, messageConsumer);
} }
private <MSG> PacketBuffer toBuffer(MSG msg) { private <MSG> Pair<PacketBuffer,Integer> toBuffer(MSG msg) {
final PacketBuffer bufIn = new PacketBuffer(Unpooled.buffer()); final PacketBuffer bufIn = new PacketBuffer(Unpooled.buffer());
encodeMessage(msg, bufIn); int index = encodeMessage(msg, bufIn);
return bufIn; return Pair.of(bufIn, index);
} }
public <MSG> void sendToServer(MSG message) public <MSG> void sendToServer(MSG message)
{ {
sendTo(message, Minecraft.getMinecraft().getConnection().getNetworkManager(), NetworkDirection.PLAY_TO_SERVER); sendTo(message, Minecraft.getMinecraft().getConnection().getNetworkManager(), NetworkDirection.PLAY_TO_SERVER);
} }
public <MSG> void sendTo(MSG message, NetworkManager manager, NetworkDirection direction) { public <MSG> void sendTo(MSG message, NetworkManager manager, NetworkDirection direction)
ICustomPacket<Packet<?>> payload = direction.buildPacket(toBuffer(message), instance.getChannelName(), -1); {
ICustomPacket<Packet<?>> payload = direction.buildPacket(toBuffer(message), instance.getChannelName());
manager.sendPacket(payload.getThis()); manager.sendPacket(payload.getThis());
} }
public <MSG> void sendLogin(MSG message, NetworkManager manager, NetworkDirection direction, int packetIndex) { public <MSG> void reply(MSG msgToReply, NetworkEvent.Context context)
ICustomPacket<Packet<?>> payload = direction.buildPacket(toBuffer(message), instance.getChannelName(), packetIndex); {
manager.sendPacket(payload.getThis()); sendTo(msgToReply, context.getNetworkManager(), context.getDirection().reply());
} }
public <M> MessageBuilder<M> messageBuilder(final Class<M> type, int id) { public <M> MessageBuilder<M> messageBuilder(final Class<M> type, int id) {
return MessageBuilder.forType(this, type, id); return MessageBuilder.forType(this, type, id);
} }
@ -99,7 +91,8 @@ public class SimpleChannel
private BiConsumer<MSG, PacketBuffer> encoder; private BiConsumer<MSG, PacketBuffer> encoder;
private Function<PacketBuffer, MSG> decoder; private Function<PacketBuffer, MSG> decoder;
private BiConsumer<MSG, Supplier<NetworkEvent.Context>> consumer; private BiConsumer<MSG, Supplier<NetworkEvent.Context>> consumer;
private BiConsumer<MSG, Integer> loginIndexFunction; private Function<MSG, Integer> loginIndexGetter;
private BiConsumer<MSG, Integer> loginIndexSetter;
private static <MSG> MessageBuilder<MSG> forType(final SimpleChannel channel, final Class<MSG> type, int id) { private static <MSG> MessageBuilder<MSG> forType(final SimpleChannel channel, final Class<MSG> type, int id) {
MessageBuilder<MSG> builder = new MessageBuilder<>(); MessageBuilder<MSG> builder = new MessageBuilder<>();
@ -119,8 +112,9 @@ public class SimpleChannel
return this; return this;
} }
public MessageBuilder<MSG> loginIndex(BiConsumer<MSG, Integer> loginIndexFunction) { public MessageBuilder<MSG> loginIndex(Function<MSG, Integer> loginIndexGetter, BiConsumer<MSG, Integer> loginIndexSetter) {
this.loginIndexFunction = loginIndexFunction; this.loginIndexGetter = loginIndexGetter;
this.loginIndexSetter = loginIndexSetter;
return this; return this;
} }
public MessageBuilder<MSG> consumer(BiConsumer<MSG, Supplier<NetworkEvent.Context>> consumer) { public MessageBuilder<MSG> consumer(BiConsumer<MSG, Supplier<NetworkEvent.Context>> consumer) {
@ -129,9 +123,12 @@ public class SimpleChannel
} }
public void add() { public void add() {
final IndexedMessageCodec.CodecIndex<MSG> message = this.channel.registerMessage(this.id, this.type, this.encoder, this.decoder, this.consumer); final IndexedMessageCodec.MessageHandler<MSG> message = this.channel.registerMessage(this.id, this.type, this.encoder, this.decoder, this.consumer);
if (this.loginIndexFunction != null) { if (this.loginIndexSetter != null) {
message.setLoginIndexFunction(this.loginIndexFunction); message.setLoginIndexSetter(this.loginIndexSetter);
}
if (this.loginIndexGetter != null) {
message.setLoginIndexGetter(this.loginIndexGetter);
} }
} }
} }