Fix a regression in SslHandler where delegated tasks run in a different executor makes the session hang

- Fixes #2098
- Deprecate specifying an alternative Executor for delegated tasks for SslHandler
This commit is contained in:
Trustin Lee 2014-01-09 18:07:36 +09:00
parent 0b8e732c6c
commit 53110a83b3
3 changed files with 125 additions and 64 deletions

View File

@ -51,8 +51,10 @@ import java.nio.channels.ClosedChannelException;
import java.nio.channels.DatagramChannel; import java.nio.channels.DatagramChannel;
import java.nio.channels.SocketChannel; import java.nio.channels.SocketChannel;
import java.util.ArrayDeque; import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque; import java.util.Deque;
import java.util.List; import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledFuture; import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
@ -204,7 +206,7 @@ public class SslHandler extends ByteToMessageDecoder {
* @param engine the {@link SSLEngine} this handler will use * @param engine the {@link SSLEngine} this handler will use
*/ */
public SslHandler(SSLEngine engine) { public SslHandler(SSLEngine engine) {
this(engine, ImmediateExecutor.INSTANCE); this(engine, false);
} }
/** /**
@ -214,35 +216,23 @@ public class SslHandler extends ByteToMessageDecoder {
* @param startTls {@code true} if the first write request shouldn't be * @param startTls {@code true} if the first write request shouldn't be
* encrypted by the {@link SSLEngine} * encrypted by the {@link SSLEngine}
*/ */
@SuppressWarnings("deprecation")
public SslHandler(SSLEngine engine, boolean startTls) { public SslHandler(SSLEngine engine, boolean startTls) {
this(engine, startTls, ImmediateExecutor.INSTANCE); this(engine, startTls, ImmediateExecutor.INSTANCE);
} }
/** /**
* Creates a new instance. * @deprecated Use {@link #SslHandler(SSLEngine)} instead.
*
* @param engine
* the {@link SSLEngine} this handler will use
* @param delegatedTaskExecutor
* the {@link Executor} which will execute the delegated task
* that {@link SSLEngine#getDelegatedTask()} will return
*/ */
@Deprecated
public SslHandler(SSLEngine engine, Executor delegatedTaskExecutor) { public SslHandler(SSLEngine engine, Executor delegatedTaskExecutor) {
this(engine, false, delegatedTaskExecutor); this(engine, false, delegatedTaskExecutor);
} }
/** /**
* Creates a new instance. * @deprecated Use {@link #SslHandler(SSLEngine, boolean)} instead.
*
* @param engine
* the {@link SSLEngine} this handler will use
* @param startTls
* {@code true} if the first write request shouldn't be encrypted
* by the {@link SSLEngine}
* @param delegatedTaskExecutor
* the {@link Executor} which will execute the delegated task
* that {@link SSLEngine#getDelegatedTask()} will return
*/ */
@Deprecated
public SslHandler(SSLEngine engine, boolean startTls, Executor delegatedTaskExecutor) { public SslHandler(SSLEngine engine, boolean startTls, Executor delegatedTaskExecutor) {
if (engine == null) { if (engine == null) {
throw new NullPointerException("engine"); throw new NullPointerException("engine");
@ -947,14 +937,66 @@ public class SslHandler extends ByteToMessageDecoder {
} }
} }
/**
* Fetches all delegated tasks from the {@link SSLEngine} and runs them via the {@link #delegatedTaskExecutor}.
* If the {@link #delegatedTaskExecutor} is {@link ImmediateExecutor}, just call {@link Runnable#run()} directly
* instead of using {@link Executor#execute(Runnable)}. Otherwise, run the tasks via
* the {@link #delegatedTaskExecutor} and wait until the tasks are finished.
*/
private void runDelegatedTasks() { private void runDelegatedTasks() {
if (delegatedTaskExecutor == ImmediateExecutor.INSTANCE) {
for (;;) { for (;;) {
Runnable task = engine.getDelegatedTask(); Runnable task = engine.getDelegatedTask();
if (task == null) { if (task == null) {
break; break;
} }
delegatedTaskExecutor.execute(task); task.run();
}
} else {
final List<Runnable> tasks = new ArrayList<Runnable>(2);
for (;;) {
final Runnable task = engine.getDelegatedTask();
if (task == null) {
break;
}
tasks.add(task);
}
if (tasks.isEmpty()) {
return;
}
final CountDownLatch latch = new CountDownLatch(1);
delegatedTaskExecutor.execute(new Runnable() {
@Override
public void run() {
try {
for (Runnable task: tasks) {
task.run();
}
} catch (Exception e) {
ctx.fireExceptionCaught(e);
} finally {
latch.countDown();
}
}
});
boolean interrupted = false;
while (latch.getCount() != 0) {
try {
latch.await();
} catch (InterruptedException e) {
// Interrupt later.
interrupted = true;
}
}
if (interrupted) {
Thread.currentThread().interrupt();
}
} }
} }

View File

@ -70,8 +70,11 @@ public abstract class AbstractSocketTest {
"Running: %s %d of %d (%s + %s) with %s", "Running: %s %d of %d (%s + %s) with %s",
testName.getMethodName(), ++ i, COMBO.size(), sb, cb, StringUtil.simpleClassName(allocator))); testName.getMethodName(), ++ i, COMBO.size(), sb, cb, StringUtil.simpleClassName(allocator)));
try { try {
Method m = getClass().getDeclaredMethod( String testMethodName = testName.getMethodName();
testName.getMethodName(), ServerBootstrap.class, Bootstrap.class); if (testMethodName.contains("[")) {
testMethodName = testMethodName.substring(0, testMethodName.indexOf('['));
}
Method m = getClass().getDeclaredMethod(testMethodName, ServerBootstrap.class, Bootstrap.class);
m.invoke(this, sb, cb); m.invoke(this, sb, cb);
} catch (InvocationTargetException ex) { } catch (InvocationTargetException ex) {
throw ex.getCause(); throw ex.getCause();

View File

@ -30,15 +30,24 @@ import io.netty.handler.stream.ChunkedWriteHandler;
import io.netty.testsuite.util.BogusSslContextFactory; import io.netty.testsuite.util.BogusSslContextFactory;
import io.netty.util.concurrent.Future; import io.netty.util.concurrent.Future;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;
import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngine;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Random; import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import static org.junit.Assert.*; import static org.junit.Assert.*;
@RunWith(Parameterized.class)
public class SocketSslEchoTest extends AbstractSocketTest { public class SocketSslEchoTest extends AbstractSocketTest {
private static final int FIRST_MESSAGE_SIZE = 16384; private static final int FIRST_MESSAGE_SIZE = 16384;
@ -49,46 +58,42 @@ public class SocketSslEchoTest extends AbstractSocketTest {
random.nextBytes(data); random.nextBytes(data);
} }
@Parameters(name = "{index}: " +
"serverUsesDelegatedTaskExecutor = {0}, clientUsesDelegatedTaskExecutor = {1}, " +
"useChunkedWriteHandler = {2}, useCompositeByteBuf = {3}")
public static Collection<Object[]> data() {
List<Object[]> params = new ArrayList<Object[]>();
for (int i = 0; i < 16; i ++) {
params.add(new Object[] {
(i & 8) != 0, (i & 4) != 0, (i & 2) != 0, (i & 1) != 0
});
}
return params;
}
private final boolean serverUsesDelegatedTaskExecutor;
private final boolean clientUsesDelegatedTaskExecutor;
private final boolean useChunkedWriteHandler;
private final boolean useCompositeByteBuf;
public SocketSslEchoTest(
boolean serverUsesDelegatedTaskExecutor, boolean clientUsesDelegatedTaskExecutor,
boolean useChunkedWriteHandler, boolean useCompositeByteBuf) {
this.serverUsesDelegatedTaskExecutor = serverUsesDelegatedTaskExecutor;
this.clientUsesDelegatedTaskExecutor = clientUsesDelegatedTaskExecutor;
this.useChunkedWriteHandler = useChunkedWriteHandler;
this.useCompositeByteBuf = useCompositeByteBuf;
}
@Test @Test
public void testSslEcho() throws Throwable { public void testSslEcho() throws Throwable {
run(); run();
} }
public void testSslEcho(ServerBootstrap sb, Bootstrap cb) throws Throwable { public void testSslEcho(ServerBootstrap sb, Bootstrap cb) throws Throwable {
testSslEcho0(sb, cb, false, false); final ExecutorService delegatedTaskExecutor = Executors.newCachedThreadPool();
} final EchoHandler sh = new EchoHandler(true, useCompositeByteBuf);
final EchoHandler ch = new EchoHandler(false, useCompositeByteBuf);
@Test
public void testSslEchoComposite() throws Throwable {
run();
}
public void testSslEchoComposite(ServerBootstrap sb, Bootstrap cb) throws Throwable {
testSslEcho0(sb, cb, false, true);
}
@Test
public void testSslEchoWithChunkHandler() throws Throwable {
run();
}
public void testSslEchoWithChunkHandler(ServerBootstrap sb, Bootstrap cb) throws Throwable {
testSslEcho0(sb, cb, true, false);
}
@Test
public void testSslEchoWithChunkHandlerComposite() throws Throwable {
run();
}
public void testSslEchoWithChunkHandlerComposite(ServerBootstrap sb, Bootstrap cb) throws Throwable {
testSslEcho0(sb, cb, true, true);
}
private void testSslEcho0(ServerBootstrap sb, Bootstrap cb,
final boolean chunkWriteHandler, final boolean composite) throws Throwable {
final EchoHandler sh = new EchoHandler(true, composite);
final EchoHandler ch = new EchoHandler(false, composite);
final SSLEngine sse = BogusSslContextFactory.getServerContext().createSSLEngine(); final SSLEngine sse = BogusSslContextFactory.getServerContext().createSSLEngine();
final SSLEngine cse = BogusSslContextFactory.getClientContext().createSSLEngine(); final SSLEngine cse = BogusSslContextFactory.getClientContext().createSSLEngine();
@ -97,9 +102,14 @@ public class SocketSslEchoTest extends AbstractSocketTest {
sb.childHandler(new ChannelInitializer<SocketChannel>() { sb.childHandler(new ChannelInitializer<SocketChannel>() {
@Override @Override
@SuppressWarnings("deprecation")
public void initChannel(SocketChannel sch) throws Exception { public void initChannel(SocketChannel sch) throws Exception {
if (serverUsesDelegatedTaskExecutor) {
sch.pipeline().addFirst("ssl", new SslHandler(sse, delegatedTaskExecutor));
} else {
sch.pipeline().addFirst("ssl", new SslHandler(sse)); sch.pipeline().addFirst("ssl", new SslHandler(sse));
if (chunkWriteHandler) { }
if (useChunkedWriteHandler) {
sch.pipeline().addLast(new ChunkedWriteHandler()); sch.pipeline().addLast(new ChunkedWriteHandler());
} }
sch.pipeline().addLast("handler", sh); sch.pipeline().addLast("handler", sh);
@ -108,9 +118,14 @@ public class SocketSslEchoTest extends AbstractSocketTest {
cb.handler(new ChannelInitializer<SocketChannel>() { cb.handler(new ChannelInitializer<SocketChannel>() {
@Override @Override
@SuppressWarnings("deprecation")
public void initChannel(SocketChannel sch) throws Exception { public void initChannel(SocketChannel sch) throws Exception {
if (clientUsesDelegatedTaskExecutor) {
sch.pipeline().addFirst("ssl", new SslHandler(cse, delegatedTaskExecutor));
} else {
sch.pipeline().addFirst("ssl", new SslHandler(cse)); sch.pipeline().addFirst("ssl", new SslHandler(cse));
if (chunkWriteHandler) { }
if (useChunkedWriteHandler) {
sch.pipeline().addLast(new ChunkedWriteHandler()); sch.pipeline().addLast(new ChunkedWriteHandler());
} }
sch.pipeline().addLast("handler", ch); sch.pipeline().addLast("handler", ch);
@ -130,7 +145,7 @@ public class SocketSslEchoTest extends AbstractSocketTest {
for (int i = FIRST_MESSAGE_SIZE; i < data.length;) { for (int i = FIRST_MESSAGE_SIZE; i < data.length;) {
int length = Math.min(random.nextInt(1024 * 64), data.length - i); int length = Math.min(random.nextInt(1024 * 64), data.length - i);
ByteBuf buf = Unpooled.wrappedBuffer(data, i, length); ByteBuf buf = Unpooled.wrappedBuffer(data, i, length);
if (composite) { if (useCompositeByteBuf) {
buf = Unpooled.compositeBuffer().addComponent(buf).writerIndex(buf.writerIndex()); buf = Unpooled.compositeBuffer().addComponent(buf).writerIndex(buf.writerIndex());
} }
ChannelFuture future = cc.writeAndFlush(buf); ChannelFuture future = cc.writeAndFlush(buf);
@ -171,6 +186,7 @@ public class SocketSslEchoTest extends AbstractSocketTest {
sh.channel.close().awaitUninterruptibly(); sh.channel.close().awaitUninterruptibly();
ch.channel.close().awaitUninterruptibly(); ch.channel.close().awaitUninterruptibly();
sc.close().awaitUninterruptibly(); sc.close().awaitUninterruptibly();
delegatedTaskExecutor.shutdown();
if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) { if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
throw sh.exception.get(); throw sh.exception.get();