diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index cea6e8d4d9..66d671f3da 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -164,26 +164,26 @@ public class SslHandler extends StreamToStreamCodec { private final SSLEngine engine; private final SslBufferPool bufferPool; private final Executor delegatedTaskExecutor; - + private final boolean startTls; private boolean sentFirstMessage; - + private volatile boolean enableRenegotiation = true; final Object handshakeLock = new Object(); - + private boolean handshaking; private volatile boolean handshaken; private ChannelFuture handshakeFuture; private boolean sentCloseNotify; - + int ignoreClosedChannelException; final Object ignoreClosedChannelExceptionLock = new Object(); private volatile boolean issueHandshake; private final SSLEngineInboundCloseFuture sslEngineCloseFuture = new SSLEngineInboundCloseFuture(); - + private int packetLength = -1; /** @@ -337,36 +337,36 @@ public class SslHandler extends StreamToStreamCodec { engine.beginHandshake(); runDelegatedTasks(); wrapNonAppData(ctx, channel); - + } catch (Exception e) { exception = e; } - + } else { ctx.executor().execute(new Runnable() { - + @Override public void run() { Throwable exception = null; synchronized (handshakeLock) { try { - + engine.beginHandshake(); runDelegatedTasks(); wrapNonAppData(ctx, ctx.channel()); - + } catch (Exception e) { exception = e; - } + } } if (exception != null) { // Failed to initiate handshake. handshakeFuture.setFailure(exception); ctx.fireExceptionCaught(exception); - } + } } }); } - + } } @@ -459,7 +459,7 @@ public class SslHandler extends StreamToStreamCodec { out.writeBytes(in); return; } - + ByteBuffer outNetBuf = bufferPool.acquireBuffer(); boolean success = true; boolean needsUnwrap = false; @@ -467,7 +467,7 @@ public class SslHandler extends StreamToStreamCodec { ByteBuffer outAppBuf = in.nioBuffer(); while(in.readable()) { - + int read; int remaining = outAppBuf.remaining(); SSLEngineResult result = null; @@ -477,14 +477,14 @@ public class SslHandler extends StreamToStreamCodec { } read = remaining - outAppBuf.remaining(); in.readerIndex(in.readerIndex() + read); - - + + if (result.bytesProduced() > 0) { outNetBuf.flip(); out.writeBytes(outNetBuf); outNetBuf.clear(); - + } else if (result.getStatus() == Status.CLOSED) { // SSLEngine has been closed already. @@ -519,7 +519,7 @@ public class SslHandler extends StreamToStreamCodec { throw new IllegalStateException("Unknown handshake status: " + handshakeStatus); } } - + } } catch (SSLException e) { success = false; @@ -527,14 +527,14 @@ public class SslHandler extends StreamToStreamCodec { throw e; } finally { bufferPool.releaseBuffer(outNetBuf); - + if (!success) { - // mark all bytes as read + // mark all bytes as read in.readerIndex(in.readerIndex() + in.readableBytes()); throw new IllegalStateException("SSLEngine already closed"); - - + + } } @@ -608,7 +608,7 @@ public class SslHandler extends StreamToStreamCodec { super.exceptionCaught(ctx, cause); } - + @Override public void decode(ChannelInboundHandlerContext ctx, ChannelBuffer in, ChannelBuffer out) throws Exception { @@ -647,7 +647,7 @@ public class SslHandler extends StreamToStreamCodec { tls = false; } } - + if (!tls) { // SSLv2 or bad data - Check the version boolean sslv2 = true; @@ -677,11 +677,11 @@ public class SslHandler extends StreamToStreamCodec { throw e; } } - + assert packetLength > 0; } - - + + if (in.readableBytes() < packetLength) { // not enough bytes left to read the packet // so return here for now @@ -1073,7 +1073,7 @@ public class SslHandler extends StreamToStreamCodec { @Override public void beforeAdd(ChannelHandlerContext ctx) throws Exception { this.ctx = ctx; - this.handshakeFuture = ctx.newFuture(); + handshakeFuture = ctx.newFuture(); } @@ -1102,7 +1102,7 @@ public class SslHandler extends StreamToStreamCodec { }); } else { super.channelActive(ctx); - } + } } private final class SSLEngineInboundCloseFuture extends DefaultChannelFuture { @@ -1134,6 +1134,6 @@ public class SslHandler extends StreamToStreamCodec { return false; } } - - + + } diff --git a/testsuite/src/test/java/io/netty/testsuite/transport/socket/AbstractSocketSslEchoTest.java b/testsuite/src/test/java/io/netty/testsuite/transport/socket/SocketSslEchoTest.java similarity index 87% rename from testsuite/src/test/java/io/netty/testsuite/transport/socket/AbstractSocketSslEchoTest.java rename to testsuite/src/test/java/io/netty/testsuite/transport/socket/SocketSslEchoTest.java index 7f65e7171c..f11602f082 100644 --- a/testsuite/src/test/java/io/netty/testsuite/transport/socket/AbstractSocketSslEchoTest.java +++ b/testsuite/src/test/java/io/netty/testsuite/transport/socket/SocketSslEchoTest.java @@ -16,22 +16,21 @@ package io.netty.testsuite.transport.socket; import static org.junit.Assert.*; +import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; import io.netty.buffer.ChannelBuffer; import io.netty.buffer.ChannelBuffers; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerContext; +import io.netty.channel.ChannelInboundStreamHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.socket.SocketChannel; import io.netty.handler.ssl.SslHandler; -import io.netty.logging.InternalLogger; -import io.netty.logging.InternalLoggerFactory; -import io.netty.util.SocketAddresses; -import io.netty.util.internal.ExecutorUtil; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; -import java.net.InetSocketAddress; import java.security.InvalidAlgorithmParameterException; import java.security.KeyStore; import java.security.KeyStoreException; @@ -39,9 +38,6 @@ import java.security.Security; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import java.util.Random; -import java.util.concurrent.Executor; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicReference; import javax.net.ssl.KeyManagerFactory; @@ -52,83 +48,49 @@ import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManagerFactorySpi; import javax.net.ssl.X509TrustManager; -import org.junit.AfterClass; -import org.junit.BeforeClass; import org.junit.Test; -public abstract class AbstractSocketSslEchoTest { - static final InternalLogger logger = - InternalLoggerFactory.getInstance(AbstractSocketSslEchoTest.class); +public class SocketSslEchoTest extends AbstractSocketTest { private static final Random random = new Random(); static final byte[] data = new byte[1048576]; - private static ExecutorService executor; - private static ExecutorService eventExecutor; - static { random.nextBytes(data); } - @BeforeClass - public static void init() { - executor = Executors.newCachedThreadPool(); - eventExecutor = new OrderedMemoryAwareThreadPoolExecutor(16, 0, 0); - } - - @AfterClass - public static void destroy() { - ExecutorUtil.terminate(executor, eventExecutor); - } - - protected abstract ChannelFactory newServerSocketChannelFactory(Executor executor); - protected abstract ChannelFactory newClientSocketChannelFactory(Executor executor); - - protected boolean isExecutorRequired() { - return false; - } - @Test public void testSslEcho() throws Throwable { - ServerBootstrap sb = new ServerBootstrap(newServerSocketChannelFactory(executor)); - ClientBootstrap cb = new ClientBootstrap(newClientSocketChannelFactory(executor)); + run(); + } - EchoHandler sh = new EchoHandler(true); - EchoHandler ch = new EchoHandler(false); + public void testSslEcho(ServerBootstrap sb, Bootstrap cb) throws Throwable { + final EchoHandler sh = new EchoHandler(true); + final EchoHandler ch = new EchoHandler(false); - SSLEngine sse = BogusSslContextFactory.getServerContext().createSSLEngine(); - SSLEngine cse = BogusSslContextFactory.getClientContext().createSSLEngine(); + final SSLEngine sse = BogusSslContextFactory.getServerContext().createSSLEngine(); + final SSLEngine cse = BogusSslContextFactory.getClientContext().createSSLEngine(); sse.setUseClientMode(false); cse.setUseClientMode(true); - // Workaround for blocking I/O transport write-write dead lock. - sb.setOption("receiveBufferSize", 1048576); - sb.setOption("receiveBufferSize", 1048576); - - sb.pipeline().addFirst("ssl", new SslHandler(sse)); - sb.pipeline().addLast("handler", sh); - cb.pipeline().addFirst("ssl", new SslHandler(cse)); - cb.pipeline().addLast("handler", ch); - - if (isExecutorRequired()) { - sb.pipeline().addFirst("executor", new ExecutionHandler(eventExecutor)); - cb.pipeline().addFirst("executor", new ExecutionHandler(eventExecutor)); - } - - Channel sc = sb.bind(new InetSocketAddress(0)); - int port = ((InetSocketAddress) sc.getLocalAddress()).getPort(); - - ChannelFuture ccf = cb.connect(new InetSocketAddress(SocketAddresses.LOCALHOST, port)); - ccf.awaitUninterruptibly(); - if (!ccf.isSuccess()) { - if(logger.isErrorEnabled()) { - logger.error("Connection attempt failed", ccf.cause()); + sb.childHandler(new ChannelInitializer() { + @Override + public void initChannel(SocketChannel sch) throws Exception { + sch.pipeline().addFirst("ssl", new SslHandler(sse)); + sch.pipeline().addLast("handler", sh); } - sc.close().awaitUninterruptibly(); - } - assertTrue(ccf.isSuccess()); + }); - Channel cc = ccf.channel(); + cb.handler(new ChannelInitializer() { + @Override + public void initChannel(SocketChannel sch) throws Exception { + sch.pipeline().addFirst("ssl", new SslHandler(cse)); + sch.pipeline().addLast("handler", ch); + } + }); + + Channel sc = sb.bind().sync().channel(); + Channel cc = cb.connect().sync().channel(); ChannelFuture hf = cc.pipeline().get(SslHandler.class).handshake(); hf.awaitUninterruptibly(); if (!hf.isSuccess()) { @@ -194,7 +156,7 @@ public abstract class AbstractSocketSslEchoTest { } } - private static class EchoHandler extends SimpleChannelUpstreamHandler { + private class EchoHandler extends ChannelInboundStreamHandlerAdapter { volatile Channel channel; final AtomicReference exception = new AtomicReference(); volatile int counter; @@ -205,41 +167,41 @@ public abstract class AbstractSocketSslEchoTest { } @Override - public void channelOpen(ChannelHandlerContext ctx, ChannelStateEvent e) + public void channelActive(ChannelInboundHandlerContext ctx) throws Exception { - channel = e.channel(); + channel = ctx.channel(); } @Override - public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) + public void inboundBufferUpdated( + ChannelInboundHandlerContext ctx, ChannelBuffer in) throws Exception { - ChannelBuffer m = (ChannelBuffer) e.getMessage(); - byte[] actual = new byte[m.readableBytes()]; - m.getBytes(0, actual); + byte[] actual = new byte[in.readableBytes()]; + in.readBytes(actual); int lastIdx = counter; for (int i = 0; i < actual.length; i ++) { assertEquals(data[i + lastIdx], actual[i]); } - if (channel.getParent() != null) { - channel.write(m); + if (channel.parent() != null) { + channel.write(ChannelBuffers.wrappedBuffer(actual)); } counter += actual.length; } @Override - public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) - throws Exception { + public void exceptionCaught(ChannelInboundHandlerContext ctx, + Throwable cause) throws Exception { if (logger.isWarnEnabled()) { logger.warn( "Unexpected exception from the " + - (server? "server" : "client") + " side", e.cause()); + (server? "server" : "client") + " side", cause); } - exception.compareAndSet(null, e.cause()); - e.channel().close(); + exception.compareAndSet(null, cause); + ctx.close(); } }