diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index 10301ad58b..33920e3419 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -177,7 +177,7 @@ public class SslHandler extends ByteToMessageDecoder { private final boolean startTls; private boolean sentFirstMessage; - + private boolean flushedBeforeHandshakeDone; private final LazyChannelPromise handshakePromise = new LazyChannelPromise(); private final LazyChannelPromise sslCloseFuture = new LazyChannelPromise(); private final Deque pendingUnencryptedWrites = new ArrayDeque(); @@ -367,6 +367,9 @@ public class SslHandler extends ByteToMessageDecoder { if (pendingUnencryptedWrites.isEmpty()) { pendingUnencryptedWrites.add(PendingWrite.newInstance(Unpooled.EMPTY_BUFFER, null)); } + if (!handshakePromise.isDone()) { + flushedBeforeHandshakeDone = true; + } wrap(ctx, false); ctx.flush(); } @@ -389,7 +392,6 @@ public class SslHandler extends ByteToMessageDecoder { pendingUnencryptedWrites.remove(); continue; } - ByteBuf buf = (ByteBuf) pending.msg(); SSLEngineResult result = wrap(engine, buf, out); @@ -842,7 +844,6 @@ public class SslHandler extends ByteToMessageDecoder { private void unwrapMultiple( ChannelHandlerContext ctx, ByteBuffer packet, int totalLength, int[] recordLengths, int nRecords, List out) throws SSLException { - for (int i = 0; i < nRecords; i ++) { packet.limit(packet.position() + recordLengths[i]); try { @@ -901,6 +902,14 @@ public class SslHandler extends ByteToMessageDecoder { wrapLater = true; continue; } + if (flushedBeforeHandshakeDone) { + // We need to call wrap(...) in case there was a flush done before the handshake completed. + // + // See https://github.com/netty/netty/pull/2437 + flushedBeforeHandshakeDone = false; + wrapLater = true; + } + break; default: throw new IllegalStateException("Unknown handshake status: " + handshakeStatus); diff --git a/testsuite/src/test/java/io/netty/testsuite/transport/socket/SocketSslGreetingTest.java b/testsuite/src/test/java/io/netty/testsuite/transport/socket/SocketSslGreetingTest.java new file mode 100644 index 0000000000..953833e79e --- /dev/null +++ b/testsuite/src/test/java/io/netty/testsuite/transport/socket/SocketSslGreetingTest.java @@ -0,0 +1,158 @@ +/* +* Copyright 2014 The Netty Project +* +* The Netty Project licenses this file to you under the Apache License, +* version 2.0 (the "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at: +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +package io.netty.testsuite.transport.socket; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.socket.SocketChannel; +import io.netty.handler.logging.LogLevel; +import io.netty.handler.logging.LoggingHandler; +import io.netty.handler.ssl.SslHandler; +import io.netty.testsuite.util.BogusSslContextFactory; +import io.netty.util.ReferenceCountUtil; +import org.junit.Test; + +import javax.net.ssl.SSLEngine; +import java.io.IOException; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.Assert.assertEquals; + + +public class SocketSslGreetingTest extends AbstractSocketTest { + + private static final LogLevel LOG_LEVEL = LogLevel.TRACE; + private final ByteBuf greeting = ReferenceCountUtil.releaseLater(Unpooled.buffer().writeByte('a')); + + // Test for https://github.com/netty/netty/pull/2437 + @Test(timeout = 30000) + public void testSslGreeting() throws Throwable { + run(); + } + + public void testSslGreeting(ServerBootstrap sb, Bootstrap cb) throws Throwable { + final ServerHandler sh = new ServerHandler(); + final ClientHandler ch = new ClientHandler(); + + sb.childHandler(new ChannelInitializer() { + @Override + public void initChannel(SocketChannel sch) throws Exception { + final SSLEngine sse = BogusSslContextFactory.getServerContext().createSSLEngine(); + sse.setUseClientMode(false); + + ChannelPipeline p = sch.pipeline(); + p.addLast(new SslHandler(sse)); + p.addLast("logger", new LoggingHandler(LOG_LEVEL)); + p.addLast(sh); + } + }); + + cb.handler(new ChannelInitializer() { + @Override + public void initChannel(SocketChannel sch) throws Exception { + final SSLEngine cse = BogusSslContextFactory.getClientContext().createSSLEngine(); + cse.setUseClientMode(true); + + ChannelPipeline p = sch.pipeline(); + p.addLast(new SslHandler(cse)); + p.addLast("logger", new LoggingHandler(LOG_LEVEL)); + p.addLast(ch); + } + }); + + Channel sc = sb.bind().sync().channel(); + Channel cc = cb.connect().sync().channel(); + + ch.latch.await(); + + sh.channel.close().awaitUninterruptibly(); + cc.close().awaitUninterruptibly(); + sc.close().awaitUninterruptibly(); + + if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) { + throw sh.exception.get(); + } + if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) { + throw ch.exception.get(); + } + if (sh.exception.get() != null) { + throw sh.exception.get(); + } + if (ch.exception.get() != null) { + throw ch.exception.get(); + } + } + + private class ClientHandler extends SimpleChannelInboundHandler { + + final AtomicReference exception = new AtomicReference(); + final CountDownLatch latch = new CountDownLatch(1); + + @Override + public void channelRead0(ChannelHandlerContext ctx, ByteBuf buf) throws Exception { + assertEquals(greeting, buf); + latch.countDown(); + ctx.close(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, + Throwable cause) throws Exception { + if (logger.isWarnEnabled()) { + logger.warn("Unexpected exception from the client side", cause); + } + + exception.compareAndSet(null, cause); + ctx.close(); + } + } + + private class ServerHandler extends SimpleChannelInboundHandler { + volatile Channel channel; + final AtomicReference exception = new AtomicReference(); + + @Override + protected void channelRead0(ChannelHandlerContext ctx, String msg) throws Exception { + // discard + } + + @Override + public void channelActive(ChannelHandlerContext ctx) + throws Exception { + channel = ctx.channel(); + channel.writeAndFlush(greeting.duplicate().retain()); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, + Throwable cause) throws Exception { + if (logger.isWarnEnabled()) { + logger.warn("Unexpected exception from the server side", cause); + } + + exception.compareAndSet(null, cause); + ctx.close(); + } + } +}