From 7d80182e512ab01d7348ef0ca71c85073d837a0a Mon Sep 17 00:00:00 2001 From: Trustin Lee Date: Tue, 1 Jan 2013 15:03:37 +0900 Subject: [PATCH] Fix a bug where SslHandler does not respect the startTls flag - Fixes #856 - Add a dedicated test case: SocketStartTlsTest --- .../java/io/netty/handler/ssl/SslHandler.java | 12 +- .../transport/socket/SocketStartTlsTest.java | 226 ++++++++++++++++++ 2 files changed, 234 insertions(+), 4 deletions(-) create mode 100644 testsuite/src/test/java/io/netty/testsuite/transport/socket/SocketStartTlsTest.java diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index 524cf08aea..5afcdb66bc 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -319,7 +319,7 @@ public class SslHandler } engine.beginHandshake(); handshakePromises.add(promise); - flush(ctx, ctx.newPromise()); + flush0(ctx, ctx.newPromise(), true); } catch (Exception e) { if (promise.tryFailure(e)) { ctx.fireExceptionCaught(e); @@ -408,6 +408,10 @@ public class SslHandler @Override public void flush(final ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + flush0(ctx, promise, false); + } + + private void flush0(ChannelHandlerContext ctx, ChannelPromise promise, boolean internal) throws Exception { final ByteBuf in = ctx.outboundByteBuffer(); final ByteBuf out = ctx.nextOutboundByteBuffer(); @@ -415,7 +419,7 @@ public class SslHandler // Do not encrypt the first write request if this handler is // created with startTLS flag turned on. - if (startTls && !sentFirstMessage) { + if (!internal && startTls && !sentFirstMessage) { sentFirstMessage = true; out.writeBytes(in); ctx.flush(promise); @@ -824,7 +828,7 @@ public class SslHandler } if (wrapLater) { - flush(ctx, ctx.newPromise()); + flush0(ctx, ctx.newPromise(), true); } } catch (SSLException e) { setHandshakeFailure(e); @@ -932,7 +936,7 @@ public class SslHandler engine.closeOutbound(); ChannelPromise closeNotifyFuture = ctx.newPromise(); - flush(ctx, closeNotifyFuture); + flush0(ctx, closeNotifyFuture, true); safeClose(ctx, closeNotifyFuture, promise); } diff --git a/testsuite/src/test/java/io/netty/testsuite/transport/socket/SocketStartTlsTest.java b/testsuite/src/test/java/io/netty/testsuite/transport/socket/SocketStartTlsTest.java new file mode 100644 index 0000000000..ce9f045a50 --- /dev/null +++ b/testsuite/src/test/java/io/netty/testsuite/transport/socket/SocketStartTlsTest.java @@ -0,0 +1,226 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.testsuite.transport.socket; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundMessageHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.DefaultEventExecutorGroup; +import io.netty.channel.EventExecutorGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.handler.codec.LineBasedFrameDecoder; +import io.netty.handler.codec.string.StringDecoder; +import io.netty.handler.codec.string.StringEncoder; +import io.netty.handler.logging.ByteLoggingHandler; +import io.netty.handler.logging.LogLevel; +import io.netty.handler.ssl.SslHandler; +import io.netty.testsuite.util.BogusSslContextFactory; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import javax.net.ssl.SSLEngine; +import java.io.IOException; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.Assert.*; + +public class SocketStartTlsTest extends AbstractSocketTest { + + private static final LogLevel LOG_LEVEL = LogLevel.TRACE; + private static EventExecutorGroup executor; + + @BeforeClass + public static void createExecutor() { + executor = new DefaultEventExecutorGroup(2); + } + + @AfterClass + public static void shutdownExecutor() { + executor.shutdown(); + } + + @Test(timeout = 30000) + public void testStartTls() throws Throwable { + run(); + } + + public void testStartTls(ServerBootstrap sb, Bootstrap cb) throws Throwable { + final EventExecutorGroup executor = SocketStartTlsTest.executor; + final SSLEngine sse = BogusSslContextFactory.getServerContext().createSSLEngine(); + final SSLEngine cse = BogusSslContextFactory.getClientContext().createSSLEngine(); + + final StartTlsServerHandler sh = new StartTlsServerHandler(sse); + final StartTlsClientHandler ch = new StartTlsClientHandler(cse); + + sb.childHandler(new ChannelInitializer() { + @Override + public void initChannel(SocketChannel sch) throws Exception { + ChannelPipeline p = sch.pipeline(); + p.addLast("logger", new ByteLoggingHandler(LOG_LEVEL)); + p.addLast(new LineBasedFrameDecoder(64), new StringDecoder(), new StringEncoder()); + p.addLast(executor, sh); + } + }); + + cb.handler(new ChannelInitializer() { + @Override + public void initChannel(SocketChannel sch) throws Exception { + ChannelPipeline p = sch.pipeline(); + p.addLast("logger", new ByteLoggingHandler(LOG_LEVEL)); + p.addLast(new LineBasedFrameDecoder(64), new StringDecoder(), new StringEncoder()); + p.addLast(executor, ch); + } + }); + + Channel sc = sb.bind().sync().channel(); + Channel cc = cb.connect().sync().channel(); + + while (cc.isActive()) { + if (sh.exception.get() != null) { + break; + } + if (ch.exception.get() != null) { + break; + } + + try { + Thread.sleep(1); + } catch (InterruptedException e) { + // Ignore. + } + } + + while (sh.channel.isActive()) { + if (sh.exception.get() != null) { + break; + } + if (ch.exception.get() != null) { + break; + } + + try { + Thread.sleep(1); + } catch (InterruptedException e) { + // Ignore. + } + } + + sh.channel.close().awaitUninterruptibly(); + cc.close().awaitUninterruptibly(); + sc.close().awaitUninterruptibly(); + + if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) { + throw sh.exception.get(); + } + if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) { + throw ch.exception.get(); + } + if (sh.exception.get() != null) { + throw sh.exception.get(); + } + if (ch.exception.get() != null) { + throw ch.exception.get(); + } + } + + private class StartTlsClientHandler extends ChannelInboundMessageHandlerAdapter { + private final SslHandler sslHandler; + private ChannelFuture handshakeFuture; + final AtomicReference exception = new AtomicReference(); + + StartTlsClientHandler(SSLEngine engine) { + engine.setUseClientMode(true); + sslHandler = new SslHandler(engine); + } + + @Override + public void channelActive(ChannelHandlerContext ctx) + throws Exception { + ctx.write("StartTlsRequest\n"); + } + + @Override + protected void messageReceived(final ChannelHandlerContext ctx, String msg) throws Exception { + if ("StartTlsResponse".equals(msg)) { + ctx.pipeline().addAfter("logger", "ssl", sslHandler); + handshakeFuture = sslHandler.handshake(); + ctx.write("EncryptedRequest\n"); + return; + } + + assertEquals("EncryptedResponse", msg); + assertNotNull(handshakeFuture); + assertTrue(handshakeFuture.isSuccess()); + ctx.close(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, + Throwable cause) throws Exception { + if (logger.isWarnEnabled()) { + logger.warn("Unexpected exception from the client side", cause); + } + + exception.compareAndSet(null, cause); + ctx.close(); + } + } + + private class StartTlsServerHandler extends ChannelInboundMessageHandlerAdapter { + private final SslHandler sslHandler; + volatile Channel channel; + final AtomicReference exception = new AtomicReference(); + + StartTlsServerHandler(SSLEngine engine) { + engine.setUseClientMode(false); + sslHandler = new SslHandler(engine, true); + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + channel = ctx.channel(); + } + + @Override + protected void messageReceived(final ChannelHandlerContext ctx, String msg) throws Exception { + if ("StartTlsRequest".equals(msg)) { + ctx.pipeline().addAfter("logger", "ssl", sslHandler); + ctx.write("StartTlsResponse\n"); + return; + } + + assertEquals("EncryptedRequest", msg); + ctx.write("EncryptedResponse\n"); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, + Throwable cause) throws Exception { + if (logger.isWarnEnabled()) { + logger.warn("Unexpected exception from the server side", cause); + } + + exception.compareAndSet(null, cause); + ctx.close(); + } + } +}