ForgePatch/src/main/java/net/minecraftforge/fml/network/simple/IndexedMessageCodec.java

154 lines
6.4 KiB
Java

/*
* Minecraft Forge
* Copyright (c) 2018.
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation version 2.1
* of the License.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*/
package net.minecraftforge.fml.network.simple;
import it.unimi.dsi.fastutil.objects.Object2ObjectArrayMap;
import it.unimi.dsi.fastutil.shorts.Short2ObjectArrayMap;
import net.minecraft.network.PacketBuffer;
import net.minecraftforge.fml.network.NetworkEvent;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.Marker;
import org.apache.logging.log4j.MarkerManager;
import java.util.Optional;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.function.Supplier;
public class IndexedMessageCodec
{
private static final Logger LOGGER = LogManager.getLogger();
private static final Marker SIMPLENET = MarkerManager.getMarker("SIMPLENET");
private final Short2ObjectArrayMap<MessageHandler<?>> indicies = new Short2ObjectArrayMap<>();
private final Object2ObjectArrayMap<Class<?>, MessageHandler<?>> types = new Object2ObjectArrayMap<>();
@SuppressWarnings("unchecked")
public <MSG> MessageHandler<MSG> findMessageType(final MSG msgToReply) {
return (MessageHandler<MSG>) types.get(msgToReply.getClass());
}
@SuppressWarnings("unchecked")
<MSG> MessageHandler<MSG> findIndex(final short i) {
return (MessageHandler<MSG>) indicies.get(i);
}
@SuppressWarnings("OptionalUsedAsFieldOrParameterType")
class MessageHandler<MSG>
{
private final Optional<BiConsumer<MSG, PacketBuffer>> encoder;
private final Optional<Function<PacketBuffer, MSG>> decoder;
private final int index;
private final BiConsumer<MSG,Supplier<NetworkEvent.Context>> messageConsumer;
private final Class<MSG> messageType;
private Optional<BiConsumer<MSG, Integer>> loginIndexSetter;
private Optional<Function<MSG, Integer>> loginIndexGetter;
public MessageHandler(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer)
{
this.index = index;
this.messageType = messageType;
this.encoder = Optional.ofNullable(encoder);
this.decoder = Optional.ofNullable(decoder);
this.messageConsumer = messageConsumer;
this.loginIndexGetter = Optional.empty();
this.loginIndexSetter = Optional.empty();
indicies.put((short)(index & 0xff), this);
types.put(messageType, this);
}
void setLoginIndexSetter(BiConsumer<MSG, Integer> loginIndexSetter)
{
this.loginIndexSetter = Optional.of(loginIndexSetter);
}
Optional<BiConsumer<MSG, Integer>> getLoginIndexSetter() {
return this.loginIndexSetter;
}
void setLoginIndexGetter(Function<MSG, Integer> loginIndexGetter) {
this.loginIndexGetter = Optional.of(loginIndexGetter);
}
public Optional<Function<MSG, Integer>> getLoginIndexGetter() {
return this.loginIndexGetter;
}
MSG newInstance() {
try {
return messageType.newInstance();
} catch (InstantiationException | IllegalAccessException e) {
LOGGER.error("Invalid login message", e);
throw new RuntimeException(e);
}
}
}
private static <M> void tryDecode(PacketBuffer payload, Supplier<NetworkEvent.Context> context, int payloadIndex, MessageHandler<M> codec)
{
codec.decoder.map(d->d.apply(payload)).
map(p->{
// Only run the loginIndex function for payloadIndexed packets (login)
if (payloadIndex != Integer.MIN_VALUE)
{
codec.getLoginIndexSetter().ifPresent(f-> f.accept(p, payloadIndex));
}
return p;
}).ifPresent(m->codec.messageConsumer.accept(m, context));
}
private static <M> int tryEncode(PacketBuffer target, M message, MessageHandler<M> codec) {
codec.encoder.ifPresent(encoder->{
target.writeByte(codec.index & 0xff);
encoder.accept(message, target);
});
return codec.loginIndexGetter.orElse(m -> Integer.MIN_VALUE).apply(message);
}
public <MSG> int build(MSG message, PacketBuffer target)
{
@SuppressWarnings("unchecked")
MessageHandler<MSG> messageHandler = (MessageHandler<MSG>)types.get(message.getClass());
if (messageHandler == null) {
LOGGER.error(SIMPLENET, "Received invalid message {}", message.getClass().getName());
throw new IllegalArgumentException("Invalid message "+message.getClass().getName());
}
return tryEncode(target, message, messageHandler);
}
void consume(PacketBuffer payload, int payloadIndex, Supplier<NetworkEvent.Context> context) {
if (payload == null) {
LOGGER.error(SIMPLENET, "Received empty payload");
return;
}
short discriminator = payload.readUnsignedByte();
final MessageHandler<?> messageHandler = indicies.get(discriminator);
if (messageHandler == null) {
LOGGER.error(SIMPLENET, "Received invalid discriminator byte {}", discriminator);
return;
}
tryDecode(payload, context, payloadIndex, messageHandler);
}
<MSG> MessageHandler<MSG> addCodecIndex(int index, Class<MSG> messageType, BiConsumer<MSG, PacketBuffer> encoder, Function<PacketBuffer, MSG> decoder, BiConsumer<MSG, Supplier<NetworkEvent.Context>> messageConsumer) {
return new MessageHandler<>(index, messageType, encoder, decoder, messageConsumer);
}
}