diff --git a/build.gradle b/build.gradle index e250d69d4..00f06c131 100644 --- a/build.gradle +++ b/build.gradle @@ -87,6 +87,7 @@ project(':forge') { target: 'fmldevserver' ] } + mcVersion = '1.13' } applyPatches { canonicalizeAccess true diff --git a/patches/minecraft/net/minecraft/item/Item.java.patch b/patches/minecraft/net/minecraft/item/Item.java.patch index 736b9ea1e..775a1258c 100644 --- a/patches/minecraft/net/minecraft/item/Item.java.patch +++ b/patches/minecraft/net/minecraft/item/Item.java.patch @@ -60,7 +60,7 @@ float f7 = f2 * f4; - double d3 = 5.0D; - Vec3d vec3d1 = vec3d.add((double)f6 * 5.0D, (double)f5 * 5.0D, (double)f7 * 5.0D); -+ double d3 = playerIn.getEntityAttribute(EntityPlayer.REACH_DISTANCE).getAttributeValue(); ++ double d3 = 6; + Vec3d vec3d1 = vec3d.add((double)f6 * d3, (double)f5 * d3, (double)f7 * d3); return worldIn.func_200259_a(vec3d, vec3d1, useLiquids ? RayTraceFluidMode.SOURCE_ONLY : RayTraceFluidMode.NEVER, false, false); } diff --git a/src/main/java/net/minecraftforge/fml/network/FMLHandshakeMessage.java b/src/main/java/net/minecraftforge/fml/network/FMLHandshakeMessage.java new file mode 100644 index 000000000..41f14c08a --- /dev/null +++ b/src/main/java/net/minecraftforge/fml/network/FMLHandshakeMessage.java @@ -0,0 +1,89 @@ +package net.minecraftforge.fml.network; + +import net.minecraft.nbt.INBTBase; +import net.minecraft.nbt.NBTTagCompound; +import net.minecraft.nbt.NBTTagList; +import net.minecraft.nbt.NBTTagString; +import net.minecraft.network.PacketBuffer; +import net.minecraftforge.fml.ModList; +import net.minecraftforge.fml.loading.moddiscovery.ModInfo; + +import java.util.List; +import java.util.stream.Collectors; + +class FMLHandshakeMessage +{ + // Login index sequence number + private int index; + void setPacketIndexSequence(int i) + { + this.index = i; + } + + int getPacketIndexSequence() + { + return index; + } + + /** + * Server to client "list of mods". Always first handshake message. + */ + static class S2CModList extends FMLHandshakeMessage + { + private NBTTagList channels; + private List modList; + + S2CModList() { + this.modList = ModList.get().getMods().stream().map(ModInfo::getModId).collect(Collectors.toList()); + } + + S2CModList(NBTTagCompound nbtTagCompound) + { + this.modList = nbtTagCompound.getTagList("modlist", 8).stream().map(INBTBase::getString).collect(Collectors.toList()); + this.channels = nbtTagCompound.getTagList("channels", 10); + } + + static S2CModList decode(PacketBuffer packetBuffer) + { + final NBTTagCompound nbtTagCompound = packetBuffer.readCompoundTag(); + return new S2CModList(nbtTagCompound); + } + + void encode(PacketBuffer packetBuffer) + { + NBTTagCompound tag = new NBTTagCompound(); + tag.setTag("modlist",modList.stream().map(NBTTagString::new).collect(Collectors.toCollection(NBTTagList::new))); + tag.setTag("channels", NetworkRegistry.buildChannelVersions()); + packetBuffer.writeCompoundTag(tag); + } + + String getModList() { + return String.join(",", modList); + } + + NBTTagList getChannels() { + return this.channels; + } + } + + static class C2SModListReply extends S2CModList + { + C2SModListReply() { + super(); + } + + C2SModListReply(final NBTTagCompound buffer) { + super(buffer); + } + + static C2SModListReply decode(PacketBuffer buffer) + { + return new C2SModListReply(buffer.readCompoundTag()); + } + + public void encode(PacketBuffer buffer) + { + super.encode(buffer); + } + } +} diff --git a/src/main/java/net/minecraftforge/fml/network/FMLNetworking.java b/src/main/java/net/minecraftforge/fml/network/FMLNetworking.java index 4c5a50918..90d86d956 100644 --- a/src/main/java/net/minecraftforge/fml/network/FMLNetworking.java +++ b/src/main/java/net/minecraftforge/fml/network/FMLNetworking.java @@ -1,14 +1,9 @@ package net.minecraftforge.fml.network; import io.netty.util.AttributeKey; -import net.minecraft.nbt.INBTBase; -import net.minecraft.nbt.NBTTagCompound; -import net.minecraft.nbt.NBTTagList; -import net.minecraft.nbt.NBTTagString; import net.minecraft.network.NetworkManager; -import net.minecraft.network.PacketBuffer; -import net.minecraftforge.fml.ModList; -import net.minecraftforge.fml.loading.moddiscovery.ModInfo; +import net.minecraft.util.text.ITextComponent; +import net.minecraft.util.text.TextComponentString; import net.minecraftforge.fml.network.simple.SimpleChannel; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -19,7 +14,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.function.Supplier; -import java.util.stream.Collectors; public class FMLNetworking { @@ -49,24 +43,24 @@ public class FMLNetworking public static class FMLHandshake { private static SimpleChannel channel; - private static List> messages = Arrays.asList(HandshakeMessage.S2CModList::new); - private List sentMessages = new ArrayList<>(); + private static List> messages = Arrays.asList(FMLHandshakeMessage.S2CModList::new); + private List sentMessages = new ArrayList<>(); static { channel = NetworkRegistry.ChannelBuilder.named(NetworkHooks.FMLHANDSHAKE). clientAcceptedVersions(a -> true). serverAcceptedVersions(a -> true). networkProtocolVersion(() -> NetworkHooks.NETVERSION). simpleChannel(); - channel.messageBuilder(HandshakeMessage.S2CModList.class, 1). - decoder(HandshakeMessage.S2CModList::decode). - encoder(HandshakeMessage.S2CModList::encode). - loginIndex(HandshakeMessage.S2CModList::setPacketIndexSequence). + channel.messageBuilder(FMLHandshakeMessage.S2CModList.class, 1). + decoder(FMLHandshakeMessage.S2CModList::decode). + encoder(FMLHandshakeMessage.S2CModList::encode). + loginIndex(FMLHandshakeMessage::getPacketIndexSequence, FMLHandshakeMessage::setPacketIndexSequence). consumer((m,c)->getHandshake(c).handleServerModListOnClient(m, c)). add(); - channel.messageBuilder(HandshakeMessage.C2SModListReply.class, 2). - loginIndex(HandshakeMessage::setPacketIndexSequence). - decoder(HandshakeMessage.C2SModListReply::decode). - encoder(HandshakeMessage.C2SModListReply::encode). + channel.messageBuilder(FMLHandshakeMessage.C2SModListReply.class, 2). + loginIndex(FMLHandshakeMessage::getPacketIndexSequence, FMLHandshakeMessage::setPacketIndexSequence). + decoder(FMLHandshakeMessage.C2SModListReply::decode). + encoder(FMLHandshakeMessage.C2SModListReply::encode). consumer((m,c) -> getHandshake(c).handleClientModListOnServer(m,c)). add(); } @@ -82,30 +76,63 @@ public class FMLNetworking this.manager = networkManager; } - public void handleServerModListOnClient(HandshakeMessage.S2CModList serverModList, Supplier c) + private void handleServerModListOnClient(FMLHandshakeMessage.S2CModList serverModList, Supplier c) { - LOGGER.debug(FMLHSMARKER, "Received S2CModList packet with index {}", serverModList.getPacketIndexSequence()); + LOGGER.debug(FMLHSMARKER, "Logging into server with mod list [{}]", serverModList.getModList()); + boolean accepted = NetworkRegistry.validateClientChannels(serverModList.getChannels()); c.get().setPacketHandled(true); - final HandshakeMessage.C2SModListReply reply = new HandshakeMessage.C2SModListReply(); - channel.sendLogin(reply, c.get().getNetworkManager(), c.get().getDirection().reply(), reply.getPacketIndexSequence()); + if (!accepted) { + LOGGER.error(FMLHSMARKER, "Terminating connection with server, mismatched mod list"); + c.get().getNetworkManager().closeChannel(new TextComponentString("Connection closed - mismatched mod channel list")); + return; + } + final FMLHandshakeMessage.C2SModListReply reply = new FMLHandshakeMessage.C2SModListReply(); + reply.setPacketIndexSequence(serverModList.getPacketIndexSequence()); + channel.reply(reply, c.get()); LOGGER.debug(FMLHSMARKER, "Sent C2SModListReply packet with index {}", reply.getPacketIndexSequence()); } - private void handleClientModListOnServer(HandshakeMessage.C2SModListReply m, Supplier c) + private void handleClientModListOnServer(FMLHandshakeMessage.C2SModListReply clientModList, Supplier c) { - LOGGER.debug(FMLHSMARKER, "Received C2SModListReply with index {}", m.getPacketIndexSequence()); - final HandshakeMessage message = this.sentMessages.stream().filter(ob -> ob.getPacketIndexSequence() == m.getPacketIndexSequence()).findFirst().orElseThrow(() -> new RuntimeException("Unexpected reply from client")); + LOGGER.debug(FMLHSMARKER, "Received client connection with modlist [{}]", clientModList.getModList()); + final FMLHandshakeMessage message = this.sentMessages.stream().filter(ob -> ob.getPacketIndexSequence() == clientModList.getPacketIndexSequence()).findFirst().orElseThrow(() -> new RuntimeException("Unexpected reply from client")); boolean removed = this.sentMessages.remove(message); + boolean accepted = NetworkRegistry.validateServerChannels(clientModList.getChannels()); c.get().setPacketHandled(true); + if (!accepted) { + LOGGER.error(FMLHSMARKER, "Terminating connection with client, mismatched mod list"); + c.get().getNetworkManager().closeChannel(new TextComponentString("Connection closed - mismatched mod channel list")); + return; + } LOGGER.debug(FMLHSMARKER, "Cleared original message {}", removed); } + /** + * Design of handshake. + * + * After {@link net.minecraft.server.network.NetHandlerLoginServer} enters the {@link net.minecraft.server.network.NetHandlerLoginServer.LoginState#NEGOTIATING} + * state, this will be ticked once per server tick. + * + * FML will send packets, from Server to Client, from the messages queue until the queue is drained. Each message + * will be indexed, and placed into the "pending acknowledgement" queue. + * + * The client should send an acknowledgement for every packet that has a positive index, containing + * that index (and maybe other data as well). + * + * As indexed packets are received at the server, they will be removed from the "pending acknowledgement" queue. + * + * Once the pending queue is drained, this method returns true - indicating that login processing can proceed to + * the next step. + * + * @return true if there is no more need to tick this login connection. + */ public boolean tickServer() { if (packetPosition < messages.size()) { - final HandshakeMessage message = messages.get(packetPosition).get(); + final FMLHandshakeMessage message = messages.get(packetPosition).get(); + message.setPacketIndexSequence(packetPosition); LOGGER.debug(FMLHSMARKER, "Sending ticking packet {} index {}", message.getClass().getName(), message.getPacketIndexSequence()); - channel.sendLogin(message, this.manager, this.direction, packetPosition); + channel.sendTo(message, this.manager, this.direction); sentMessages.add(message); packetPosition++; } @@ -121,57 +148,4 @@ public class FMLNetworking } - static class HandshakeMessage - { - private int index; - public void setPacketIndexSequence(int i) - { - this.index = i; - } - - public int getPacketIndexSequence() - { - return index; - } - - static class S2CModList extends HandshakeMessage - { - private List modList; - - S2CModList() { - this.modList = ModList.get().getMods().stream().map(ModInfo::getModId).collect(Collectors.toList()); - } - - S2CModList(NBTTagCompound nbtTagCompound) - { - this.modList = nbtTagCompound.getTagList("list", 8).stream().map(INBTBase::getString).collect(Collectors.toList()); - } - - public static S2CModList decode(PacketBuffer packetBuffer) - { - final NBTTagCompound nbtTagCompound = packetBuffer.readCompoundTag(); - return new S2CModList(nbtTagCompound); - } - - public void encode(PacketBuffer packetBuffer) - { - NBTTagCompound tag = new NBTTagCompound(); - tag.setTag("list",modList.stream().map(NBTTagString::new).collect(Collectors.toCollection(NBTTagList::new))); - packetBuffer.writeCompoundTag(tag); - } - } - - static class C2SModListReply extends HandshakeMessage - { - public static C2SModListReply decode(PacketBuffer buffer) - { - return new C2SModListReply(); - } - - public void encode(PacketBuffer buffer) - { - - } - } - } } diff --git a/src/main/java/net/minecraftforge/fml/network/NetworkDirection.java b/src/main/java/net/minecraftforge/fml/network/NetworkDirection.java index bf55c7590..20a7a65e7 100644 --- a/src/main/java/net/minecraftforge/fml/network/NetworkDirection.java +++ b/src/main/java/net/minecraftforge/fml/network/NetworkDirection.java @@ -29,6 +29,7 @@ import net.minecraft.network.play.server.SPacketCustomPayload; import net.minecraft.util.ResourceLocation; import net.minecraftforge.fml.LogicalSide; import net.minecraftforge.fml.UnsafeHacks; +import org.apache.commons.lang3.tuple.Pair; import java.util.function.BiFunction; import java.util.function.Function; @@ -84,12 +85,12 @@ public enum NetworkDirection } @SuppressWarnings("unchecked") - public > ICustomPacket buildPacket(PacketBuffer packetBuffer, ResourceLocation channelName, int index) + public > ICustomPacket buildPacket(Pair packetData, ResourceLocation channelName) { ICustomPacket packet = (ICustomPacket)UnsafeHacks.newInstance(getPacketClass()); packet.setName(channelName); - packet.setData(packetBuffer); - packet.setIndex(index); + packet.setData(packetData.getLeft()); + packet.setIndex(packetData.getRight()); return packet; } } diff --git a/src/main/java/net/minecraftforge/fml/network/NetworkEvent.java b/src/main/java/net/minecraftforge/fml/network/NetworkEvent.java index 5c129be8f..ae48e4258 100644 --- a/src/main/java/net/minecraftforge/fml/network/NetworkEvent.java +++ b/src/main/java/net/minecraftforge/fml/network/NetworkEvent.java @@ -32,10 +32,13 @@ public class NetworkEvent extends Event { private final PacketBuffer payload; private final Supplier source; - private NetworkEvent(ICustomPacket payload, Supplier source) + private final int loginIndex; + + private NetworkEvent(final ICustomPacket payload, final Supplier source) { this.payload = payload.getData(); this.source = source; + this.loginIndex = payload.getIndex(); } public PacketBuffer getPayload() @@ -48,9 +51,13 @@ public class NetworkEvent extends Event return source; } + public int getLoginIndex() + { + return loginIndex; + } + public static class ServerCustomPayloadEvent extends NetworkEvent { - ServerCustomPayloadEvent(final ICustomPacket payload, final Supplier source) { super(payload, source); } @@ -61,39 +68,20 @@ public class NetworkEvent extends Event super(payload, source); } } - public static class ServerCustomPayloadLoginEvent extends ServerCustomPayloadEvent implements ILoginIndex { - private final int index; - + public static class ServerCustomPayloadLoginEvent extends ServerCustomPayloadEvent { ServerCustomPayloadLoginEvent(ICustomPacket payload, Supplier source) { super(payload, source); - this.index = payload.getIndex(); - } - - public int getIndex() - { - return index; } } - public static class ClientCustomPayloadLoginEvent extends ClientCustomPayloadEvent implements ILoginIndex { - private final int index; - + public static class ClientCustomPayloadLoginEvent extends ClientCustomPayloadEvent { ClientCustomPayloadLoginEvent(ICustomPacket payload, Supplier source) { super(payload, source); - this.index = payload.getIndex(); - } - - public int getIndex() - { - return index; } } - public interface ILoginIndex { - int getIndex(); - } /** * Context for {@link NetworkEvent} */ diff --git a/src/main/java/net/minecraftforge/fml/network/NetworkInstance.java b/src/main/java/net/minecraftforge/fml/network/NetworkInstance.java index f71cc23aa..b8efef78a 100644 --- a/src/main/java/net/minecraftforge/fml/network/NetworkInstance.java +++ b/src/main/java/net/minecraftforge/fml/network/NetworkInstance.java @@ -19,6 +19,7 @@ package net.minecraftforge.fml.network; +import net.minecraft.nbt.INBTBase; import net.minecraft.network.NetworkManager; import net.minecraft.network.PacketBuffer; import net.minecraft.util.ResourceLocation; @@ -38,7 +39,7 @@ public class NetworkInstance } private final ResourceLocation channelName; - private final Supplier networkProtocolVersion; + private final String networkProtocolVersion; private final Predicate clientAcceptedVersions; private final Predicate serverAcceptedVersions; private final IEventBus networkEventBus; @@ -46,7 +47,7 @@ public class NetworkInstance NetworkInstance(ResourceLocation channelName, Supplier networkProtocolVersion, Predicate clientAcceptedVersions, Predicate serverAcceptedVersions) { this.channelName = channelName; - this.networkProtocolVersion = networkProtocolVersion; + this.networkProtocolVersion = networkProtocolVersion.get(); this.clientAcceptedVersions = clientAcceptedVersions; this.serverAcceptedVersions = serverAcceptedVersions; this.networkEventBus = IEventBus.create(this::handleError); @@ -78,4 +79,15 @@ public class NetworkInstance } + String getNetworkProtocolVersion() { + return networkProtocolVersion; + } + + boolean tryServerVersionOnClient(final String serverVersion) { + return this.clientAcceptedVersions.test(serverVersion); + } + + boolean tryClientVersionOnServer(final String clientVersion) { + return this.serverAcceptedVersions.test(clientVersion); + } } diff --git a/src/main/java/net/minecraftforge/fml/network/NetworkRegistry.java b/src/main/java/net/minecraftforge/fml/network/NetworkRegistry.java index 4629fba41..743911979 100644 --- a/src/main/java/net/minecraftforge/fml/network/NetworkRegistry.java +++ b/src/main/java/net/minecraftforge/fml/network/NetworkRegistry.java @@ -19,10 +19,12 @@ package net.minecraftforge.fml.network; -import io.netty.util.AttributeKey; +import net.minecraft.nbt.NBTTagCompound; +import net.minecraft.nbt.NBTTagList; import net.minecraft.util.ResourceLocation; import net.minecraftforge.fml.network.event.EventNetworkChannel; import net.minecraftforge.fml.network.simple.SimpleChannel; +import org.apache.commons.lang3.tuple.Pair; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.Marker; @@ -35,6 +37,7 @@ import java.util.Map; import java.util.Optional; import java.util.function.Predicate; import java.util.function.Supplier; +import java.util.stream.Collectors; public class NetworkRegistry { @@ -71,6 +74,53 @@ public class NetworkRegistry return Optional.ofNullable(instances.get(resourceLocation)); } + static NBTTagList buildChannelVersions() { + return instances.entrySet().stream().map(e-> { + final NBTTagCompound tag = new NBTTagCompound(); + tag.setString("name", e.getKey().toString()); + tag.setString("version", e.getValue().getNetworkProtocolVersion()); + return tag; + }).collect(Collectors.toCollection(NBTTagList::new)); + } + + static boolean validateClientChannels(final NBTTagList channels) { + final List> results = channels.stream().map(t -> { + NBTTagCompound tag = (NBTTagCompound) t; + final ResourceLocation rl = new ResourceLocation(tag.getString("name")); + final String serverVersion = tag.getString("version"); + boolean test = instances.get(rl).tryServerVersionOnClient(serverVersion); + LOGGER.debug(NETREGISTRY, "Channel {} : Client version test of ''{}'' from server : {}", rl, serverVersion, test); + return Pair.of(rl, test); + }).filter(p->!p.getRight()).collect(Collectors.toList()); + + if (!results.isEmpty()) { + LOGGER.error(NETREGISTRY, "Channels [{}] rejected their server side version number", + results.stream().map(Pair::getLeft).map(Object::toString).collect(Collectors.joining(","))); + return false; + } + LOGGER.debug(NETREGISTRY, "Accepting channel list from server"); + return true; + } + + static boolean validateServerChannels(final NBTTagList channels) { + final List> results = channels.stream().map(t -> { + NBTTagCompound tag = (NBTTagCompound) t; + final ResourceLocation rl = new ResourceLocation(tag.getString("name")); + final String clientVersion = tag.getString("version"); + boolean test = instances.get(rl).tryClientVersionOnServer(clientVersion); + LOGGER.debug(NETREGISTRY, "Channel {} : Server version test of ''{}'' from client : {}", rl, clientVersion, test); + return Pair.of(rl, test); + }).filter(p->!p.getRight()).collect(Collectors.toList()); + + if (!results.isEmpty()) { + LOGGER.error(NETREGISTRY, "Channels [{}] rejected their client side version number", + results.stream().map(Pair::getLeft).map(Object::toString).collect(Collectors.joining(","))); + return false; + } + LOGGER.debug(NETREGISTRY, "Accepting channel list from client"); + return true; + } + public static class ChannelBuilder { private ResourceLocation channelName; private Supplier networkProtocolVersion; diff --git a/src/main/java/net/minecraftforge/fml/network/simple/IndexedMessageCodec.java b/src/main/java/net/minecraftforge/fml/network/simple/IndexedMessageCodec.java index 1a7854387..9bcfa9d84 100644 --- a/src/main/java/net/minecraftforge/fml/network/simple/IndexedMessageCodec.java +++ b/src/main/java/net/minecraftforge/fml/network/simple/IndexedMessageCodec.java @@ -21,9 +21,7 @@ package net.minecraftforge.fml.network.simple; import it.unimi.dsi.fastutil.objects.Object2ObjectArrayMap; import it.unimi.dsi.fastutil.shorts.Short2ObjectArrayMap; -import net.minecraft.network.Packet; import net.minecraft.network.PacketBuffer; -import net.minecraftforge.fml.network.ICustomPacket; import net.minecraftforge.fml.network.NetworkEvent; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -32,7 +30,6 @@ import org.apache.logging.log4j.MarkerManager; import java.util.Optional; import java.util.function.BiConsumer; -import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Supplier; @@ -40,69 +37,86 @@ public class IndexedMessageCodec { private static final Logger LOGGER = LogManager.getLogger(); private static final Marker SIMPLENET = MarkerManager.getMarker("SIMPLENET"); - private final Short2ObjectArrayMap> indicies = new Short2ObjectArrayMap<>(); - private final Object2ObjectArrayMap, CodecIndex> types = new Object2ObjectArrayMap<>(); + private final Short2ObjectArrayMap> indicies = new Short2ObjectArrayMap<>(); + private final Object2ObjectArrayMap, MessageHandler> types = new Object2ObjectArrayMap<>(); + + @SuppressWarnings("unchecked") + public MessageHandler findMessageType(final MSG msgToReply) { + return (MessageHandler) types.get(msgToReply.getClass()); + } @SuppressWarnings("OptionalUsedAsFieldOrParameterType") - public class CodecIndex + class MessageHandler { - private final Optional> encoder; private final Optional> decoder; private final int index; private final BiConsumer> messageConsumer; private final Class messageType; - private Optional> loginIndexFunction; + private Optional> loginIndexSetter; + private Optional> loginIndexGetter; - public CodecIndex(int index, Class messageType, BiConsumer encoder, Function decoder, BiConsumer> messageConsumer) + public MessageHandler(int index, Class messageType, BiConsumer encoder, Function decoder, BiConsumer> messageConsumer) { this.index = index; this.messageType = messageType; this.encoder = Optional.ofNullable(encoder); this.decoder = Optional.ofNullable(decoder); this.messageConsumer = messageConsumer; - this.loginIndexFunction = Optional.empty(); + this.loginIndexGetter = Optional.empty(); + this.loginIndexSetter = Optional.empty(); indicies.put((short)(index & 0xff), this); types.put(messageType, this); } - public void setLoginIndexFunction(BiConsumer loginIndexFunction) + void setLoginIndexSetter(BiConsumer loginIndexSetter) { - this.loginIndexFunction = Optional.of(loginIndexFunction); + this.loginIndexSetter = Optional.of(loginIndexSetter); } - public Optional> getLoginIndexFunction() { - return this.loginIndexFunction; + Optional> getLoginIndexSetter() { + return this.loginIndexSetter; + } + + void setLoginIndexGetter(Function loginIndexGetter) { + this.loginIndexGetter = Optional.of(loginIndexGetter); + } + + public Optional> getLoginIndexGetter() { + return this.loginIndexGetter; } } - private static void tryDecode(PacketBuffer payload, Supplier context, CodecIndex codec) - { - codec.decoder.map(d->d.apply(payload)).ifPresent(m->codec.messageConsumer.accept(m, context)); - } - private static void tryDecode(PacketBuffer payload, Supplier context, int payloadIndex, CodecIndex codec) + private static void tryDecode(PacketBuffer payload, Supplier context, int payloadIndex, MessageHandler codec) { codec.decoder.map(d->d.apply(payload)). - map(p->{ codec.getLoginIndexFunction().ifPresent(f-> f.accept(p, payloadIndex)); return p; }). - ifPresent(m->codec.messageConsumer.accept(m, context)); + map(p->{ + // Only run the loginIndex function for payloadIndexed packets (login) + if (payloadIndex != Integer.MIN_VALUE) + { + codec.getLoginIndexSetter().ifPresent(f-> f.accept(p, payloadIndex)); + } + return p; + }).ifPresent(m->codec.messageConsumer.accept(m, context)); } - private static void tryEncode(PacketBuffer target, M message, CodecIndex codec) { + private static int tryEncode(PacketBuffer target, M message, MessageHandler codec) { codec.encoder.ifPresent(encoder->{ target.writeByte(codec.index & 0xff); encoder.accept(message, target); }); + return codec.loginIndexGetter.orElse(m -> Integer.MIN_VALUE).apply(message); } - public void build(MSG message, PacketBuffer target) + public int build(MSG message, PacketBuffer target) { @SuppressWarnings("unchecked") - CodecIndex codecIndex = (CodecIndex)types.get(message.getClass()); - if (codecIndex == null) { + MessageHandler messageHandler = (MessageHandler)types.get(message.getClass()); + if (messageHandler == null) { LOGGER.error(SIMPLENET, "Received invalid message {}", message.getClass().getName()); throw new IllegalArgumentException("Invalid message "+message.getClass().getName()); } - tryEncode(target, message, codecIndex); + return tryEncode(target, message, messageHandler); } void consume(PacketBuffer payload, int payloadIndex, Supplier context) { @@ -111,30 +125,15 @@ public class IndexedMessageCodec return; } short discriminator = payload.readUnsignedByte(); - final CodecIndex codecIndex = indicies.get(discriminator); - if (codecIndex == null) { + final MessageHandler messageHandler = indicies.get(discriminator); + if (messageHandler == null) { LOGGER.error(SIMPLENET, "Received invalid discriminator byte {}", discriminator); return; } - tryDecode(payload, context, payloadIndex, codecIndex); + tryDecode(payload, context, payloadIndex, messageHandler); } - void consume(PacketBuffer payload, Supplier context) { - // no data in empty payload - if (payload == null) { - LOGGER.error(SIMPLENET, "Received empty payload"); - return; - } - short discriminator = payload.readUnsignedByte(); - final CodecIndex codecIndex = indicies.get(discriminator); - if (codecIndex == null) { - LOGGER.error(SIMPLENET, "Received invalid discriminator byte {}", discriminator); - return; - } - tryDecode(payload, context, codecIndex); - } - - CodecIndex addCodecIndex(int index, Class messageType, BiConsumer encoder, Function decoder, BiConsumer> messageConsumer) { - return new CodecIndex<>(index, messageType, encoder, decoder, messageConsumer); + MessageHandler addCodecIndex(int index, Class messageType, BiConsumer encoder, Function decoder, BiConsumer> messageConsumer) { + return new MessageHandler<>(index, messageType, encoder, decoder, messageConsumer); } } diff --git a/src/main/java/net/minecraftforge/fml/network/simple/SimpleChannel.java b/src/main/java/net/minecraftforge/fml/network/simple/SimpleChannel.java index 984e8b79f..f413dabe3 100644 --- a/src/main/java/net/minecraftforge/fml/network/simple/SimpleChannel.java +++ b/src/main/java/net/minecraftforge/fml/network/simple/SimpleChannel.java @@ -24,18 +24,14 @@ import net.minecraft.client.Minecraft; import net.minecraft.network.NetworkManager; import net.minecraft.network.Packet; import net.minecraft.network.PacketBuffer; -import net.minecraft.network.play.client.CPacketCustomPayload; import net.minecraftforge.fml.network.ICustomPacket; import net.minecraftforge.fml.network.NetworkDirection; import net.minecraftforge.fml.network.NetworkEvent; import net.minecraftforge.fml.network.NetworkInstance; +import org.apache.commons.lang3.tuple.Pair; import java.util.function.BiConsumer; -import java.util.function.BinaryOperator; import java.util.function.Function; -import java.util.function.IntBinaryOperator; -import java.util.function.IntConsumer; -import java.util.function.IntFunction; import java.util.function.Supplier; public class SimpleChannel @@ -52,42 +48,38 @@ public class SimpleChannel private void networkEventListener(final NetworkEvent networkEvent) { - if (networkEvent instanceof NetworkEvent.ILoginIndex) - { - this.indexedCodec.consume(networkEvent.getPayload(), ((NetworkEvent.ILoginIndex)networkEvent).getIndex(), networkEvent.getSource()); - } - else - { - this.indexedCodec.consume(networkEvent.getPayload(), networkEvent.getSource()); - } + this.indexedCodec.consume(networkEvent.getPayload(), networkEvent.getLoginIndex(), networkEvent.getSource()); } - public void encodeMessage(MSG message, final PacketBuffer target) { - this.indexedCodec.build(message, target); + public int encodeMessage(MSG message, final PacketBuffer target) { + return this.indexedCodec.build(message, target); } - public IndexedMessageCodec.CodecIndex registerMessage(int index, Class messageType, BiConsumer encoder, Function decoder, BiConsumer> messageConsumer) { + public IndexedMessageCodec.MessageHandler registerMessage(int index, Class messageType, BiConsumer encoder, Function decoder, BiConsumer> messageConsumer) { return this.indexedCodec.addCodecIndex(index, messageType, encoder, decoder, messageConsumer); } - private PacketBuffer toBuffer(MSG msg) { + private Pair toBuffer(MSG msg) { final PacketBuffer bufIn = new PacketBuffer(Unpooled.buffer()); - encodeMessage(msg, bufIn); - return bufIn; + int index = encodeMessage(msg, bufIn); + return Pair.of(bufIn, index); } + public void sendToServer(MSG message) { sendTo(message, Minecraft.getMinecraft().getConnection().getNetworkManager(), NetworkDirection.PLAY_TO_SERVER); } - public void sendTo(MSG message, NetworkManager manager, NetworkDirection direction) { - ICustomPacket> payload = direction.buildPacket(toBuffer(message), instance.getChannelName(), -1); + public void sendTo(MSG message, NetworkManager manager, NetworkDirection direction) + { + ICustomPacket> payload = direction.buildPacket(toBuffer(message), instance.getChannelName()); manager.sendPacket(payload.getThis()); } - public void sendLogin(MSG message, NetworkManager manager, NetworkDirection direction, int packetIndex) { - ICustomPacket> payload = direction.buildPacket(toBuffer(message), instance.getChannelName(), packetIndex); - manager.sendPacket(payload.getThis()); + public void reply(MSG msgToReply, NetworkEvent.Context context) + { + sendTo(msgToReply, context.getNetworkManager(), context.getDirection().reply()); } + public MessageBuilder messageBuilder(final Class type, int id) { return MessageBuilder.forType(this, type, id); } @@ -99,7 +91,8 @@ public class SimpleChannel private BiConsumer encoder; private Function decoder; private BiConsumer> consumer; - private BiConsumer loginIndexFunction; + private Function loginIndexGetter; + private BiConsumer loginIndexSetter; private static MessageBuilder forType(final SimpleChannel channel, final Class type, int id) { MessageBuilder builder = new MessageBuilder<>(); @@ -119,8 +112,9 @@ public class SimpleChannel return this; } - public MessageBuilder loginIndex(BiConsumer loginIndexFunction) { - this.loginIndexFunction = loginIndexFunction; + public MessageBuilder loginIndex(Function loginIndexGetter, BiConsumer loginIndexSetter) { + this.loginIndexGetter = loginIndexGetter; + this.loginIndexSetter = loginIndexSetter; return this; } public MessageBuilder consumer(BiConsumer> consumer) { @@ -129,9 +123,12 @@ public class SimpleChannel } public void add() { - final IndexedMessageCodec.CodecIndex message = this.channel.registerMessage(this.id, this.type, this.encoder, this.decoder, this.consumer); - if (this.loginIndexFunction != null) { - message.setLoginIndexFunction(this.loginIndexFunction); + final IndexedMessageCodec.MessageHandler message = this.channel.registerMessage(this.id, this.type, this.encoder, this.decoder, this.consumer); + if (this.loginIndexSetter != null) { + message.setLoginIndexSetter(this.loginIndexSetter); + } + if (this.loginIndexGetter != null) { + message.setLoginIndexGetter(this.loginIndexGetter); } } }