Read if needed on NEED_UNWRAP

Motivation:

There are some use cases when a client may only be willing to read from a channel once
its previous write is finished (eg: serial dispatchers in Finagle). In this case, a
connection with SslHandler installed and ctx.channel().config().isAutoRead() == false
will stall in 100% of cases no matter what order of "channel active", "write", "flush"
events was.

The use case is following (how Finagle serial dispatchers work):

1. Client writeAndFlushes and waits on a write-promise to perform read() once it's satisfied.
2. A write-promise will only be satisfied once SslHandler finishes with handshaking and
   sends the unencrypted queued message.
3. The handshaking process itself requires a number of read()s done by a client but the
   SslHandler doesn't request them explicitly assuming that either auto-read is enabled
   or client requested at least one read() already.
4. At this point a client will stall with NEED_UNWRAP status returned from underlying engine.

Modifiations:

Always request a read() on NEED_UNWRAP returned from engine if

a) it's handshaking and
b) auto read is disabled and
c) it wasn't requested already.

Result:

SslHandler is now completely tolerant of whether or not auto-read is enabled and client
is explicitly reading a channel.
This commit is contained in:
Vladimir Kostyukov 2016-03-24 20:29:15 -07:00 committed by Norman Maurer
parent f0f014d0c7
commit 84bbbf7e09
2 changed files with 74 additions and 6 deletions

View File

@ -510,6 +510,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
ByteBuf out = null; ByteBuf out = null;
ChannelPromise promise = null; ChannelPromise promise = null;
ByteBufAllocator alloc = ctx.alloc(); ByteBufAllocator alloc = ctx.alloc();
boolean needUnwrap = false;
try { try {
for (;;) { for (;;) {
Object msg = pendingUnencryptedWrites.current(); Object msg = pendingUnencryptedWrites.current();
@ -546,11 +547,12 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
setHandshakeSuccessIfStillHandshaking(); setHandshakeSuccessIfStillHandshaking();
// deliberate fall-through // deliberate fall-through
case NEED_WRAP: case NEED_WRAP:
finishWrap(ctx, out, promise, inUnwrap); finishWrap(ctx, out, promise, inUnwrap, false);
promise = null; promise = null;
out = null; out = null;
break; break;
case NEED_UNWRAP: case NEED_UNWRAP:
needUnwrap = true;
return; return;
default: default:
throw new IllegalStateException( throw new IllegalStateException(
@ -562,11 +564,12 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
setHandshakeFailure(ctx, e); setHandshakeFailure(ctx, e);
throw e; throw e;
} finally { } finally {
finishWrap(ctx, out, promise, inUnwrap); finishWrap(ctx, out, promise, inUnwrap, needUnwrap);
} }
} }
private void finishWrap(ChannelHandlerContext ctx, ByteBuf out, ChannelPromise promise, boolean inUnwrap) { private void finishWrap(ChannelHandlerContext ctx, ByteBuf out, ChannelPromise promise, boolean inUnwrap,
boolean needUnwrap) {
if (out == null) { if (out == null) {
out = Unpooled.EMPTY_BUFFER; out = Unpooled.EMPTY_BUFFER;
} else if (!out.isReadable()) { } else if (!out.isReadable()) {
@ -583,6 +586,12 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
if (inUnwrap) { if (inUnwrap) {
needsFlush = true; needsFlush = true;
} }
if (needUnwrap) {
// The underlying engine is starving so we need to feed it with more data.
// See https://github.com/netty/netty/pull/5039
readIfNeeded(ctx);
}
} }
private void wrapNonAppData(ChannelHandlerContext ctx, boolean inUnwrap) throws SSLException { private void wrapNonAppData(ChannelHandlerContext ctx, boolean inUnwrap) throws SSLException {
@ -917,16 +926,19 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
discardSomeReadBytes(); discardSomeReadBytes();
flushIfNeeded(ctx); flushIfNeeded(ctx);
readIfNeeded(ctx);
firedChannelRead = false;
ctx.fireChannelReadComplete();
}
private void readIfNeeded(ChannelHandlerContext ctx) {
// If handshake is not finished yet, we need more data. // If handshake is not finished yet, we need more data.
if (!ctx.channel().config().isAutoRead() && (!firedChannelRead || !handshakePromise.isDone())) { if (!ctx.channel().config().isAutoRead() && (!firedChannelRead || !handshakePromise.isDone())) {
// No auto-read used and no message passed through the ChannelPipeline or the handhshake was not complete // No auto-read used and no message passed through the ChannelPipeline or the handhshake was not complete
// yet, which means we need to trigger the read to ensure we not encounter any stalls. // yet, which means we need to trigger the read to ensure we not encounter any stalls.
ctx.read(); ctx.read();
} }
firedChannelRead = false;
ctx.fireChannelReadComplete();
} }
private void flushIfNeeded(ChannelHandlerContext ctx) { private void flushIfNeeded(ChannelHandlerContext ctx) {
@ -1031,6 +1043,12 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
} }
if (status == Status.BUFFER_UNDERFLOW || consumed == 0 && produced == 0) { if (status == Status.BUFFER_UNDERFLOW || consumed == 0 && produced == 0) {
if (handshakeStatus == HandshakeStatus.NEED_UNWRAP) {
// The underlying engine is starving so we need to feed it with more data.
// See https://github.com/netty/netty/pull/5039
readIfNeeded(ctx);
}
break; break;
} }
} }

View File

@ -17,6 +17,10 @@
package io.netty.handler.ssl; package io.netty.handler.ssl;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.DecoderException; import io.netty.handler.codec.DecoderException;
import io.netty.handler.codec.UnsupportedMessageTypeException; import io.netty.handler.codec.UnsupportedMessageTypeException;
@ -63,4 +67,50 @@ public class SslHandlerTest {
ch.writeOutbound(new Object()); ch.writeOutbound(new Object());
} }
private static final class TlsReadTest extends ChannelOutboundHandlerAdapter {
private volatile boolean readIssued;
@Override
public void read(ChannelHandlerContext ctx) throws Exception {
readIssued = true;
super.read(ctx);
}
public void test(final boolean dropChannelActive) throws Exception {
SSLEngine engine = SSLContext.getDefault().createSSLEngine();
engine.setUseClientMode(true);
EmbeddedChannel ch = new EmbeddedChannel(
this,
new SslHandler(engine),
new ChannelInboundHandlerAdapter() {
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
if (!dropChannelActive) {
ctx.fireChannelActive();
}
}
}
);
ch.config().setAutoRead(false);
assertFalse(ch.config().isAutoRead());
assertTrue(ch.writeOutbound(Unpooled.EMPTY_BUFFER));
assertTrue(readIssued);
assertTrue(ch.finishAndReleaseAll());
}
}
@Test
public void testIssueReadAfterActiveWriteFlush() throws Exception {
// the handshake is initiated by channelActive
new TlsReadTest().test(false);
}
@Test
public void testIssueReadAfterWriteFlushActive() throws Exception {
// the handshake is initiated by flush
new TlsReadTest().test(true);
}
} }