From 0aea9eaed56968970da87684b7f8e1ca05b69b85 Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Wed, 30 Apr 2014 12:31:41 +0200 Subject: [PATCH] Correctly write pending data after ssl handshake completes. Related to [#2437] Motivation: When writing data from a server before the ssl handshake completes may not be written at all to the remote peer if nothing else is written after the handshake was done. Modification: Correctly try to write pending data after the handshake was complete Result: Correctly write out all pending data --- .../java/io/netty/handler/ssl/SslHandler.java | 15 +- .../socket/SocketSslGreetingTest.java | 158 ++++++++++++++++++ 2 files changed, 170 insertions(+), 3 deletions(-) create mode 100644 testsuite/src/test/java/io/netty/testsuite/transport/socket/SocketSslGreetingTest.java diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index 68d1e94109..a76d35d244 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -183,7 +183,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH private final boolean startTls; private boolean sentFirstMessage; - + private boolean flushedBeforeHandshakeDone; private final LazyChannelPromise handshakePromise = new LazyChannelPromise(); private final LazyChannelPromise sslCloseFuture = new LazyChannelPromise(); private final Deque pendingUnencryptedWrites = new ArrayDeque(); @@ -415,6 +415,9 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH if (pendingUnencryptedWrites.isEmpty()) { pendingUnencryptedWrites.add(PendingWrite.newInstance(Unpooled.EMPTY_BUFFER, null)); } + if (!handshakePromise.isDone()) { + flushedBeforeHandshakeDone = true; + } wrap(ctx, false); ctx.flush(); } @@ -437,7 +440,6 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH pendingUnencryptedWrites.remove(); continue; } - ByteBuf buf = (ByteBuf) pending.msg(); SSLEngineResult result = wrap(engine, buf, out); @@ -890,7 +892,6 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH private void unwrapMultiple( ChannelHandlerContext ctx, ByteBuffer packet, int totalLength, int[] recordLengths, int nRecords, List out) throws SSLException { - for (int i = 0; i < nRecords; i ++) { packet.limit(packet.position() + recordLengths[i]); try { @@ -949,6 +950,14 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH wrapLater = true; continue; } + if (flushedBeforeHandshakeDone) { + // We need to call wrap(...) in case there was a flush done before the handshake completed. + // + // See https://github.com/netty/netty/pull/2437 + flushedBeforeHandshakeDone = false; + wrapLater = true; + } + break; default: throw new IllegalStateException("Unknown handshake status: " + handshakeStatus); diff --git a/testsuite/src/test/java/io/netty/testsuite/transport/socket/SocketSslGreetingTest.java b/testsuite/src/test/java/io/netty/testsuite/transport/socket/SocketSslGreetingTest.java new file mode 100644 index 0000000000..953833e79e --- /dev/null +++ b/testsuite/src/test/java/io/netty/testsuite/transport/socket/SocketSslGreetingTest.java @@ -0,0 +1,158 @@ +/* +* Copyright 2014 The Netty Project +* +* The Netty Project licenses this file to you under the Apache License, +* version 2.0 (the "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at: +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +package io.netty.testsuite.transport.socket; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.socket.SocketChannel; +import io.netty.handler.logging.LogLevel; +import io.netty.handler.logging.LoggingHandler; +import io.netty.handler.ssl.SslHandler; +import io.netty.testsuite.util.BogusSslContextFactory; +import io.netty.util.ReferenceCountUtil; +import org.junit.Test; + +import javax.net.ssl.SSLEngine; +import java.io.IOException; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.Assert.assertEquals; + + +public class SocketSslGreetingTest extends AbstractSocketTest { + + private static final LogLevel LOG_LEVEL = LogLevel.TRACE; + private final ByteBuf greeting = ReferenceCountUtil.releaseLater(Unpooled.buffer().writeByte('a')); + + // Test for https://github.com/netty/netty/pull/2437 + @Test(timeout = 30000) + public void testSslGreeting() throws Throwable { + run(); + } + + public void testSslGreeting(ServerBootstrap sb, Bootstrap cb) throws Throwable { + final ServerHandler sh = new ServerHandler(); + final ClientHandler ch = new ClientHandler(); + + sb.childHandler(new ChannelInitializer() { + @Override + public void initChannel(SocketChannel sch) throws Exception { + final SSLEngine sse = BogusSslContextFactory.getServerContext().createSSLEngine(); + sse.setUseClientMode(false); + + ChannelPipeline p = sch.pipeline(); + p.addLast(new SslHandler(sse)); + p.addLast("logger", new LoggingHandler(LOG_LEVEL)); + p.addLast(sh); + } + }); + + cb.handler(new ChannelInitializer() { + @Override + public void initChannel(SocketChannel sch) throws Exception { + final SSLEngine cse = BogusSslContextFactory.getClientContext().createSSLEngine(); + cse.setUseClientMode(true); + + ChannelPipeline p = sch.pipeline(); + p.addLast(new SslHandler(cse)); + p.addLast("logger", new LoggingHandler(LOG_LEVEL)); + p.addLast(ch); + } + }); + + Channel sc = sb.bind().sync().channel(); + Channel cc = cb.connect().sync().channel(); + + ch.latch.await(); + + sh.channel.close().awaitUninterruptibly(); + cc.close().awaitUninterruptibly(); + sc.close().awaitUninterruptibly(); + + if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) { + throw sh.exception.get(); + } + if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) { + throw ch.exception.get(); + } + if (sh.exception.get() != null) { + throw sh.exception.get(); + } + if (ch.exception.get() != null) { + throw ch.exception.get(); + } + } + + private class ClientHandler extends SimpleChannelInboundHandler { + + final AtomicReference exception = new AtomicReference(); + final CountDownLatch latch = new CountDownLatch(1); + + @Override + public void channelRead0(ChannelHandlerContext ctx, ByteBuf buf) throws Exception { + assertEquals(greeting, buf); + latch.countDown(); + ctx.close(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, + Throwable cause) throws Exception { + if (logger.isWarnEnabled()) { + logger.warn("Unexpected exception from the client side", cause); + } + + exception.compareAndSet(null, cause); + ctx.close(); + } + } + + private class ServerHandler extends SimpleChannelInboundHandler { + volatile Channel channel; + final AtomicReference exception = new AtomicReference(); + + @Override + protected void channelRead0(ChannelHandlerContext ctx, String msg) throws Exception { + // discard + } + + @Override + public void channelActive(ChannelHandlerContext ctx) + throws Exception { + channel = ctx.channel(); + channel.writeAndFlush(greeting.duplicate().retain()); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, + Throwable cause) throws Exception { + if (logger.isWarnEnabled()) { + logger.warn("Unexpected exception from the server side", cause); + } + + exception.compareAndSet(null, cause); + ctx.close(); + } + } +}