diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index 7bcc6255fb..2d1634fbeb 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -510,6 +510,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH ByteBuf out = null; ChannelPromise promise = null; ByteBufAllocator alloc = ctx.alloc(); + boolean needUnwrap = false; try { for (;;) { Object msg = pendingUnencryptedWrites.current(); @@ -546,11 +547,12 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH setHandshakeSuccessIfStillHandshaking(); // deliberate fall-through case NEED_WRAP: - finishWrap(ctx, out, promise, inUnwrap); + finishWrap(ctx, out, promise, inUnwrap, false); promise = null; out = null; break; case NEED_UNWRAP: + needUnwrap = true; return; default: throw new IllegalStateException( @@ -562,11 +564,12 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH setHandshakeFailure(ctx, e); throw e; } finally { - finishWrap(ctx, out, promise, inUnwrap); + finishWrap(ctx, out, promise, inUnwrap, needUnwrap); } } - private void finishWrap(ChannelHandlerContext ctx, ByteBuf out, ChannelPromise promise, boolean inUnwrap) { + private void finishWrap(ChannelHandlerContext ctx, ByteBuf out, ChannelPromise promise, boolean inUnwrap, + boolean needUnwrap) { if (out == null) { out = Unpooled.EMPTY_BUFFER; } else if (!out.isReadable()) { @@ -583,6 +586,12 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH if (inUnwrap) { needsFlush = true; } + + if (needUnwrap) { + // The underlying engine is starving so we need to feed it with more data. + // See https://github.com/netty/netty/pull/5039 + readIfNeeded(ctx); + } } private void wrapNonAppData(ChannelHandlerContext ctx, boolean inUnwrap) throws SSLException { @@ -917,16 +926,19 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH discardSomeReadBytes(); flushIfNeeded(ctx); + readIfNeeded(ctx); + firedChannelRead = false; + ctx.fireChannelReadComplete(); + } + + private void readIfNeeded(ChannelHandlerContext ctx) { // If handshake is not finished yet, we need more data. if (!ctx.channel().config().isAutoRead() && (!firedChannelRead || !handshakePromise.isDone())) { // No auto-read used and no message passed through the ChannelPipeline or the handhshake was not complete // yet, which means we need to trigger the read to ensure we not encounter any stalls. ctx.read(); } - - firedChannelRead = false; - ctx.fireChannelReadComplete(); } private void flushIfNeeded(ChannelHandlerContext ctx) { @@ -1031,6 +1043,12 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH } if (status == Status.BUFFER_UNDERFLOW || consumed == 0 && produced == 0) { + if (handshakeStatus == HandshakeStatus.NEED_UNWRAP) { + // The underlying engine is starving so we need to feed it with more data. + // See https://github.com/netty/netty/pull/5039 + readIfNeeded(ctx); + } + break; } } diff --git a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java index 151f0a7233..35ec497cf1 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SslHandlerTest.java @@ -17,6 +17,10 @@ package io.netty.handler.ssl; import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelOutboundHandlerAdapter; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.DecoderException; import io.netty.handler.codec.UnsupportedMessageTypeException; @@ -63,4 +67,50 @@ public class SslHandlerTest { ch.writeOutbound(new Object()); } + + private static final class TlsReadTest extends ChannelOutboundHandlerAdapter { + private volatile boolean readIssued; + + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + readIssued = true; + super.read(ctx); + } + + public void test(final boolean dropChannelActive) throws Exception { + SSLEngine engine = SSLContext.getDefault().createSSLEngine(); + engine.setUseClientMode(true); + + EmbeddedChannel ch = new EmbeddedChannel( + this, + new SslHandler(engine), + new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + if (!dropChannelActive) { + ctx.fireChannelActive(); + } + } + } + ); + ch.config().setAutoRead(false); + assertFalse(ch.config().isAutoRead()); + + assertTrue(ch.writeOutbound(Unpooled.EMPTY_BUFFER)); + assertTrue(readIssued); + assertTrue(ch.finishAndReleaseAll()); + } + } + + @Test + public void testIssueReadAfterActiveWriteFlush() throws Exception { + // the handshake is initiated by channelActive + new TlsReadTest().test(false); + } + + @Test + public void testIssueReadAfterWriteFlushActive() throws Exception { + // the handshake is initiated by flush + new TlsReadTest().test(true); + } }