SslHandler LocalChannel read/unwrap reentry fix (#11156)

Motivation:
SslHandler invokes channel.read() during the handshake process. For some
channel implementations (e.g. LocalChannel) this may result in re-entry
conditions into unwrap. Unwrap currently defers updating the input
buffer indexes until the unwrap method returns to avoid intermediate
updates if not necessary, but this may result in unwrapping the same
contents multiple times which leads to handshake failures [1][2].

[1] ssl3_get_record:decryption failed or bad record mac
[2] ssl3_read_bytes:sslv3 alert bad record mac

Modifications:
- SslHandler#unwrap updates buffer indexes on each iteration so that if
  reentry scenario happens the correct indexes will be visible.

Result:
Fixes https://github.com/netty/netty/issues/11146
This commit is contained in:
Scott Mitchell 2021-04-16 08:21:40 -07:00
parent 42dc696c6c
commit 59867fa0fd
2 changed files with 193 additions and 177 deletions

View File

@ -15,9 +15,9 @@
*/ */
package io.netty.handler.ssl; package io.netty.handler.ssl;
import io.netty.buffer.ByteBufConvertible;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.ByteBufConvertible;
import io.netty.buffer.ByteBufUtil; import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.CompositeByteBuf; import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
@ -189,6 +189,7 @@ public class SslHandler extends ByteToMessageDecoder {
* when {@link ChannelConfig#isAutoRead()} is {@code false}. * when {@link ChannelConfig#isAutoRead()} is {@code false}.
*/ */
private static final int STATE_FIRE_CHANNEL_READ = 1 << 8; private static final int STATE_FIRE_CHANNEL_READ = 1 << 8;
private static final int STATE_UNWRAP_REENTRY = 1 << 9;
/** /**
* <a href="https://tools.ietf.org/html/rfc5246#section-6.2">2^14</a> which is the maximum sized plaintext chunk * <a href="https://tools.ietf.org/html/rfc5246#section-6.2">2^14</a> which is the maximum sized plaintext chunk
@ -199,8 +200,7 @@ public class SslHandler extends ByteToMessageDecoder {
private enum SslEngineType { private enum SslEngineType {
TCNATIVE(true, COMPOSITE_CUMULATOR) { TCNATIVE(true, COMPOSITE_CUMULATOR) {
@Override @Override
SSLEngineResult unwrap(SslHandler handler, ByteBuf in, int readerIndex, int len, ByteBuf out) SSLEngineResult unwrap(SslHandler handler, ByteBuf in, int len, ByteBuf out) throws SSLException {
throws SSLException {
int nioBufferCount = in.nioBufferCount(); int nioBufferCount = in.nioBufferCount();
int writerIndex = out.writerIndex(); int writerIndex = out.writerIndex();
final SSLEngineResult result; final SSLEngineResult result;
@ -212,14 +212,13 @@ public class SslHandler extends ByteToMessageDecoder {
*/ */
ReferenceCountedOpenSslEngine opensslEngine = (ReferenceCountedOpenSslEngine) handler.engine; ReferenceCountedOpenSslEngine opensslEngine = (ReferenceCountedOpenSslEngine) handler.engine;
try { try {
handler.singleBuffer[0] = toByteBuffer(out, writerIndex, handler.singleBuffer[0] = toByteBuffer(out, writerIndex, out.writableBytes());
out.writableBytes()); result = opensslEngine.unwrap(in.nioBuffers(in.readerIndex(), len), handler.singleBuffer);
result = opensslEngine.unwrap(in.nioBuffers(readerIndex, len), handler.singleBuffer);
} finally { } finally {
handler.singleBuffer[0] = null; handler.singleBuffer[0] = null;
} }
} else { } else {
result = handler.engine.unwrap(toByteBuffer(in, readerIndex, len), result = handler.engine.unwrap(toByteBuffer(in, in.readerIndex(), len),
toByteBuffer(out, writerIndex, out.writableBytes())); toByteBuffer(out, writerIndex, out.writableBytes()));
} }
out.writerIndex(writerIndex + result.bytesProduced()); out.writerIndex(writerIndex + result.bytesProduced());
@ -246,8 +245,7 @@ public class SslHandler extends ByteToMessageDecoder {
}, },
CONSCRYPT(true, COMPOSITE_CUMULATOR) { CONSCRYPT(true, COMPOSITE_CUMULATOR) {
@Override @Override
SSLEngineResult unwrap(SslHandler handler, ByteBuf in, int readerIndex, int len, ByteBuf out) SSLEngineResult unwrap(SslHandler handler, ByteBuf in, int len, ByteBuf out) throws SSLException {
throws SSLException {
int nioBufferCount = in.nioBufferCount(); int nioBufferCount = in.nioBufferCount();
int writerIndex = out.writerIndex(); int writerIndex = out.writerIndex();
final SSLEngineResult result; final SSLEngineResult result;
@ -258,13 +256,13 @@ public class SslHandler extends ByteToMessageDecoder {
try { try {
handler.singleBuffer[0] = toByteBuffer(out, writerIndex, out.writableBytes()); handler.singleBuffer[0] = toByteBuffer(out, writerIndex, out.writableBytes());
result = ((ConscryptAlpnSslEngine) handler.engine).unwrap( result = ((ConscryptAlpnSslEngine) handler.engine).unwrap(
in.nioBuffers(readerIndex, len), in.nioBuffers(in.readerIndex(), len),
handler.singleBuffer); handler.singleBuffer);
} finally { } finally {
handler.singleBuffer[0] = null; handler.singleBuffer[0] = null;
} }
} else { } else {
result = handler.engine.unwrap(toByteBuffer(in, readerIndex, len), result = handler.engine.unwrap(toByteBuffer(in, in.readerIndex(), len),
toByteBuffer(out, writerIndex, out.writableBytes())); toByteBuffer(out, writerIndex, out.writableBytes()));
} }
out.writerIndex(writerIndex + result.bytesProduced()); out.writerIndex(writerIndex + result.bytesProduced());
@ -290,10 +288,9 @@ public class SslHandler extends ByteToMessageDecoder {
}, },
JDK(false, MERGE_CUMULATOR) { JDK(false, MERGE_CUMULATOR) {
@Override @Override
SSLEngineResult unwrap(SslHandler handler, ByteBuf in, int readerIndex, int len, ByteBuf out) SSLEngineResult unwrap(SslHandler handler, ByteBuf in, int len, ByteBuf out) throws SSLException {
throws SSLException {
int writerIndex = out.writerIndex(); int writerIndex = out.writerIndex();
ByteBuffer inNioBuffer = toByteBuffer(in, readerIndex, len); ByteBuffer inNioBuffer = toByteBuffer(in, in.readerIndex(), len);
int position = inNioBuffer.position(); int position = inNioBuffer.position();
final SSLEngineResult result = handler.engine.unwrap(inNioBuffer, final SSLEngineResult result = handler.engine.unwrap(inNioBuffer,
toByteBuffer(out, writerIndex, out.writableBytes())); toByteBuffer(out, writerIndex, out.writableBytes()));
@ -349,8 +346,7 @@ public class SslHandler extends ByteToMessageDecoder {
this.cumulator = cumulator; this.cumulator = cumulator;
} }
abstract SSLEngineResult unwrap(SslHandler handler, ByteBuf in, int readerIndex, int len, ByteBuf out) abstract SSLEngineResult unwrap(SslHandler handler, ByteBuf in, int len, ByteBuf out) throws SSLException;
throws SSLException;
abstract int calculatePendingData(SslHandler handler, int guess); abstract int calculatePendingData(SslHandler handler, int guess);
@ -838,7 +834,6 @@ public class SslHandler extends ByteToMessageDecoder {
} }
SSLEngineResult result = wrap(alloc, engine, buf, out); SSLEngineResult result = wrap(alloc, engine, buf, out);
if (buf.isReadable()) { if (buf.isReadable()) {
pendingUnencryptedWrites.addFirst(buf, promise); pendingUnencryptedWrites.addFirst(buf, promise);
// When we add the buffer/promise pair back we need to be sure we don't complete the promise // When we add the buffer/promise pair back we need to be sure we don't complete the promise
@ -885,11 +880,9 @@ public class SslHandler extends ByteToMessageDecoder {
} }
break; break;
case FINISHED: case FINISHED:
case NOT_HANDSHAKING: // work around for android bug that skips the FINISHED state.
setHandshakeSuccess(); setHandshakeSuccess();
break; break;
case NOT_HANDSHAKING:
setHandshakeSuccessIfStillHandshaking();
break;
case NEED_WRAP: case NEED_WRAP:
// If we are expected to wrap again and we produced some data we need to ensure there // If we are expected to wrap again and we produced some data we need to ensure there
// is something in the queue to process as otherwise we will not try again before there // is something in the queue to process as otherwise we will not try again before there
@ -956,12 +949,11 @@ public class SslHandler extends ByteToMessageDecoder {
HandshakeStatus status = result.getHandshakeStatus(); HandshakeStatus status = result.getHandshakeStatus();
switch (status) { switch (status) {
case FINISHED: case FINISHED:
setHandshakeSuccess();
// We may be here because we read data and discovered the remote peer initiated a renegotiation // We may be here because we read data and discovered the remote peer initiated a renegotiation
// and this write is to complete the new handshake. The user may have previously done a // and this write is to complete the new handshake. The user may have previously done a
// writeAndFlush which wasn't able to wrap data due to needing the pending handshake, so we // writeAndFlush which wasn't able to wrap data due to needing the pending handshake, so we
// attempt to wrap application data here if any is pending. // attempt to wrap application data here if any is pending.
if (inUnwrap && !pendingUnencryptedWrites.isEmpty()) { if (setHandshakeSuccess() && inUnwrap && !pendingUnencryptedWrites.isEmpty()) {
wrap(ctx, true); wrap(ctx, true);
} }
return false; return false;
@ -973,19 +965,19 @@ public class SslHandler extends ByteToMessageDecoder {
} }
break; break;
case NEED_UNWRAP: case NEED_UNWRAP:
if (inUnwrap) { if (inUnwrap || unwrapNonAppData(ctx) <= 0) {
// If we asked for a wrap, the engine requested an unwrap, and we are in unwrap there is // If we asked for a wrap, the engine requested an unwrap, and we are in unwrap there is
// no use in trying to call wrap again because we have already attempted (or will after we // no use in trying to call wrap again because we have already attempted (or will after we
// return) to feed more data to the engine. // return) to feed more data to the engine.
return false; return false;
} }
unwrapNonAppData(ctx);
break; break;
case NEED_WRAP: case NEED_WRAP:
break; break;
case NOT_HANDSHAKING: case NOT_HANDSHAKING:
setHandshakeSuccessIfStillHandshaking(); if (setHandshakeSuccess() && inUnwrap && !pendingUnencryptedWrites.isEmpty()) {
wrap(ctx, true);
}
// Workaround for TLS False Start problem reported at: // Workaround for TLS False Start problem reported at:
// https://github.com/netty/netty/issues/1108#issuecomment-14266970 // https://github.com/netty/netty/issues/1108#issuecomment-14266970
if (!inUnwrap) { if (!inUnwrap) {
@ -1251,11 +1243,10 @@ public class SslHandler extends ByteToMessageDecoder {
// be consumed by the SSLEngine. // be consumed by the SSLEngine.
this.packetLength = 0; this.packetLength = 0;
try { try {
int bytesConsumed = unwrap(ctx, in, in.readerIndex(), packetLength); final int bytesConsumed = unwrap(ctx, in, packetLength);
assert bytesConsumed == packetLength || engine.isInboundDone() : assert bytesConsumed == packetLength || engine.isInboundDone() :
"we feed the SSLEngine a packets worth of data: " + packetLength + " but it only consumed: " + "we feed the SSLEngine a packets worth of data: " + packetLength + " but it only consumed: " +
bytesConsumed; bytesConsumed;
in.skipBytes(bytesConsumed);
} catch (Throwable cause) { } catch (Throwable cause) {
handleUnwrapThrowable(ctx, cause); handleUnwrapThrowable(ctx, cause);
} }
@ -1263,7 +1254,7 @@ public class SslHandler extends ByteToMessageDecoder {
private void decodeNonJdkCompatible(ChannelHandlerContext ctx, ByteBuf in) { private void decodeNonJdkCompatible(ChannelHandlerContext ctx, ByteBuf in) {
try { try {
in.skipBytes(unwrap(ctx, in, in.readerIndex(), in.readableBytes())); unwrap(ctx, in, in.readableBytes());
} catch (Throwable cause) { } catch (Throwable cause) {
handleUnwrapThrowable(ctx, cause); handleUnwrapThrowable(ctx, cause);
} }
@ -1339,139 +1330,97 @@ public class SslHandler extends ByteToMessageDecoder {
/** /**
* Calls {@link SSLEngine#unwrap(ByteBuffer, ByteBuffer)} with an empty buffer to handle handshakes, etc. * Calls {@link SSLEngine#unwrap(ByteBuffer, ByteBuffer)} with an empty buffer to handle handshakes, etc.
*/ */
private void unwrapNonAppData(ChannelHandlerContext ctx) throws SSLException { private int unwrapNonAppData(ChannelHandlerContext ctx) throws SSLException {
unwrap(ctx, Unpooled.EMPTY_BUFFER, 0, 0); return unwrap(ctx, Unpooled.EMPTY_BUFFER, 0);
} }
/** /**
* Unwraps inbound SSL records. * Unwraps inbound SSL records.
*/ */
private int unwrap( private int unwrap(ChannelHandlerContext ctx, ByteBuf packet, int length) throws SSLException {
ChannelHandlerContext ctx, ByteBuf packet, int offset, int length) throws SSLException {
final int originalLength = length; final int originalLength = length;
boolean wrapLater = false; boolean wrapLater = false;
boolean notifyClosure = false; boolean notifyClosure = false;
int overflowReadableBytes = -1; boolean executedRead = false;
ByteBuf decodeOut = allocate(ctx, length); ByteBuf decodeOut = allocate(ctx, length);
try { try {
// Only continue to loop if the handler was not removed in the meantime. // Only continue to loop if the handler was not removed in the meantime.
// See https://github.com/netty/netty/issues/5860 // See https://github.com/netty/netty/issues/5860
unwrapLoop: while (!ctx.isRemoved()) { do {
final SSLEngineResult result = engineType.unwrap(this, packet, offset, length, decodeOut); final SSLEngineResult result = engineType.unwrap(this, packet, length, decodeOut);
final Status status = result.getStatus(); final Status status = result.getStatus();
final HandshakeStatus handshakeStatus = result.getHandshakeStatus(); final HandshakeStatus handshakeStatus = result.getHandshakeStatus();
final int produced = result.bytesProduced(); final int produced = result.bytesProduced();
final int consumed = result.bytesConsumed(); final int consumed = result.bytesConsumed();
// Update indexes for the next iteration // Skip bytes now in case unwrap is called in a re-entry scenario. For example LocalChannel.read()
offset += consumed; // may entry this method in a re-entry fashion and if the peer is writing into a shared buffer we may
// unwrap the same data multiple times.
packet.skipBytes(consumed);
length -= consumed; length -= consumed;
switch (status) { // The expected sequence of events is:
case BUFFER_OVERFLOW: // 1. Notify of handshake success
final int readableBytes = decodeOut.readableBytes(); // 2. fireChannelRead for unwrapped data
final int previousOverflowReadableBytes = overflowReadableBytes; if (handshakeStatus == HandshakeStatus.FINISHED || handshakeStatus == HandshakeStatus.NOT_HANDSHAKING) {
overflowReadableBytes = readableBytes; wrapLater |= (decodeOut.isReadable() ?
int bufferSize = engine.getSession().getApplicationBufferSize() - readableBytes; setHandshakeSuccessUnwrapMarkReentry() : setHandshakeSuccess()) ||
if (readableBytes > 0) { handshakeStatus == HandshakeStatus.FINISHED;
}
// Dispatch decoded data after we have notified of handshake success. If this method has been invoked
// in a re-entry fashion we execute a task on the executor queue to process after the stack unwinds
// to preserve order of events.
if (decodeOut.isReadable()) {
setState(STATE_FIRE_CHANNEL_READ); setState(STATE_FIRE_CHANNEL_READ);
ctx.fireChannelRead(decodeOut); if (isStateSet(STATE_UNWRAP_REENTRY)) {
executedRead = true;
// This buffer was handled, null it out. executeChannelRead(ctx, decodeOut);
decodeOut = null;
if (bufferSize <= 0) {
// It may happen that readableBytes >= engine.getSession().getApplicationBufferSize()
// while there is still more to unwrap, in this case we will just allocate a new buffer
// with the capacity of engine.getSession().getApplicationBufferSize() and call unwrap
// again.
bufferSize = engine.getSession().getApplicationBufferSize();
}
} else { } else {
// This buffer was handled, null it out. ctx.fireChannelRead(decodeOut);
decodeOut.release(); }
decodeOut = null; decodeOut = null;
} }
if (readableBytes == 0 && previousOverflowReadableBytes == 0) {
// If there is two consecutive loops where we overflow and are not able to consume any data, if (status == Status.CLOSED) {
// assume the amount of data exceeds the maximum amount for the engine and bail notifyClosure = true; // notify about the CLOSED state of the SSLEngine. See #137
throw new IllegalStateException("Two consecutive overflows but no content was consumed. " + } else if (status == Status.BUFFER_OVERFLOW) {
SSLSession.class.getSimpleName() + " getApplicationBufferSize: " + if (decodeOut != null) {
engine.getSession().getApplicationBufferSize() + " maybe too small."); decodeOut.release();
} }
final int applicationBufferSize = engine.getSession().getApplicationBufferSize();
// Allocate a new buffer which can hold all the rest data and loop again. // Allocate a new buffer which can hold all the rest data and loop again.
// TODO: We may want to reconsider how we calculate the length here as we may // It may happen that applicationBufferSize < produced while there is still more to unwrap, in this
// have more then one ssl message to decode. // case we will just allocate a new buffer with the capacity of applicationBufferSize and call
decodeOut = allocate(ctx, engineType.calculatePendingData(this, bufferSize)); // unwrap again.
decodeOut = allocate(ctx, engineType.calculatePendingData(this, applicationBufferSize < produced ?
applicationBufferSize : applicationBufferSize - produced));
continue; continue;
case CLOSED:
// notify about the CLOSED state of the SSLEngine. See #137
notifyClosure = true;
overflowReadableBytes = -1;
break;
default:
overflowReadableBytes = -1;
break;
} }
switch (handshakeStatus) { if (handshakeStatus == HandshakeStatus.NEED_TASK) {
case NEED_UNWRAP:
break;
case NEED_WRAP:
// If the wrap operation transitions the status to NOT_HANDSHAKING and there is no more data to
// unwrap then the next call to unwrap will not produce any data. We can avoid the potentially
// costly unwrap operation and break out of the loop.
if (wrapNonAppData(ctx, true) && length == 0) {
break unwrapLoop;
}
break;
case NEED_TASK:
if (!runDelegatedTasks(true)) { if (!runDelegatedTasks(true)) {
// We scheduled a task on the delegatingTaskExecutor, so stop processing as we will // We scheduled a task on the delegatingTaskExecutor, so stop processing as we will
// resume once the task completes. // resume once the task completes.
// //
// We break out of the loop only and do NOT return here as we still may need to notify // We break out of the loop only and do NOT return here as we still may need to notify
// about the closure of the SSLEngine. // about the closure of the SSLEngine.
//
wrapLater = false; wrapLater = false;
break unwrapLoop;
}
break; break;
case FINISHED:
setHandshakeSuccess();
wrapLater = true;
// We 'break' here and NOT 'continue' as android API version 21 has a bug where they consume
// data from the buffer but NOT correctly set the SSLEngineResult.bytesConsumed().
// Because of this it will raise an exception on the next iteration of the for loop on android
// API version 21. Just doing a break will work here as produced and consumed will both be 0
// and so we break out of the complete for (;;) loop and so call decode(...) again later on.
// On other platforms this will have no negative effect as we will just continue with the
// for (;;) loop if something was either consumed or produced.
//
// See:
// - https://github.com/netty/netty/issues/4116
// - https://code.google.com/p/android/issues/detail?id=198639&thanks=198639&ts=1452501203
break;
case NOT_HANDSHAKING:
if (setHandshakeSuccessIfStillHandshaking()) {
wrapLater = true;
continue;
}
// If we are not handshaking and there is no more data to unwrap then the next call to unwrap
// will not produce any data. We can avoid the potentially costly unwrap operation and break
// out of the loop.
if (length == 0) {
break unwrapLoop;
} }
} else if (handshakeStatus == HandshakeStatus.NEED_WRAP) {
// If the wrap operation transitions the status to NOT_HANDSHAKING and there is no more data to
// unwrap then the next call to unwrap will not produce any data. We can avoid the potentially
// costly unwrap operation and break out of the loop.
if (wrapNonAppData(ctx, true) && length == 0) {
break; break;
default: }
throw new IllegalStateException("unknown handshake status: " + handshakeStatus);
} }
if (status == Status.BUFFER_UNDERFLOW || if (status == Status.BUFFER_UNDERFLOW ||
// If we processed NEED_TASK we should try again even we did not consume or produce anything. // If we processed NEED_TASK we should try again even we did not consume or produce anything.
handshakeStatus != HandshakeStatus.NEED_TASK && consumed == 0 && produced == 0) { handshakeStatus != HandshakeStatus.NEED_TASK && (consumed == 0 && produced == 0 ||
(length == 0 && handshakeStatus == HandshakeStatus.NOT_HANDSHAKING))) {
if (handshakeStatus == HandshakeStatus.NEED_UNWRAP) { if (handshakeStatus == HandshakeStatus.NEED_UNWRAP) {
// The underlying engine is starving so we need to feed it with more data. // The underlying engine is starving so we need to feed it with more data.
// See https://github.com/netty/netty/pull/5039 // See https://github.com/netty/netty/pull/5039
@ -1479,8 +1428,10 @@ public class SslHandler extends ByteToMessageDecoder {
} }
break; break;
} else if (decodeOut == null) {
decodeOut = allocate(ctx, length);
} }
} } while (!ctx.isRemoved());
if (isStateSet(STATE_FLUSHED_BEFORE_HANDSHAKE) && handshakePromise.isDone()) { if (isStateSet(STATE_FLUSHED_BEFORE_HANDSHAKE) && handshakePromise.isDone()) {
// We need to call wrap(...) in case there was a flush done before the handshake completed to ensure // We need to call wrap(...) in case there was a flush done before the handshake completed to ensure
@ -1496,22 +1447,55 @@ public class SslHandler extends ByteToMessageDecoder {
} }
} finally { } finally {
if (decodeOut != null) { if (decodeOut != null) {
if (decodeOut.isReadable()) {
setState(STATE_FIRE_CHANNEL_READ);
ctx.fireChannelRead(decodeOut);
} else {
decodeOut.release(); decodeOut.release();
} }
}
if (notifyClosure) { if (notifyClosure) {
if (executedRead) {
executeNotifyClosePromise(ctx);
} else {
notifyClosePromise(null); notifyClosePromise(null);
} }
} }
}
return originalLength - length; return originalLength - length;
} }
private boolean setHandshakeSuccessUnwrapMarkReentry() {
// setHandshakeSuccess calls out to external methods which may trigger re-entry. We need to preserve ordering of
// fireChannelRead for decodeOut relative to re-entry data.
final boolean setReentryState = !isStateSet(STATE_UNWRAP_REENTRY);
if (setReentryState) {
setState(STATE_UNWRAP_REENTRY);
}
try {
return setHandshakeSuccess();
} finally {
// It is unlikely this specific method will be re-entry because handshake completion is infrequent, but just
// in case we only clear the state if we set it in the first place.
if (setReentryState) {
clearState(STATE_UNWRAP_REENTRY);
}
}
}
private void executeNotifyClosePromise(final ChannelHandlerContext ctx) {
try {
ctx.executor().execute(() -> notifyClosePromise(null));
} catch (RejectedExecutionException e) {
notifyClosePromise(e);
}
}
private void executeChannelRead(final ChannelHandlerContext ctx, final ByteBuf decodedOut) {
try {
ctx.executor().execute(() -> ctx.fireChannelRead(decodedOut));
} catch (RejectedExecutionException e) {
decodedOut.release();
throw e;
}
}
private static ByteBuffer toByteBuffer(ByteBuf out, int index, int len) { private static ByteBuffer toByteBuffer(ByteBuf out, int index, int len) {
return out.nioBufferCount() == 1 ? out.internalNioBuffer(index, len) : return out.nioBufferCount() == 1 ? out.internalNioBuffer(index, len) :
out.nioBuffer(index, len); out.nioBuffer(index, len);
@ -1638,13 +1622,9 @@ public class SslHandler extends ByteToMessageDecoder {
// The handshake finished, lets notify about the completion of it and resume processing. // The handshake finished, lets notify about the completion of it and resume processing.
case FINISHED: case FINISHED:
setHandshakeSuccess();
// deliberate fall-through
// Not handshaking anymore, lets notify about the completion if not done yet and resume processing. // Not handshaking anymore, lets notify about the completion if not done yet and resume processing.
case NOT_HANDSHAKING: case NOT_HANDSHAKING:
setHandshakeSuccessIfStillHandshaking(); setHandshakeSuccess(); // NOT_HANDSHAKING -> workaround for android skipping FINISHED state.
try { try {
// Lets call wrap to ensure we produce the alert if there is any pending and also to // Lets call wrap to ensure we produce the alert if there is any pending and also to
// ensure we flush any queued data.. // ensure we flush any queued data..
@ -1744,32 +1724,17 @@ public class SslHandler extends ByteToMessageDecoder {
} }
} }
/**
* Works around some Android {@link SSLEngine} implementations that skip {@link HandshakeStatus#FINISHED} and
* go straight into {@link HandshakeStatus#NOT_HANDSHAKING} when handshake is finished.
*
* @return {@code true} if and only if the workaround has been applied and thus {@link #handshakeFuture} has been
* marked as success by this method
*/
private boolean setHandshakeSuccessIfStillHandshaking() {
return setHandshakeSuccess();
}
/** /**
* Notify all the handshake futures about the successfully handshake * Notify all the handshake futures about the successfully handshake
* @return {@code true} if {@link #handshakePromise} was set successfully and a {@link SslHandshakeCompletionEvent} * @return {@code true} if {@link #handshakePromise} was set successfully and a {@link SslHandshakeCompletionEvent}
* was fired. {@code false} otherwise. * was fired. {@code false} otherwise.
*/ */
private boolean setHandshakeSuccess() { private boolean setHandshakeSuccess() {
if (isStateSet(STATE_READ_DURING_HANDSHAKE) && !ctx.channel().config().isAutoRead()) {
clearState(STATE_READ_DURING_HANDSHAKE);
ctx.read();
}
// Our control flow may invoke this method multiple times for a single FINISHED event. For example // Our control flow may invoke this method multiple times for a single FINISHED event. For example
// wrapNonAppData may drain pendingUnencryptedWrites in wrap which transitions to handshake from FINISHED to // wrapNonAppData may drain pendingUnencryptedWrites in wrap which transitions to handshake from FINISHED to
// NOT_HANDSHAKING which invokes setHandshakeSuccessIfStillHandshaking, and then wrapNonAppData also directly // NOT_HANDSHAKING which invokes setHandshakeSuccess, and then wrapNonAppData also directly invokes this method.
// invokes this method. final boolean notified;
if (handshakePromise.trySuccess(ctx.channel())) { if (notified = !handshakePromise.isDone() && handshakePromise.trySuccess(ctx.channel())) {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
SSLSession session = engine.getSession(); SSLSession session = engine.getSession();
logger.debug( logger.debug(
@ -1779,9 +1744,14 @@ public class SslHandler extends ByteToMessageDecoder {
session.getCipherSuite()); session.getCipherSuite());
} }
ctx.fireUserEventTriggered(SslHandshakeCompletionEvent.SUCCESS); ctx.fireUserEventTriggered(SslHandshakeCompletionEvent.SUCCESS);
return true;
} }
return false; if (isStateSet(STATE_READ_DURING_HANDSHAKE)) {
clearState(STATE_READ_DURING_HANDSHAKE);
if (!ctx.channel().config().isAutoRead()) {
ctx.read();
}
}
return notified;
} }
/** /**

View File

@ -24,11 +24,16 @@ import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup; import io.netty.channel.EventLoopGroup;
import io.netty.channel.MultithreadEventLoopGroup; import io.netty.channel.MultithreadEventLoopGroup;
import io.netty.channel.ServerChannel;
import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.local.LocalAddress;
import io.netty.channel.local.LocalChannel;
import io.netty.channel.local.LocalHandler;
import io.netty.channel.local.LocalServerChannel;
import io.netty.channel.nio.NioHandler; import io.netty.channel.nio.NioHandler;
import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory; import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
@ -36,11 +41,11 @@ import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.netty.handler.ssl.util.SimpleTrustManagerFactory; import io.netty.handler.ssl.util.SimpleTrustManagerFactory;
import io.netty.util.CharsetUtil; import io.netty.util.CharsetUtil;
import io.netty.util.ReferenceCountUtil; import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.ResourcesUtil;
import io.netty.util.concurrent.FutureListener; import io.netty.util.concurrent.FutureListener;
import io.netty.util.concurrent.Promise; import io.netty.util.concurrent.Promise;
import io.netty.util.concurrent.PromiseNotifier; import io.netty.util.concurrent.PromiseNotifier;
import io.netty.util.internal.EmptyArrays; import io.netty.util.internal.EmptyArrays;
import io.netty.util.internal.ResourcesUtil;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.junit.runners.Parameterized; import org.junit.runners.Parameterized;
@ -50,6 +55,7 @@ import javax.net.ssl.SSLException;
import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager; import javax.net.ssl.X509TrustManager;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.security.KeyStore; import java.security.KeyStore;
import java.security.cert.CertificateException; import java.security.cert.CertificateException;
import java.security.cert.X509Certificate; import java.security.cert.X509Certificate;
@ -61,6 +67,7 @@ import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import static io.netty.buffer.ByteBufUtil.writeAscii; import static io.netty.buffer.ByteBufUtil.writeAscii;
import static java.util.concurrent.ThreadLocalRandom.current;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
@ -492,7 +499,41 @@ public class ParameterizedSslHandlerTest {
} }
@Test(timeout = 30000) @Test(timeout = 30000)
public void reentryWriteOnHandshakeComplete() throws Exception { public void reentryOnHandshakeCompleteNioChannel() throws Exception {
EventLoopGroup group = new MultithreadEventLoopGroup(NioHandler.newFactory());
try {
Class<? extends ServerChannel> serverClass = NioServerSocketChannel.class;
Class<? extends Channel> clientClass = NioSocketChannel.class;
SocketAddress bindAddress = new InetSocketAddress(0);
reentryOnHandshakeComplete(group, bindAddress, serverClass, clientClass, false, false);
reentryOnHandshakeComplete(group, bindAddress, serverClass, clientClass, false, true);
reentryOnHandshakeComplete(group, bindAddress, serverClass, clientClass, true, false);
reentryOnHandshakeComplete(group, bindAddress, serverClass, clientClass, true, true);
} finally {
group.shutdownGracefully();
}
}
@Test(timeout = 30000)
public void reentryOnHandshakeCompleteLocalChannel() throws Exception {
EventLoopGroup group = new MultithreadEventLoopGroup(LocalHandler.newFactory());
try {
Class<? extends ServerChannel> serverClass = LocalServerChannel.class;
Class<? extends Channel> clientClass = LocalChannel.class;
SocketAddress bindAddress = new LocalAddress(String.valueOf(current().nextLong()));
reentryOnHandshakeComplete(group, bindAddress, serverClass, clientClass, false, false);
reentryOnHandshakeComplete(group, bindAddress, serverClass, clientClass, false, true);
reentryOnHandshakeComplete(group, bindAddress, serverClass, clientClass, true, false);
reentryOnHandshakeComplete(group, bindAddress, serverClass, clientClass, true, true);
} finally {
group.shutdownGracefully();
}
}
private void reentryOnHandshakeComplete(EventLoopGroup group, SocketAddress bindAddress,
Class<? extends ServerChannel> serverClass,
Class<? extends Channel> clientClass, boolean serverAutoRead,
boolean clientAutoRead) throws Exception {
SelfSignedCertificate ssc = new SelfSignedCertificate(); SelfSignedCertificate ssc = new SelfSignedCertificate();
final SslContext sslServerCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) final SslContext sslServerCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey())
.sslProvider(serverProvider) .sslProvider(serverProvider)
@ -503,7 +544,6 @@ public class ParameterizedSslHandlerTest {
.sslProvider(clientProvider) .sslProvider(clientProvider)
.build(); .build();
EventLoopGroup group = new MultithreadEventLoopGroup(NioHandler.newFactory());
Channel sc = null; Channel sc = null;
Channel cc = null; Channel cc = null;
try { try {
@ -515,23 +555,25 @@ public class ParameterizedSslHandlerTest {
sc = new ServerBootstrap() sc = new ServerBootstrap()
.group(group) .group(group)
.channel(NioServerSocketChannel.class) .channel(serverClass)
.childOption(ChannelOption.AUTO_READ, serverAutoRead)
.childHandler(new ChannelInitializer<Channel>() { .childHandler(new ChannelInitializer<Channel>() {
@Override @Override
protected void initChannel(Channel ch) { protected void initChannel(Channel ch) {
ch.pipeline().addLast(sslServerCtx.newHandler(ch.alloc())); ch.pipeline().addLast(disableHandshakeTimeout(sslServerCtx.newHandler(ch.alloc())));
ch.pipeline().addLast(new ReentryWriteSslHandshakeHandler(expectedContent, serverQueue, ch.pipeline().addLast(new ReentryWriteSslHandshakeHandler(expectedContent, serverQueue,
serverLatch)); serverLatch));
} }
}).bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); }).bind(bindAddress).syncUninterruptibly().channel();
cc = new Bootstrap() cc = new Bootstrap()
.group(group) .group(group)
.channel(NioSocketChannel.class) .channel(clientClass)
.option(ChannelOption.AUTO_READ, clientAutoRead)
.handler(new ChannelInitializer<Channel>() { .handler(new ChannelInitializer<Channel>() {
@Override @Override
protected void initChannel(Channel ch) { protected void initChannel(Channel ch) {
ch.pipeline().addLast(sslClientCtx.newHandler(ch.alloc())); ch.pipeline().addLast(disableHandshakeTimeout(sslClientCtx.newHandler(ch.alloc())));
ch.pipeline().addLast(new ReentryWriteSslHandshakeHandler(expectedContent, clientQueue, ch.pipeline().addLast(new ReentryWriteSslHandshakeHandler(expectedContent, clientQueue,
clientLatch)); clientLatch));
} }
@ -548,13 +590,17 @@ public class ParameterizedSslHandlerTest {
if (sc != null) { if (sc != null) {
sc.close().syncUninterruptibly(); sc.close().syncUninterruptibly();
} }
group.shutdownGracefully();
ReferenceCountUtil.release(sslServerCtx); ReferenceCountUtil.release(sslServerCtx);
ReferenceCountUtil.release(sslClientCtx); ReferenceCountUtil.release(sslClientCtx);
} }
} }
private static SslHandler disableHandshakeTimeout(SslHandler handler) {
handler.setHandshakeTimeoutMillis(0);
return handler;
}
private static final class ReentryWriteSslHandshakeHandler extends SimpleChannelInboundHandler<ByteBuf> { private static final class ReentryWriteSslHandshakeHandler extends SimpleChannelInboundHandler<ByteBuf> {
private final String toWrite; private final String toWrite;
private final StringBuilder readQueue; private final StringBuilder readQueue;