diff --git a/transport/src/main/java/io/netty/channel/DefaultChannelHandlerInvoker.java b/transport/src/main/java/io/netty/channel/DefaultChannelHandlerInvoker.java index 7f45f0bdbd..3d536b1d3e 100644 --- a/transport/src/main/java/io/netty/channel/DefaultChannelHandlerInvoker.java +++ b/transport/src/main/java/io/netty/channel/DefaultChannelHandlerInvoker.java @@ -16,13 +16,23 @@ package io.netty.channel; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufHolder; import io.netty.util.Recycler; import io.netty.util.ReferenceCountUtil; import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.FastThreadLocal; import io.netty.util.internal.OneTimeTask; import io.netty.util.internal.RecyclableMpscLinkedQueueNode; +import java.lang.reflect.Field; +import java.lang.reflect.Modifier; import java.net.SocketAddress; +import java.nio.ByteBuffer; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.WeakHashMap; import static io.netty.channel.ChannelHandlerInvokerUtil.*; import static io.netty.channel.DefaultChannelPipeline.*; @@ -336,16 +346,7 @@ public class DefaultChannelHandlerInvoker implements ChannelHandlerInvoker { if (executor.inEventLoop()) { invokeWriteNow(ctx, msg, promise); } else { - AbstractChannel channel = (AbstractChannel) ctx.channel(); - int size = channel.estimatorHandle().size(msg); - if (size > 0) { - ChannelOutboundBuffer buffer = channel.unsafe().outboundBuffer(); - // Check for null as it may be set to null if the channel is closed already - if (buffer != null) { - buffer.incrementPendingOutboundBytes(size); - } - } - safeExecuteOutbound(WriteTask.newInstance(ctx, msg, size, promise), promise, msg); + safeExecuteOutbound(WriteTask.newInstance(ctx, msg, promise), promise, msg); } } @@ -401,6 +402,92 @@ public class DefaultChannelHandlerInvoker implements ChannelHandlerInvoker { static final class WriteTask extends RecyclableMpscLinkedQueueNode implements SingleThreadEventLoop.NonWakeupRunnable { + + private static final FastThreadLocal, Integer>> CLASS_SIZES = + new FastThreadLocal, Integer>>() { + @Override + protected Map, Integer> initialValue() throws Exception { + Map, Integer> map = new WeakHashMap, Integer>(); + map.put(void.class, 0); + map.put(byte.class, 1); + map.put(char.class, 2); + map.put(short.class, 2); + map.put(boolean.class, 4); // Probably an integer. + map.put(int.class, 4); + map.put(float.class, 4); + map.put(long.class, 8); + map.put(double.class, 8); + return map; + } + }; + + private static int estimateSize(Object o, Map, Integer> classSizes) { + int answer = 8 + estimateSize(o.getClass(), classSizes, null); + + if (o instanceof ByteBuf) { + answer += ((ByteBuf) o).readableBytes(); + } else if (o instanceof ByteBufHolder) { + answer += ((ByteBufHolder) o).content().readableBytes(); + } else if (o instanceof FileRegion) { + // nothing to add. + } else if (o instanceof byte[]) { + answer += ((byte[]) o).length; + } else if (o instanceof ByteBuffer) { + answer += ((ByteBuffer) o).remaining(); + } else if (o instanceof CharSequence) { + answer += ((CharSequence) o).length() << 1; + } else if (o instanceof Iterable) { + for (Object m : (Iterable) o) { + answer += estimateSize(m, classSizes); + } + } + + return align(answer); + } + + private static int estimateSize(Class clazz, Map, Integer> classSizes, + Set> visitedClasses) { + Integer objectSize = classSizes.get(clazz); + if (objectSize != null) { + return objectSize; + } + + if (visitedClasses != null) { + if (visitedClasses.contains(clazz)) { + return 0; + } + } else { + visitedClasses = new HashSet>(); + } + + visitedClasses.add(clazz); + + int answer = 8; // Basic overhead. + for (Class c = clazz; c != null; c = c.getSuperclass()) { + Field[] fields = c.getDeclaredFields(); + for (Field f : fields) { + if ((f.getModifiers() & Modifier.STATIC) != 0) { + // Ignore static fields. + continue; + } + + answer += estimateSize(f.getType(), classSizes, visitedClasses); + } + } + + visitedClasses.remove(clazz); + + // Some alignment. + answer = align(answer); + + // Put the final answer. + classSizes.put(clazz, answer); + return answer; + } + + private static int align(int size) { + return size + 8 - (size & 7); + } private ChannelHandlerContext ctx; private Object msg; private ChannelPromise promise; @@ -414,12 +501,18 @@ public class DefaultChannelHandlerInvoker implements ChannelHandlerInvoker { }; private static WriteTask newInstance( - ChannelHandlerContext ctx, Object msg, int size, ChannelPromise promise) { + ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { WriteTask task = RECYCLER.get(); task.ctx = ctx; task.msg = msg; task.promise = promise; - task.size = size; + task.size = ((AbstractChannel) ctx.channel()).estimatorHandle().size(msg) + + estimateSize(task, CLASS_SIZES.get()); + ChannelOutboundBuffer buffer = ctx.channel().unsafe().outboundBuffer(); + // Check for null as it may be set to null if the channel is closed already + if (buffer != null) { + buffer.incrementPendingOutboundBytes(task.size); + } return task; } @@ -430,12 +523,10 @@ public class DefaultChannelHandlerInvoker implements ChannelHandlerInvoker { @Override public void run() { try { - if (size > 0) { - ChannelOutboundBuffer buffer = ctx.channel().unsafe().outboundBuffer(); - // Check for null as it may be set to null if the channel is closed already - if (buffer != null) { - buffer.decrementPendingOutboundBytes(size); - } + ChannelOutboundBuffer buffer = ctx.channel().unsafe().outboundBuffer(); + // Check for null as it may be set to null if the channel is closed already + if (buffer != null) { + buffer.decrementPendingOutboundBytes(size); } invokeWriteNow(ctx, msg, promise); } finally { diff --git a/transport/src/main/java/io/netty/channel/DefaultMessageSizeEstimator.java b/transport/src/main/java/io/netty/channel/DefaultMessageSizeEstimator.java index 19193e86e7..1459743259 100644 --- a/transport/src/main/java/io/netty/channel/DefaultMessageSizeEstimator.java +++ b/transport/src/main/java/io/netty/channel/DefaultMessageSizeEstimator.java @@ -47,9 +47,9 @@ public final class DefaultMessageSizeEstimator implements MessageSizeEstimator { } /** - * Return the default implementation which returns {@code -1} for unknown messages. + * Return the default implementation which returns {@code 8} for unknown messages. */ - public static final MessageSizeEstimator DEFAULT = new DefaultMessageSizeEstimator(0); + public static final MessageSizeEstimator DEFAULT = new DefaultMessageSizeEstimator(8); private final Handle handle;