Refactor Object Decoder/Encoder

Prevent crash on netplay
This commit is contained in:
Anthony Calosa
2020-04-05 07:07:30 +08:00
parent 9f1b93373d
commit c714187eee
10 changed files with 201 additions and 148 deletions

View File

@@ -0,0 +1,55 @@
package forge.net;
import io.netty.handler.codec.serialization.ClassResolver;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectStreamClass;
import java.io.StreamCorruptedException;
public class CObjectInputStream extends ObjectInputStream {
private final ClassResolver classResolver;
CObjectInputStream(InputStream in, ClassResolver classResolver) throws IOException {
super(in);
this.classResolver = classResolver;
}
protected void readStreamHeader() throws IOException {
int version = this.readByte() & 255;
if (version != 5) {
throw new StreamCorruptedException("Unsupported version: " + version);
}
}
protected ObjectStreamClass readClassDescriptor() throws IOException, ClassNotFoundException {
int type = this.read();
if (type < 0) {
throw new EOFException();
} else {
switch(type) {
case 0:
return super.readClassDescriptor();
case 1:
String className = this.readUTF();
Class<?> clazz = this.classResolver.resolve(className);
return ObjectStreamClass.lookupAny(clazz);
default:
throw new StreamCorruptedException("Unexpected class descriptor type: " + type);
}
}
}
protected Class<?> resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException {
Class clazz;
try {
clazz = this.classResolver.resolve(desc.getName());
} catch (ClassNotFoundException var4) {
clazz = super.resolveClass(desc);
}
return clazz;
}
}

View File

@@ -0,0 +1,31 @@
package forge.net;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.ObjectStreamClass;
import java.io.OutputStream;
public class CObjectOutputStream extends ObjectOutputStream {
static final int TYPE_FAT_DESCRIPTOR = 0;
static final int TYPE_THIN_DESCRIPTOR = 1;
CObjectOutputStream(OutputStream out) throws IOException {
super(out);
}
protected void writeStreamHeader() throws IOException {
this.writeByte(5);
}
protected void writeClassDescriptor(ObjectStreamClass desc) throws IOException {
Class<?> clazz = desc.forClass();
if (!clazz.isPrimitive() && !clazz.isArray() && !clazz.isInterface() && desc.getSerialVersionUID() != 0L) {
this.write(1);
this.writeUTF(desc.getName());
} else {
this.write(0);
super.writeClassDescriptor(desc);
}
}
}

View File

@@ -0,0 +1,46 @@
package forge.net;
import forge.GuiBase;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufInputStream;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.codec.serialization.ClassResolver;
import java.io.ObjectInputStream;
import java.io.StreamCorruptedException;
public class CompatibleObjectDecoder extends LengthFieldBasedFrameDecoder {
private final ClassResolver classResolver;
public CompatibleObjectDecoder(ClassResolver classResolver) {
this(1048576, classResolver);
}
public CompatibleObjectDecoder(int maxObjectSize, ClassResolver classResolver) {
super(maxObjectSize, 0, 4, 0, 4);
this.classResolver = classResolver;
}
protected Object decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
ByteBuf frame = (ByteBuf)super.decode(ctx, in);
if (frame == null) {
return null;
} else {
ObjectInputStream ois = GuiBase.hasPropertyConfig() ?
new ObjectInputStream(new ByteBufInputStream(frame, true)):
new CObjectInputStream(new ByteBufInputStream(frame, true),this.classResolver);
Object var5 = null;
try {
var5 = ois.readObject();
} catch (StreamCorruptedException e) {
e.printStackTrace();
} finally {
ois.close();
}
return var5;
}
}
}

View File

@@ -0,0 +1,37 @@
package forge.net;
import forge.GuiBase;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufOutputStream;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToByteEncoder;
import java.io.ObjectOutputStream;
import java.io.Serializable;
public class CompatibleObjectEncoder extends MessageToByteEncoder<Serializable> {
private static final byte[] LENGTH_PLACEHOLDER = new byte[4];
@Override
protected void encode(ChannelHandlerContext ctx, Serializable msg, ByteBuf out) throws Exception {
int startIdx = out.writerIndex();
ByteBufOutputStream bout = new ByteBufOutputStream(out);
ObjectOutputStream oout = null;
try {
bout.write(LENGTH_PLACEHOLDER);
oout = GuiBase.hasPropertyConfig() ? new ObjectOutputStream(bout) : new CObjectOutputStream(bout);
oout.writeObject(msg);
oout.flush();
} finally {
if (oout != null) {
oout.close();
} else {
bout.close();
}
}
int endIdx = out.writerIndex();
out.setInt(startIdx, endIdx - startIdx - 4);
}
}

View File

@@ -1,58 +0,0 @@
package forge.net;
import forge.GuiBase;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufInputStream;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.codec.serialization.ClassResolver;
import org.mapdb.elsa.ElsaObjectInputStream;
import java.io.ObjectInputStream;
public class CustomObjectDecoder extends LengthFieldBasedFrameDecoder {
private final ClassResolver classResolver;
public CustomObjectDecoder(ClassResolver classResolver) {
this(1048576, classResolver);
}
public CustomObjectDecoder(int maxObjectSize, ClassResolver classResolver) {
super(maxObjectSize, 0, 4, 0, 4);
this.classResolver = classResolver;
}
protected Object decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
ByteBuf frame = (ByteBuf) super.decode(ctx, in);
if (frame == null) {
return null;
} else {
if (GuiBase.hasPropertyConfig()){
ElsaObjectInputStream ois = new ElsaObjectInputStream(new ByteBufInputStream(frame, true));
Object var5;
try {
var5 = ois.readObject();
} finally {
ois.close();
}
return var5;
}
else {
ObjectInputStream ois = new ObjectInputStream(new ByteBufInputStream(frame, true));
Object var5;
try {
var5 = ois.readObject();
} finally {
ois.close();
}
return var5;
}
}
}
public static int maxObjectsize = 10000000; //10megabyte???
}

View File

@@ -1,56 +0,0 @@
package forge.net;
import forge.GuiBase;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufOutputStream;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToByteEncoder;
import org.mapdb.elsa.ElsaObjectOutputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
public class CustomObjectEncoder extends MessageToByteEncoder<Serializable> {
private static final byte[] LENGTH_PLACEHOLDER = new byte[4];
public CustomObjectEncoder() {
}
protected void encode(ChannelHandlerContext ctx, Serializable msg, ByteBuf out) throws Exception {
int startIdx = out.writerIndex();
ByteBufOutputStream bout = new ByteBufOutputStream(out);
if (GuiBase.hasPropertyConfig()){
ElsaObjectOutputStream oout = null;
try {
bout.write(LENGTH_PLACEHOLDER);
oout = new ElsaObjectOutputStream(bout);
oout.writeObject(msg);
oout.flush();
} finally {
if (oout != null) {
oout.close();
} else {
bout.close();
}
}
} else {
ObjectOutputStream oout = null;
try {
bout.write(LENGTH_PLACEHOLDER);
oout = new ObjectOutputStream(bout);
oout.writeObject(msg);
oout.flush();
} finally {
if (oout != null) {
oout.close();
} else {
bout.close();
}
}
}
int endIdx = out.writerIndex();
out.setInt(startIdx, endIdx - startIdx - 4);
}
}

View File

@@ -1,8 +1,10 @@
package forge.net; package forge.net;
import forge.FThreads; import forge.FThreads;
import forge.assets.FSkinProp;
import forge.net.event.GuiGameEvent; import forge.net.event.GuiGameEvent;
import forge.net.event.ReplyEvent; import forge.net.event.ReplyEvent;
import forge.util.gui.SOptionPane;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInboundHandlerAdapter;
@@ -25,6 +27,7 @@ public abstract class GameProtocolHandler<T> extends ChannelInboundHandlerAdapte
@Override @Override
public final void channelRead(final ChannelHandlerContext ctx, final Object msg) { public final void channelRead(final ChannelHandlerContext ctx, final Object msg) {
final String[] catchedError = {""};
System.out.println("Received: " + msg); System.out.println("Received: " + msg);
if (msg instanceof ReplyEvent) { if (msg instanceof ReplyEvent) {
final ReplyEvent event = (ReplyEvent) msg; final ReplyEvent event = (ReplyEvent) msg;
@@ -36,7 +39,9 @@ public abstract class GameProtocolHandler<T> extends ChannelInboundHandlerAdapte
final Method method = protocolMethod.getMethod(); final Method method = protocolMethod.getMethod();
if (method == null) { if (method == null) {
throw new IllegalStateException(String.format("Method %s not found", protocolMethod.name())); //throw new IllegalStateException(String.format("Method %s not found", protocolMethod.name()));
catchedError[0] += String.format("IllegalStateException: Method %s not found (GameProtocolHandler.java Line 43)\n", protocolMethod.name());
System.err.println(String.format("Method %s not found", protocolMethod.name()));
} }
final Object[] args = event.getObjects(); final Object[] args = event.getObjects();
@@ -56,7 +61,9 @@ public abstract class GameProtocolHandler<T> extends ChannelInboundHandlerAdapte
} catch (final IllegalAccessException | IllegalArgumentException e) { } catch (final IllegalAccessException | IllegalArgumentException e) {
System.err.println(String.format("Unknown protocol method %s with %d args", methodName, args == null ? 0 : args.length)); System.err.println(String.format("Unknown protocol method %s with %d args", methodName, args == null ? 0 : args.length));
} catch (final InvocationTargetException e) { } catch (final InvocationTargetException e) {
throw new RuntimeException(e.getTargetException()); //throw new RuntimeException(e.getTargetException());
catchedError[0] += (String.format("RuntimeException: %s (GameProtocolHandler.java Line 65)\n", e.getTargetException().toString()));
System.err.println(e.getTargetException().toString());
} }
} else { } else {
Serializable reply = null; Serializable reply = null;
@@ -70,8 +77,11 @@ public abstract class GameProtocolHandler<T> extends ChannelInboundHandlerAdapte
} }
} catch (final IllegalAccessException | IllegalArgumentException e) { } catch (final IllegalAccessException | IllegalArgumentException e) {
System.err.println(String.format("Unknown protocol method %s with %d args, replying with null", methodName, args == null ? 0 : args.length)); System.err.println(String.format("Unknown protocol method %s with %d args, replying with null", methodName, args == null ? 0 : args.length));
} catch (final InvocationTargetException e) { } catch (final NullPointerException | InvocationTargetException e) {
throw new RuntimeException(e.getTargetException()); //throw new RuntimeException(e.getTargetException());
catchedError[0] += e.toString();
SOptionPane.showMessageDialog(catchedError[0], "Error", FSkinProp.ICO_WARNING);
System.err.println(e.toString());
} }
getRemote(ctx).send(new ReplyEvent(event.getId(), reply)); getRemote(ctx).send(new ReplyEvent(event.getId(), reply));
} }

View File

@@ -1,7 +1,6 @@
package forge.net; package forge.net;
import com.google.common.base.Function; import com.google.common.base.Function;
import forge.GuiBase;
import forge.assets.FSkinProp; import forge.assets.FSkinProp;
import forge.deck.CardPool; import forge.deck.CardPool;
import forge.game.GameEntityView; import forge.game.GameEntityView;
@@ -18,12 +17,9 @@ import forge.player.PlayerZoneUpdates;
import forge.trackable.TrackableCollection; import forge.trackable.TrackableCollection;
import forge.util.ITriggerEvent; import forge.util.ITriggerEvent;
import forge.util.ReflectionUtil; import forge.util.ReflectionUtil;
import org.apache.commons.lang3.SerializationUtils;
import java.io.Serializable;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.Collection; import java.util.Collection;
import java.util.ConcurrentModificationException;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@@ -159,24 +155,15 @@ public enum ProtocolMethod {
} }
public void checkArgs(final Object[] args) { public void checkArgs(final Object[] args) {
if (GuiBase.hasPropertyConfig())
return; //uses custom serializer for Android 8+..
for (int iArg = 0; iArg < args.length; iArg++) { for (int iArg = 0; iArg < args.length; iArg++) {
Object arg = null; final Object arg = args[iArg];
Class<?> type = null; final Class<?> type = this.args[iArg];
try {
arg = args[iArg];
if (this.args.length > iArg)
type = this.args[iArg];
}
catch (ArrayIndexOutOfBoundsException ex){ ex.printStackTrace(); }
catch(ConcurrentModificationException ex) { ex.printStackTrace(); }
if (arg != null)
if (type != null)
if (!ReflectionUtil.isInstance(arg, type)) { if (!ReflectionUtil.isInstance(arg, type)) {
throw new InternalError(String.format("Protocol method %s: illegal argument (%d) of type %s, %s expected", name(), iArg, arg.getClass().getName(), type.getName())); //throw new InternalError(String.format("Protocol method %s: illegal argument (%d) of type %s, %s expected", name(), iArg, arg.getClass().getName(), type.getName()));
System.err.println(String.format("InternalError: Protocol method %s: illegal argument (%d) of type %s, %s expected (ProtocolMethod.java Line 163)", name(), iArg, arg.getClass().getName(), type.getName()));
} }
if (arg != null) { //this should be handled via decoder or it will process them twice
/*if (arg != null) {
// attempt to Serialize each argument, this will throw an exception if it can't. // attempt to Serialize each argument, this will throw an exception if it can't.
try { try {
byte[] serialized = SerializationUtils.serialize((Serializable) arg); byte[] serialized = SerializationUtils.serialize((Serializable) arg);
@@ -189,7 +176,7 @@ public enum ProtocolMethod {
// can't seem to avoid this from periodically happening // can't seem to avoid this from periodically happening
ex.printStackTrace(); ex.printStackTrace();
} }
} }*/
} }
} }
@@ -199,7 +186,8 @@ public enum ProtocolMethod {
return; return;
} }
if (!ReflectionUtil.isInstance(value, returnType)) { if (!ReflectionUtil.isInstance(value, returnType)) {
throw new IllegalStateException(String.format("Protocol method %s: illegal return object type %s returned by client, expected %s", name(), value.getClass().getName(), getReturnType().getName())); //throw new IllegalStateException(String.format("Protocol method %s: illegal return object type %s returned by client, expected %s", name(), value.getClass().getName(), getReturnType().getName()));
System.err.println(String.format("IllegalStateException: Protocol method %s: illegal return object type %s returned by client, expected %s (ProtocolMethod.java Line 190)", name(), value.getClass().getName(), getReturnType().getName()));
} }
} }
} }

View File

@@ -1,5 +1,7 @@
package forge.net.client; package forge.net.client;
import forge.net.CompatibleObjectDecoder;
import forge.net.CompatibleObjectEncoder;
import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
@@ -26,8 +28,6 @@ import forge.net.event.IdentifiableNetEvent;
import forge.net.event.LobbyUpdateEvent; import forge.net.event.LobbyUpdateEvent;
import forge.net.event.MessageEvent; import forge.net.event.MessageEvent;
import forge.net.event.NetEvent; import forge.net.event.NetEvent;
import io.netty.handler.codec.serialization.ObjectDecoder;
import io.netty.handler.codec.serialization.ObjectEncoder;
public class FGameClient implements IToServer { public class FGameClient implements IToServer {
@@ -58,8 +58,8 @@ public class FGameClient implements IToServer {
public void initChannel(final SocketChannel ch) throws Exception { public void initChannel(final SocketChannel ch) throws Exception {
final ChannelPipeline pipeline = ch.pipeline(); final ChannelPipeline pipeline = ch.pipeline();
pipeline.addLast( pipeline.addLast(
new ObjectEncoder(), new CompatibleObjectEncoder(),
new ObjectDecoder(9766*1024, ClassResolvers.cacheDisabled(null)), new CompatibleObjectDecoder(9766*1024, ClassResolvers.cacheDisabled(null)),
new MessageHandler(), new MessageHandler(),
new LobbyUpdateHandler(), new LobbyUpdateHandler(),
new GameClientHandler(FGameClient.this)); new GameClientHandler(FGameClient.this));

View File

@@ -6,6 +6,8 @@ import forge.interfaces.IGuiGame;
import forge.interfaces.ILobbyListener; import forge.interfaces.ILobbyListener;
import forge.match.LobbySlot; import forge.match.LobbySlot;
import forge.match.LobbySlotType; import forge.match.LobbySlotType;
import forge.net.CompatibleObjectDecoder;
import forge.net.CompatibleObjectEncoder;
import forge.net.event.LobbyUpdateEvent; import forge.net.event.LobbyUpdateEvent;
import forge.net.event.LoginEvent; import forge.net.event.LoginEvent;
import forge.net.event.LogoutEvent; import forge.net.event.LogoutEvent;
@@ -24,8 +26,6 @@ import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel; import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.serialization.ClassResolvers; import io.netty.handler.codec.serialization.ClassResolvers;
import io.netty.handler.codec.serialization.ObjectDecoder;
import io.netty.handler.codec.serialization.ObjectEncoder;
import io.netty.handler.logging.LogLevel; import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler; import io.netty.handler.logging.LoggingHandler;
@@ -99,8 +99,8 @@ public final class FServerManager {
public final void initChannel(final SocketChannel ch) throws Exception { public final void initChannel(final SocketChannel ch) throws Exception {
final ChannelPipeline p = ch.pipeline(); final ChannelPipeline p = ch.pipeline();
p.addLast( p.addLast(
new ObjectEncoder(), new CompatibleObjectEncoder(),
new ObjectDecoder(9766*1024, ClassResolvers.cacheDisabled(null)), new CompatibleObjectDecoder(9766*1024, ClassResolvers.cacheDisabled(null)),
new MessageHandler(), new MessageHandler(),
new RegisterClientHandler(), new RegisterClientHandler(),
new LobbyInputHandler(), new LobbyInputHandler(),