[#504] SslHandler.flush() notifies futures prematurely.

- Make use of ChannelFlushFutureNotifier to notify flush futures
  correctly
- Improve the test case to ensure this commit fixes the bug
This commit is contained in:
Trustin Lee 2012-08-19 17:36:58 +09:00
parent 3f101ad3d1
commit 2bb114bcb7
3 changed files with 48 additions and 10 deletions

View File

@ -19,6 +19,7 @@ import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFlushFutureNotifier;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerAdapter;
@ -161,6 +162,22 @@ public class SslHandler
private volatile ChannelHandlerContext ctx;
private final SSLEngine engine;
private final Executor delegatedTaskExecutor;
private final ChannelFlushFutureNotifier flushFutureNotifier = new ChannelFlushFutureNotifier() {
@Override
public synchronized void increaseWriteCounter(long delta) {
super.increaseWriteCounter(delta);
}
@Override
public synchronized void notifyFlushFutures() {
super.notifyFlushFutures();
}
@Override
public synchronized void notifyFlushFutures(Throwable cause) {
super.notifyFlushFutures(cause);
}
};
private final boolean startTls;
private boolean sentFirstMessage;
@ -330,7 +347,6 @@ public class SslHandler
closeOutboundAndChannel(ctx, future, false);
}
@Override
public void flush(final ChannelHandlerContext ctx, ChannelFuture future) throws Exception {
final ByteBuf in = ctx.outboundByteBuffer();
@ -347,6 +363,8 @@ public class SslHandler
return;
}
flushFutureNotifier.addFlushFuture(future, in.readableBytes());
boolean unwrapLater = false;
int bytesProduced = 0;
try {
@ -399,7 +417,8 @@ public class SslHandler
throw e;
} finally {
in.unsafe().discardSomeReadBytes();
ctx.flush(future);
flushFutureNotifier.increaseWriteCounter(bytesProduced);
ctx.flush(ctx.newFuture().addListener(flushFutureNotifier));
}
}

View File

@ -22,6 +22,7 @@ import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundByteHandlerAdapter;
import io.netty.channel.ChannelInitializer;
@ -38,6 +39,7 @@ import java.security.Security;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.Random;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.KeyManagerFactory;
@ -52,6 +54,7 @@ import org.junit.Test;
public class SocketSslEchoTest extends AbstractSocketTest {
private static final int FIRST_MESSAGE_SIZE = 16384;
private static final Random random = new Random();
static final byte[] data = new byte[1048576];
@ -92,9 +95,20 @@ public class SocketSslEchoTest extends AbstractSocketTest {
Channel sc = sb.bind().sync().channel();
Channel cc = cb.connect().sync().channel();
ChannelFuture hf = cc.pipeline().get(SslHandler.class).handshake();
final ChannelFuture firstByteWriteFuture =
cc.write(Unpooled.wrappedBuffer(data, 0, FIRST_MESSAGE_SIZE));
final AtomicBoolean firstByteWriteFutureDone = new AtomicBoolean();
hf.addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
firstByteWriteFutureDone.set(firstByteWriteFuture.isDone());
}
});
hf.sync();
for (int i = 0; i < data.length;) {
assertFalse(firstByteWriteFutureDone.get());
for (int i = FIRST_MESSAGE_SIZE; i < data.length;) {
int length = Math.min(random.nextInt(1024 * 64), data.length - i);
cc.write(Unpooled.wrappedBuffer(data, i, length));
i += length;

View File

@ -18,13 +18,13 @@ package io.netty.channel;
import java.util.ArrayDeque;
import java.util.Deque;
public final class ChannelFlushFutureNotifier {
public class ChannelFlushFutureNotifier implements ChannelFutureListener {
private long writeCounter;
private final Deque<FlushCheckpoint> flushCheckpoints = new ArrayDeque<FlushCheckpoint>();
public void addFlushFuture(ChannelFuture future, int size) {
long checkpoint = writeCounter + size;
public void addFlushFuture(ChannelFuture future, int pendingDataSize) {
long checkpoint = writeCounter + pendingDataSize;
if (future instanceof FlushCheckpoint) {
FlushCheckpoint cp = (FlushCheckpoint) future;
cp.flushCheckpoint(checkpoint);
@ -34,10 +34,6 @@ public final class ChannelFlushFutureNotifier {
}
}
public long writeCounter() {
return writeCounter;
}
public void increaseWriteCounter(long delta) {
writeCounter += delta;
}
@ -91,6 +87,15 @@ public final class ChannelFlushFutureNotifier {
}
}
@Override
public void operationComplete(ChannelFuture future) throws Exception {
if (future.isSuccess()) {
notifyFlushFutures();
} else {
notifyFlushFutures(future.cause());
}
}
abstract static class FlushCheckpoint {
abstract long flushCheckpoint();
abstract void flushCheckpoint(long checkpoint);