Net Handshake phase 1. Validates pairings on client and server.

This commit is contained in:
cpw 2018-09-15 15:44:48 -04:00
parent afead63381
commit 9f2c7c881c
10 changed files with 297 additions and 186 deletions

View File

@ -87,6 +87,7 @@ project(':forge') {
target: 'fmldevserver'
]
}
mcVersion = '1.13'
}
applyPatches {
canonicalizeAccess true

View File

@ -60,7 +60,7 @@
float f7 = f2 * f4;
- double d3 = 5.0D;
- Vec3d vec3d1 = vec3d.add((double)f6 * 5.0D, (double)f5 * 5.0D, (double)f7 * 5.0D);
+ double d3 = playerIn.getEntityAttribute(EntityPlayer.REACH_DISTANCE).getAttributeValue();
+ double d3 = 6;
+ Vec3d vec3d1 = vec3d.add((double)f6 * d3, (double)f5 * d3, (double)f7 * d3);
return worldIn.func_200259_a(vec3d, vec3d1, useLiquids ? RayTraceFluidMode.SOURCE_ONLY : RayTraceFluidMode.NEVER, false, false);
}

View File

@ -0,0 +1,89 @@
package net.minecraftforge.fml.network;
import net.minecraft.nbt.INBTBase;
import net.minecraft.nbt.NBTTagCompound;
import net.minecraft.nbt.NBTTagList;
import net.minecraft.nbt.NBTTagString;
import net.minecraft.network.PacketBuffer;
import net.minecraftforge.fml.ModList;
import net.minecraftforge.fml.loading.moddiscovery.ModInfo;
import java.util.List;
import java.util.stream.Collectors;
class FMLHandshakeMessage
{
// Login index sequence number
private int index;
void setPacketIndexSequence(int i)
{
this.index = i;
}
int getPacketIndexSequence()
{
return index;
}
/**
* Server to client "list of mods". Always first handshake message.
*/
static class S2CModList extends FMLHandshakeMessage
{
private NBTTagList channels;
private List<String> modList;
S2CModList() {
this.modList = ModList.get().getMods().stream().map(ModInfo::getModId).collect(Collectors.toList());
}
S2CModList(NBTTagCompound nbtTagCompound)
{
this.modList = nbtTagCompound.getTagList("modlist", 8).stream().map(INBTBase::getString).collect(Collectors.toList());
this.channels = nbtTagCompound.getTagList("channels", 10);
}
static S2CModList decode(PacketBuffer packetBuffer)
{
final NBTTagCompound nbtTagCompound = packetBuffer.readCompoundTag();
return new S2CModList(nbtTagCompound);
}
void encode(PacketBuffer packetBuffer)
{
NBTTagCompound tag = new NBTTagCompound();
tag.setTag("modlist",modList.stream().map(NBTTagString::new).collect(Collectors.toCollection(NBTTagList::new)));
tag.setTag("channels", NetworkRegistry.buildChannelVersions());
packetBuffer.writeCompoundTag(tag);
}
String getModList() {
return String.join(",", modList);
}
NBTTagList getChannels() {
return this.channels;
}
}
static class C2SModListReply extends S2CModList
{
C2SModListReply() {
super();
}
C2SModListReply(final NBTTagCompound buffer) {
super(buffer);
}
static C2SModListReply decode(PacketBuffer buffer)
{
return new C2SModListReply(buffer.readCompoundTag());
}
public void encode(PacketBuffer buffer)
{
super.encode(buffer);
}
}
}

View File

@ -1,14 +1,9 @@
package net.minecraftforge.fml.network;
import io.netty.util.AttributeKey;
import net.minecraft.nbt.INBTBase;
import net.minecraft.nbt.NBTTagCompound;
import net.minecraft.nbt.NBTTagList;
import net.minecraft.nbt.NBTTagString;
import net.minecraft.network.NetworkManager;
import net.minecraft.network.PacketBuffer;
import net.minecraftforge.fml.ModList;
import net.minecraftforge.fml.loading.moddiscovery.ModInfo;
import net.minecraft.util.text.ITextComponent;
import net.minecraft.util.text.TextComponentString;
import net.minecraftforge.fml.network.simple.SimpleChannel;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
@ -19,7 +14,6 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Supplier;
import java.util.stream.Collectors;
public class FMLNetworking
{
@ -49,24 +43,24 @@ public class FMLNetworking
public static class FMLHandshake {
private static SimpleChannel channel;
private static List<Supplier<HandshakeMessage>> messages = Arrays.asList(HandshakeMessage.S2CModList::new);
private List<HandshakeMessage> sentMessages = new ArrayList<>();
private static List<Supplier<FMLHandshakeMessage>> messages = Arrays.asList(FMLHandshakeMessage.S2CModList::new);
private List<FMLHandshakeMessage> sentMessages = new ArrayList<>();
static {
channel = NetworkRegistry.ChannelBuilder.named(NetworkHooks.FMLHANDSHAKE).
clientAcceptedVersions(a -> true).
serverAcceptedVersions(a -> true).
networkProtocolVersion(() -> NetworkHooks.NETVERSION).
simpleChannel();
channel.messageBuilder(HandshakeMessage.S2CModList.class, 1).
decoder(HandshakeMessage.S2CModList::decode).
encoder(HandshakeMessage.S2CModList::encode).
loginIndex(HandshakeMessage.S2CModList::setPacketIndexSequence).
channel.messageBuilder(FMLHandshakeMessage.S2CModList.class, 1).
decoder(FMLHandshakeMessage.S2CModList::decode).
encoder(FMLHandshakeMessage.S2CModList::encode).
loginIndex(FMLHandshakeMessage::getPacketIndexSequence, FMLHandshakeMessage::setPacketIndexSequence).
consumer((m,c)->getHandshake(c).handleServerModListOnClient(m, c)).
add();
channel.messageBuilder(HandshakeMessage.C2SModListReply.class, 2).
loginIndex(HandshakeMessage::setPacketIndexSequence).
decoder(HandshakeMessage.C2SModListReply::decode).
encoder(HandshakeMessage.C2SModListReply::encode).
channel.messageBuilder(FMLHandshakeMessage.C2SModListReply.class, 2).
loginIndex(FMLHandshakeMessage::getPacketIndexSequence, FMLHandshakeMessage::setPacketIndexSequence).
decoder(FMLHandshakeMessage.C2SModListReply::decode).
encoder(FMLHandshakeMessage.C2SModListReply::encode).
consumer((m,c) -> getHandshake(c).handleClientModListOnServer(m,c)).
add();
}
@ -82,30 +76,63 @@ public class FMLNetworking
this.manager = networkManager;
}
public void handleServerModListOnClient(HandshakeMessage.S2CModList serverModList, Supplier<NetworkEvent.Context> c)
private void handleServerModListOnClient(FMLHandshakeMessage.S2CModList serverModList, Supplier<NetworkEvent.Context> c)
{
LOGGER.debug(FMLHSMARKER, "Received S2CModList packet with index {}", serverModList.getPacketIndexSequence());
LOGGER.debug(FMLHSMARKER, "Logging into server with mod list [{}]", serverModList.getModList());
boolean accepted = NetworkRegistry.validateClientChannels(serverModList.getChannels());
c.get().setPacketHandled(true);
final HandshakeMessage.C2SModListReply reply = new HandshakeMessage.C2SModListReply();
channel.sendLogin(reply, c.get().getNetworkManager(), c.get().getDirection().reply(), reply.getPacketIndexSequence());
if (!accepted) {
LOGGER.error(FMLHSMARKER, "Terminating connection with server, mismatched mod list");
c.get().getNetworkManager().closeChannel(new TextComponentString("Connection closed - mismatched mod channel list"));
return;
}
final FMLHandshakeMessage.C2SModListReply reply = new FMLHandshakeMessage.C2SModListReply();
reply.setPacketIndexSequence(serverModList.getPacketIndexSequence());
channel.reply(reply, c.get());
LOGGER.debug(FMLHSMARKER, "Sent C2SModListReply packet with index {}", reply.getPacketIndexSequence());
}
private void handleClientModListOnServer(HandshakeMessage.C2SModListReply m, Supplier<NetworkEvent.Context> c)
private void handleClientModListOnServer(FMLHandshakeMessage.C2SModListReply clientModList, Supplier<NetworkEvent.Context> c)
{
LOGGER.debug(FMLHSMARKER, "Received C2SModListReply with index {}", m.getPacketIndexSequence());
final HandshakeMessage message = this.sentMessages.stream().filter(ob -> ob.getPacketIndexSequence() == m.getPacketIndexSequence()).findFirst().orElseThrow(() -> new RuntimeException("Unexpected reply from client"));
LOGGER.debug(FMLHSMARKER, "Received client connection with modlist [{}]", clientModList.getModList());
final FMLHandshakeMessage message = this.sentMessages.stream().filter(ob -> ob.getPacketIndexSequence() == clientModList.getPacketIndexSequence()).findFirst().orElseThrow(() -> new RuntimeException("Unexpected reply from client"));
boolean removed = this.sentMessages.remove(message);
boolean accepted = NetworkRegistry.validateServerChannels(clientModList.getChannels());
c.get().setPacketHandled(true);
if (!accepted) {
LOGGER.error(FMLHSMARKER, "Terminating connection with client, mismatched mod list");
c.get().getNetworkManager().closeChannel(new TextComponentString("Connection closed - mismatched mod channel list"));
return;
}
LOGGER.debug(FMLHSMARKER, "Cleared original message {}", removed);
}
/**
* Design of handshake.
*
* After {@link net.minecraft.server.network.NetHandlerLoginServer} enters the {@link net.minecraft.server.network.NetHandlerLoginServer.LoginState#NEGOTIATING}
* state, this will be ticked once per server tick.
*
* FML will send packets, from Server to Client, from the messages queue until the queue is drained. Each message
* will be indexed, and placed into the "pending acknowledgement" queue.
*
* The client should send an acknowledgement for every packet that has a positive index, containing
* that index (and maybe other data as well).
*
* As indexed packets are received at the server, they will be removed from the "pending acknowledgement" queue.
*
* Once the pending queue is drained, this method returns true - indicating that login processing can proceed to
* the next step.
*
* @return true if there is no more need to tick this login connection.
*/
public boolean tickServer()
{
if (packetPosition < messages.size()) {
final HandshakeMessage message = messages.get(packetPosition).get();
final FMLHandshakeMessage message = messages.get(packetPosition).get();
message.setPacketIndexSequence(packetPosition);
LOGGER.debug(FMLHSMARKER, "Sending ticking packet {} index {}", message.getClass().getName(), message.getPacketIndexSequence());
channel.sendLogin(message, this.manager, this.direction, packetPosition);
channel.sendTo(message, this.manager, this.direction);
sentMessages.add(message);
packetPosition++;
}
@ -121,57 +148,4 @@ public class FMLNetworking
}
static class HandshakeMessage
{
private int index;
public void setPacketIndexSequence(int i)
{
this.index = i;
}
public int getPacketIndexSequence()
{
return index;
}
static class S2CModList extends HandshakeMessage
{
private List<String> modList;
S2CModList() {
this.modList = ModList.get().getMods().stream().map(ModInfo::getModId).collect(Collectors.toList());
}
S2CModList(NBTTagCompound nbtTagCompound)
{
this.modList = nbtTagCompound.getTagList("list", 8).stream().map(INBTBase::getString).collect(Collectors.toList());
}
public static S2CModList decode(PacketBuffer packetBuffer)
{
final NBTTagCompound nbtTagCompound = packetBuffer.readCompoundTag();
return new S2CModList(nbtTagCompound);
}
public void encode(PacketBuffer packetBuffer)
{
NBTTagCompound tag = new NBTTagCompound();
tag.setTag("list",modList.stream().map(NBTTagString::new).collect(Collectors.toCollection(NBTTagList::new)));
packetBuffer.writeCompoundTag(tag);
}
}
static class C2SModListReply extends HandshakeMessage
{
public static C2SModListReply decode(PacketBuffer buffer)
{
return new C2SModListReply();
}
public void encode(PacketBuffer buffer)
{
}
}
}
}

View File

@ -29,6 +29,7 @@ import net.minecraft.network.play.server.SPacketCustomPayload;
import net.minecraft.util.ResourceLocation;
import net.minecraftforge.fml.LogicalSide;
import net.minecraftforge.fml.UnsafeHacks;
import org.apache.commons.lang3.tuple.Pair;
import java.util.function.BiFunction;
import java.util.function.Function;
@ -84,12 +85,12 @@ public enum NetworkDirection
}
@SuppressWarnings("unchecked")
public <T extends Packet<?>> ICustomPacket<T> buildPacket(PacketBuffer packetBuffer, ResourceLocation channelName, int index)
public <T extends Packet<?>> ICustomPacket<T> buildPacket(Pair<PacketBuffer,Integer> packetData, ResourceLocation channelName)
{
ICustomPacket<T> packet = (ICustomPacket<T>)UnsafeHacks.newInstance(getPacketClass());
packet.setName(channelName);
packet.setData(packetBuffer);
packet.setIndex(index);
packet.setData(packetData.getLeft());
packet.setIndex(packetData.getRight());
return packet;
}
}

View File

@ -32,10 +32,13 @@ public class NetworkEvent extends Event
{
private final PacketBuffer payload;
private final Supplier<Context> source;
private NetworkEvent(ICustomPacket<?> payload, Supplier<Context> source)
private final int loginIndex;
private NetworkEvent(final ICustomPacket<?> payload, final Supplier<Context> source)
{
this.payload = payload.getData();
this.source = source;
this.loginIndex = payload.getIndex();
}
public PacketBuffer getPayload()
@ -48,9 +51,13 @@ public class NetworkEvent extends Event
return source;
}
public int getLoginIndex()
{
return loginIndex;
}
public static class ServerCustomPayloadEvent extends NetworkEvent
{
ServerCustomPayloadEvent(final ICustomPacket<?> payload, final Supplier<Context> source) {
super(payload, source);
}
@ -61,39 +68,20 @@ public class NetworkEvent extends Event
super(payload, source);
}
}
public static class ServerCustomPayloadLoginEvent extends ServerCustomPayloadEvent implements ILoginIndex {
private final int index;
public static class ServerCustomPayloadLoginEvent extends ServerCustomPayloadEvent {
ServerCustomPayloadLoginEvent(ICustomPacket<?> payload, Supplier<Context> source)
{
super(payload, source);
this.index = payload.getIndex();
}
public int getIndex()
{
return index;
}
}
public static class ClientCustomPayloadLoginEvent extends ClientCustomPayloadEvent implements ILoginIndex {
private final int index;
public static class ClientCustomPayloadLoginEvent extends ClientCustomPayloadEvent {
ClientCustomPayloadLoginEvent(ICustomPacket<?> payload, Supplier<Context> source)
{
super(payload, source);
this.index = payload.getIndex();
}
public int getIndex()
{
return index;
}
}
public interface ILoginIndex {
int getIndex();
}
/**
* Context for {@link NetworkEvent}
*/

View File

@ -19,6 +19,7 @@
package net.minecraftforge.fml.network;
import net.minecraft.nbt.INBTBase;
import net.minecraft.network.NetworkManager;
import net.minecraft.network.PacketBuffer;
import net.minecraft.util.ResourceLocation;
@ -38,7 +39,7 @@ public class NetworkInstance
}
private final ResourceLocation channelName;
private final Supplier<String> networkProtocolVersion;
private final String networkProtocolVersion;
private final Predicate<String> clientAcceptedVersions;
private final Predicate<String> serverAcceptedVersions;
private final IEventBus networkEventBus;
@ -46,7 +47,7 @@ public class NetworkInstance
NetworkInstance(ResourceLocation channelName, Supplier<String> networkProtocolVersion, Predicate<String> clientAcceptedVersions, Predicate<String> serverAcceptedVersions)
{
this.channelName = channelName;
this.networkProtocolVersion = networkProtocolVersion;
this.networkProtocolVersion = networkProtocolVersion.get();
this.clientAcceptedVersions = clientAcceptedVersions;
this.serverAcceptedVersions = serverAcceptedVersions;
this.networkEventBus = IEventBus.create(this::handleError);
@ -78,4 +79,15 @@ public class NetworkInstance
}
String getNetworkProtocolVersion() {
return networkProtocolVersion;
}
boolean tryServerVersionOnClient(final String serverVersion) {
return this.clientAcceptedVersions.test(serverVersion);
}
boolean tryClientVersionOnServer(final String clientVersion) {
return this.serverAcceptedVersions.test(clientVersion);
}
}

View File

@ -19,10 +19,12 @@
package net.minecraftforge.fml.network;
import io.netty.util.AttributeKey;
import net.minecraft.nbt.NBTTagCompound;
import net.minecraft.nbt.NBTTagList;
import net.minecraft.util.ResourceLocation;
import net.minecraftforge.fml.network.event.EventNetworkChannel;
import net.minecraftforge.fml.network.simple.SimpleChannel;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.Marker;
@ -35,6 +37,7 @@ import java.util.Map;
import java.util.Optional;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Collectors;
public class NetworkRegistry
{
@ -71,6 +74,53 @@ public class NetworkRegistry
return Optional.ofNullable(instances.get(resourceLocation));
}
static NBTTagList buildChannelVersions() {
return instances.entrySet().stream().map(e-> {
final NBTTagCompound tag = new NBTTagCompound();
tag.setString("name", e.getKey().toString());
tag.setString("version", e.getValue().getNetworkProtocolVersion());
return tag;
}).collect(Collectors.toCollection(NBTTagList::new));
}
static boolean validateClientChannels(final NBTTagList channels) {
final List<Pair<ResourceLocation, Boolean>> results = channels.stream().map(t -> {
NBTTagCompound tag = (NBTTagCompound) t;
final ResourceLocation rl = new ResourceLocation(tag.getString("name"));
final String serverVersion = tag.getString("version");
boolean test = instances.get(rl).tryServerVersionOnClient(serverVersion);
LOGGER.debug(NETREGISTRY, "Channel {} : Client version test of ''{}'' from server : {}", rl, serverVersion, test);
return Pair.of(rl, test);
}).filter(p->!p.getRight()).collect(Collectors.toList());
if (!results.isEmpty()) {
LOGGER.error(NETREGISTRY, "Channels [{}] rejected their server side version number",
results.stream().map(Pair::getLeft).map(Object::toString).collect(Collectors.joining(",")));
return false;
}
LOGGER.debug(NETREGISTRY, "Accepting channel list from server");
return true;
}
static boolean validateServerChannels(final NBTTagList channels) {
final List<Pair<ResourceLocation, Boolean>> results = channels.stream().map(t -> {
NBTTagCompound tag = (NBTTagCompound) t;
final ResourceLocation rl = new ResourceLocation(tag.getString("name"));
final String clientVersion = tag.getString("version");
boolean test = instances.get(rl).tryClientVersionOnServer(clientVersion);
LOGGER.debug(NETREGISTRY, "Channel {} : Server version test of ''{}'' from client : {}", rl, clientVersion, test);
return Pair.of(rl, test);
}).filter(p->!p.getRight()).collect(Collectors.toList());
if (!results.isEmpty()) {
LOGGER.error(NETREGISTRY, "Channels [{}] rejected their client side version number",
results.stream().map(Pair::getLeft).map(Object::toString).collect(Collectors.joining(",")));
return false;
}
LOGGER.debug(NETREGISTRY, "Accepting channel list from client");
return true;
}
public static class ChannelBuilder {
private ResourceLocation channelName;
private Supplier<String> networkProtocolVersion;

View File

@ -21,9 +21,7 @@ package net.minecraftforge.fml.network.simple;
import it.unimi.dsi.fastutil.objects.Object2ObjectArrayMap;
import it.unimi.dsi.fastutil.shorts.Short2ObjectArrayMap;
import net.minecraft.network.Packet;
import net.minecraft.network.PacketBuffer;
import net.minecraftforge.fml.network.ICustomPacket;
import net.minecraftforge.fml.network.NetworkEvent;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
@ -32,7 +30,6 @@ import org.apache.logging.log4j.MarkerManager;
import java.util.Optional;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;
@ -40,69 +37,86 @@ public class IndexedMessageCodec
{
private static final Logger LOGGER = LogManager.getLogger();
private static final Marker SIMPLENET = MarkerManager.getMarker("SIMPLENET");
private final Short2ObjectArrayMap<CodecIndex<?>> indicies = new Short2ObjectArrayMap<>();
private final Object2ObjectArrayMap<Class<?>, CodecIndex<?>> types = new Object2ObjectArrayMap<>();
private final Short2ObjectArrayMap<MessageHandler<?>> indicies = new Short2ObjectArrayMap<>();
private final Object2ObjectArrayMap<Class<?>, MessageHandler<?>> types = new Object2ObjectArrayMap<>();
@SuppressWarnings("unchecked")
public <MSG> MessageHandler<MSG> findMessageType(final MSG msgToReply) {
return (MessageHandler<MSG>) types.get(msgToReply.getClass());
}
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
public class CodecIndex<MSG>
class MessageHandler<MSG>
{
private final Optional<BiConsumer<MSG, PacketBuffer>> encoder;
private final Optional<Function<PacketBuffer, MSG>> decoder;
private final int index;
private final BiConsumer<MSG,Supplier<NetworkEvent.Context>> messageConsumer;
private final Class<MSG> messageType;
private Optional<BiConsumer<MSG, Integer>> loginIndexFunction;
private Optional<BiConsumer<MSG, Integer>> loginIndexSetter;
private Optional<Function<MSG, Integer>> loginIndexGetter;
public CodecIndex(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer)
public MessageHandler(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer)
{
this.index = index;
this.messageType = messageType;
this.encoder = Optional.ofNullable(encoder);
this.decoder = Optional.ofNullable(decoder);
this.messageConsumer = messageConsumer;
this.loginIndexFunction = Optional.empty();
this.loginIndexGetter = Optional.empty();
this.loginIndexSetter = Optional.empty();
indicies.put((short)(index & 0xff), this);
types.put(messageType, this);
}
public void setLoginIndexFunction(BiConsumer<MSG, Integer> loginIndexFunction)
void setLoginIndexSetter(BiConsumer<MSG, Integer> loginIndexSetter)
{
this.loginIndexFunction = Optional.of(loginIndexFunction);
this.loginIndexSetter = Optional.of(loginIndexSetter);
}
public Optional<BiConsumer<MSG, Integer>> getLoginIndexFunction() {
return this.loginIndexFunction;
Optional<BiConsumer<MSG, Integer>> getLoginIndexSetter() {
return this.loginIndexSetter;
}
void setLoginIndexGetter(Function<MSG, Integer> loginIndexGetter) {
this.loginIndexGetter = Optional.of(loginIndexGetter);
}
public Optional<Function<MSG, Integer>> getLoginIndexGetter() {
return this.loginIndexGetter;
}
}
private static <M> void tryDecode(PacketBuffer payload, Supplier<NetworkEvent.Context> context, CodecIndex<M> codec)
{
codec.decoder.map(d->d.apply(payload)).ifPresent(m->codec.messageConsumer.accept(m, context));
}
private static <M> void tryDecode(PacketBuffer payload, Supplier<NetworkEvent.Context> context, int payloadIndex, CodecIndex<M> codec)
private static <M> void tryDecode(PacketBuffer payload, Supplier<NetworkEvent.Context> context, int payloadIndex, MessageHandler<M> codec)
{
codec.decoder.map(d->d.apply(payload)).
map(p->{ codec.getLoginIndexFunction().ifPresent(f-> f.accept(p, payloadIndex)); return p; }).
ifPresent(m->codec.messageConsumer.accept(m, context));
map(p->{
// Only run the loginIndex function for payloadIndexed packets (login)
if (payloadIndex != Integer.MIN_VALUE)
{
codec.getLoginIndexSetter().ifPresent(f-> f.accept(p, payloadIndex));
}
return p;
}).ifPresent(m->codec.messageConsumer.accept(m, context));
}
private static <M> void tryEncode(PacketBuffer target, M message, CodecIndex<M> codec) {
private static <M> int tryEncode(PacketBuffer target, M message, MessageHandler<M> codec) {
codec.encoder.ifPresent(encoder->{
target.writeByte(codec.index & 0xff);
encoder.accept(message, target);
});
return codec.loginIndexGetter.orElse(m -> Integer.MIN_VALUE).apply(message);
}
public <MSG> void build(MSG message, PacketBuffer target)
public <MSG> int build(MSG message, PacketBuffer target)
{
@SuppressWarnings("unchecked")
CodecIndex<MSG> codecIndex = (CodecIndex<MSG>)types.get(message.getClass());
if (codecIndex == null) {
MessageHandler<MSG> messageHandler = (MessageHandler<MSG>)types.get(message.getClass());
if (messageHandler == null) {
LOGGER.error(SIMPLENET, "Received invalid message {}", message.getClass().getName());
throw new IllegalArgumentException("Invalid message "+message.getClass().getName());
}
tryEncode(target, message, codecIndex);
return tryEncode(target, message, messageHandler);
}
void consume(PacketBuffer payload, int payloadIndex, Supplier<NetworkEvent.Context> context) {
@ -111,30 +125,15 @@ public class IndexedMessageCodec
return;
}
short discriminator = payload.readUnsignedByte();
final CodecIndex<?> codecIndex = indicies.get(discriminator);
if (codecIndex == null) {
final MessageHandler<?> messageHandler = indicies.get(discriminator);
if (messageHandler == null) {
LOGGER.error(SIMPLENET, "Received invalid discriminator byte {}", discriminator);
return;
}
tryDecode(payload, context, payloadIndex, codecIndex);
tryDecode(payload, context, payloadIndex, messageHandler);
}
void consume(PacketBuffer payload, Supplier<NetworkEvent.Context> context) {
// no data in empty payload
if (payload == null) {
LOGGER.error(SIMPLENET, "Received empty payload");
return;
}
short discriminator = payload.readUnsignedByte();
final CodecIndex<?> codecIndex = indicies.get(discriminator);
if (codecIndex == null) {
LOGGER.error(SIMPLENET, "Received invalid discriminator byte {}", discriminator);
return;
}
tryDecode(payload, context, codecIndex);
}
<MSG> CodecIndex<MSG> addCodecIndex(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer) {
return new CodecIndex<>(index, messageType, encoder, decoder, messageConsumer);
<MSG> MessageHandler<MSG> addCodecIndex(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer) {
return new MessageHandler<>(index, messageType, encoder, decoder, messageConsumer);
}
}

View File

@ -24,18 +24,14 @@ import net.minecraft.client.Minecraft;
import net.minecraft.network.NetworkManager;
import net.minecraft.network.Packet;
import net.minecraft.network.PacketBuffer;
import net.minecraft.network.play.client.CPacketCustomPayload;
import net.minecraftforge.fml.network.ICustomPacket;
import net.minecraftforge.fml.network.NetworkDirection;
import net.minecraftforge.fml.network.NetworkEvent;
import net.minecraftforge.fml.network.NetworkInstance;
import org.apache.commons.lang3.tuple.Pair;
import java.util.function.BiConsumer;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.function.IntBinaryOperator;
import java.util.function.IntConsumer;
import java.util.function.IntFunction;
import java.util.function.Supplier;
public class SimpleChannel
@ -52,42 +48,38 @@ public class SimpleChannel
private void networkEventListener(final NetworkEvent networkEvent)
{
if (networkEvent instanceof NetworkEvent.ILoginIndex)
{
this.indexedCodec.consume(networkEvent.getPayload(), ((NetworkEvent.ILoginIndex)networkEvent).getIndex(), networkEvent.getSource());
}
else
{
this.indexedCodec.consume(networkEvent.getPayload(), networkEvent.getSource());
}
this.indexedCodec.consume(networkEvent.getPayload(), networkEvent.getLoginIndex(), networkEvent.getSource());
}
public <MSG> void encodeMessage(MSG message, final PacketBuffer target) {
this.indexedCodec.build(message, target);
public <MSG> int encodeMessage(MSG message, final PacketBuffer target) {
return this.indexedCodec.build(message, target);
}
public <MSG> IndexedMessageCodec.CodecIndex<MSG> registerMessage(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer) {
public <MSG> IndexedMessageCodec.MessageHandler<MSG> registerMessage(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer) {
return this.indexedCodec.addCodecIndex(index, messageType, encoder, decoder, messageConsumer);
}
private <MSG> PacketBuffer toBuffer(MSG msg) {
private <MSG> Pair<PacketBuffer,Integer> toBuffer(MSG msg) {
final PacketBuffer bufIn = new PacketBuffer(Unpooled.buffer());
encodeMessage(msg, bufIn);
return bufIn;
int index = encodeMessage(msg, bufIn);
return Pair.of(bufIn, index);
}
public <MSG> void sendToServer(MSG message)
{
sendTo(message, Minecraft.getMinecraft().getConnection().getNetworkManager(), NetworkDirection.PLAY_TO_SERVER);
}
public <MSG> void sendTo(MSG message, NetworkManager manager, NetworkDirection direction) {
ICustomPacket<Packet<?>> payload = direction.buildPacket(toBuffer(message), instance.getChannelName(), -1);
public <MSG> void sendTo(MSG message, NetworkManager manager, NetworkDirection direction)
{
ICustomPacket<Packet<?>> payload = direction.buildPacket(toBuffer(message), instance.getChannelName());
manager.sendPacket(payload.getThis());
}
public <MSG> void sendLogin(MSG message, NetworkManager manager, NetworkDirection direction, int packetIndex) {
ICustomPacket<Packet<?>> payload = direction.buildPacket(toBuffer(message), instance.getChannelName(), packetIndex);
manager.sendPacket(payload.getThis());
public <MSG> void reply(MSG msgToReply, NetworkEvent.Context context)
{
sendTo(msgToReply, context.getNetworkManager(), context.getDirection().reply());
}
public <M> MessageBuilder<M> messageBuilder(final Class<M> type, int id) {
return MessageBuilder.forType(this, type, id);
}
@ -99,7 +91,8 @@ public class SimpleChannel
private BiConsumer<MSG, PacketBuffer> encoder;
private Function<PacketBuffer, MSG> decoder;
private BiConsumer<MSG, Supplier<NetworkEvent.Context>> consumer;
private BiConsumer<MSG, Integer> loginIndexFunction;
private Function<MSG, Integer> loginIndexGetter;
private BiConsumer<MSG, Integer> loginIndexSetter;
private static <MSG> MessageBuilder<MSG> forType(final SimpleChannel channel, final Class<MSG> type, int id) {
MessageBuilder<MSG> builder = new MessageBuilder<>();
@ -119,8 +112,9 @@ public class SimpleChannel
return this;
}
public MessageBuilder<MSG> loginIndex(BiConsumer<MSG, Integer> loginIndexFunction) {
this.loginIndexFunction = loginIndexFunction;
public MessageBuilder<MSG> loginIndex(Function<MSG, Integer> loginIndexGetter, BiConsumer<MSG, Integer> loginIndexSetter) {
this.loginIndexGetter = loginIndexGetter;
this.loginIndexSetter = loginIndexSetter;
return this;
}
public MessageBuilder<MSG> consumer(BiConsumer<MSG, Supplier<NetworkEvent.Context>> consumer) {
@ -129,9 +123,12 @@ public class SimpleChannel
}
public void add() {
final IndexedMessageCodec.CodecIndex<MSG> message = this.channel.registerMessage(this.id, this.type, this.encoder, this.decoder, this.consumer);
if (this.loginIndexFunction != null) {
message.setLoginIndexFunction(this.loginIndexFunction);
final IndexedMessageCodec.MessageHandler<MSG> message = this.channel.registerMessage(this.id, this.type, this.encoder, this.decoder, this.consumer);
if (this.loginIndexSetter != null) {
message.setLoginIndexSetter(this.loginIndexSetter);
}
if (this.loginIndexGetter != null) {
message.setLoginIndexGetter(this.loginIndexGetter);
}
}
}