Feed only a single SSL record to SSLEngine.unwrap()

Motivation:

Some SSLEngine implementations violate the contract and raises an
exception when SslHandler feeds an input buffer that contains multiple
SSL records to SSLEngine.unwrap(), while the expected behavior is to
decode the first record and return.

Modification:

- Modify SslHandler.decode() to keep the lengths of each record and feed
  SSLEngine.unwrap() record by record to work around the forementioned
  issue.
- Rename unwrap() to unwrapMultiple() and unwrapNonApp()
- Rename unwrap0() to unwrapSingle()

Result:

SslHandler now works OpenSSLEngine from finagle-native.  Performance
impact remains unnoticeable.  Slightly better readability. Fixes #2116.
This commit is contained in:
Trustin Lee 2014-04-20 17:20:00 +09:00
parent 104faba509
commit e74fa1b5d5

View File

@ -34,7 +34,6 @@ import io.netty.util.concurrent.GenericFutureListener;
import io.netty.util.internal.EmptyArrays; import io.netty.util.internal.EmptyArrays;
import io.netty.util.internal.PendingWrite; import io.netty.util.internal.PendingWrite;
import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.RecyclableArrayList;
import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory; import io.netty.util.internal.logging.InternalLoggerFactory;
@ -50,6 +49,7 @@ import java.nio.channels.ClosedChannelException;
import java.nio.channels.DatagramChannel; import java.nio.channels.DatagramChannel;
import java.nio.channels.SocketChannel; import java.nio.channels.SocketChannel;
import java.util.ArrayDeque; import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.Deque; import java.util.Deque;
import java.util.List; import java.util.List;
import java.util.concurrent.ScheduledFuture; import java.util.concurrent.ScheduledFuture;
@ -448,7 +448,8 @@ public class SslHandler extends ByteToMessageDecoder {
case NEED_UNWRAP: case NEED_UNWRAP:
return; return;
default: default:
throw new IllegalStateException("Unknown handshake status: " + result.getHandshakeStatus()); throw new IllegalStateException(
"Unknown handshake status: " + result.getHandshakeStatus());
} }
} }
} }
@ -505,7 +506,7 @@ public class SslHandler extends ByteToMessageDecoder {
break; break;
case NEED_UNWRAP: case NEED_UNWRAP:
if (!inUnwrap) { if (!inUnwrap) {
unwrap(ctx); unwrapNonApp(ctx);
} }
break; break;
case NEED_WRAP: case NEED_WRAP:
@ -515,7 +516,7 @@ public class SslHandler extends ByteToMessageDecoder {
// Workaround for TLS False Start problem reported at: // Workaround for TLS False Start problem reported at:
// https://github.com/netty/netty/issues/1108#issuecomment-14266970 // https://github.com/netty/netty/issues/1108#issuecomment-14266970
if (!inUnwrap) { if (!inUnwrap) {
unwrap(ctx); unwrapNonApp(ctx);
} }
break; break;
default: default:
@ -742,6 +743,11 @@ public class SslHandler extends ByteToMessageDecoder {
@Override @Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws SSLException { protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws SSLException {
// Keeps the list of the length of every SSL record in the input buffer.
int[] recordLengths = null;
int nRecords = 0;
final int startOffset = in.readerIndex(); final int startOffset = in.readerIndex();
final int endOffset = in.writerIndex(); final int endOffset = in.writerIndex();
int offset = startOffset; int offset = startOffset;
@ -751,6 +757,10 @@ public class SslHandler extends ByteToMessageDecoder {
if (endOffset - startOffset < packetLength) { if (endOffset - startOffset < packetLength) {
return; return;
} else { } else {
recordLengths = new int[4];
recordLengths[0] = packetLength;
nRecords = 1;
offset += packetLength; offset += packetLength;
packetLength = 0; packetLength = 0;
} }
@ -778,11 +788,22 @@ public class SslHandler extends ByteToMessageDecoder {
break; break;
} }
// We have a whole packet.
// Remember the length of the current packet.
if (recordLengths == null) {
recordLengths = new int[4];
}
if (nRecords == recordLengths.length) {
recordLengths = Arrays.copyOf(recordLengths, recordLengths.length << 1);
}
recordLengths[nRecords ++] = packetLength;
// Increment the offset to handle the next packet.
offset += packetLength; offset += packetLength;
} }
final int length = offset - startOffset; final int totalLength = offset - startOffset;
if (length > 0) { if (totalLength > 0) {
// The buffer contains one or more full SSL records. // The buffer contains one or more full SSL records.
// Slice out the whole packet so unwrap will only be called with complete packets. // Slice out the whole packet so unwrap will only be called with complete packets.
// Also directly reset the packetLength. This is needed as unwrap(..) may trigger // Also directly reset the packetLength. This is needed as unwrap(..) may trigger
@ -793,9 +814,9 @@ public class SslHandler extends ByteToMessageDecoder {
// 4) unwrapLater(...) calls decode(...) // 4) unwrapLater(...) calls decode(...)
// //
// See https://github.com/netty/netty/issues/1534 // See https://github.com/netty/netty/issues/1534
in.skipBytes(length); in.skipBytes(totalLength);
ByteBuffer buffer = in.nioBuffer(startOffset, length); ByteBuffer buffer = in.nioBuffer(startOffset, totalLength);
unwrap(ctx, buffer, out); unwrapMultiple(ctx, buffer, totalLength, recordLengths, nRecords, out);
} }
if (nonSslRecord) { if (nonSslRecord) {
@ -817,26 +838,54 @@ public class SslHandler extends ByteToMessageDecoder {
super.channelReadComplete(ctx); super.channelReadComplete(ctx);
} }
private void unwrap(ChannelHandlerContext ctx) throws SSLException { /**
RecyclableArrayList out = RecyclableArrayList.newInstance(); * Calls {@link SSLEngine#unwrap(ByteBuffer, ByteBuffer)} with an empty buffer to handle handshakes, etc.
*/
private void unwrapNonApp(ChannelHandlerContext ctx) throws SSLException {
try { try {
unwrap(ctx, Unpooled.EMPTY_BUFFER.nioBuffer(), out); unwrapSingle(ctx, Unpooled.EMPTY_BUFFER.nioBuffer(), 0);
final int size = out.size();
for (int i = 0; i < size; i++) {
ctx.fireChannelRead(out.get(i));
}
} finally { } finally {
out.recycle(); ByteBuf decodeOut = this.decodeOut;
if (decodeOut != null && decodeOut.isReadable()) {
this.decodeOut = null;
ctx.fireChannelRead(decodeOut);
}
} }
} }
private void unwrap(ChannelHandlerContext ctx, ByteBuffer packet, List<Object> out) throws SSLException { /**
* Unwraps multiple inbound SSL records.
*/
private void unwrapMultiple(
ChannelHandlerContext ctx, ByteBuffer packet, int totalLength,
int[] recordLengths, int nRecords, List<Object> out) throws SSLException {
for (int i = 0; i < nRecords; i ++) {
packet.limit(packet.position() + recordLengths[i]);
try {
unwrapSingle(ctx, packet, totalLength);
assert !packet.hasRemaining();
} finally {
ByteBuf decodeOut = this.decodeOut;
if (decodeOut != null && decodeOut.isReadable()) {
this.decodeOut = null;
out.add(decodeOut);
}
}
}
}
/**
* Unwraps a single SSL record.
*/
private void unwrapSingle(
ChannelHandlerContext ctx, ByteBuffer packet, int initialOutAppBufCapacity) throws SSLException {
boolean wrapLater = false; boolean wrapLater = false;
int totalProduced = 0;
try { try {
for (;;) { for (;;) {
if (decodeOut == null) { if (decodeOut == null) {
decodeOut = ctx.alloc().buffer(packet.remaining()); decodeOut = ctx.alloc().buffer(initialOutAppBufCapacity);
} }
final SSLEngineResult result = unwrap(engine, packet, decodeOut); final SSLEngineResult result = unwrap(engine, packet, decodeOut);
@ -845,7 +894,6 @@ public class SslHandler extends ByteToMessageDecoder {
final int produced = result.bytesProduced(); final int produced = result.bytesProduced();
final int consumed = result.bytesConsumed(); final int consumed = result.bytesConsumed();
totalProduced += produced;
if (status == Status.CLOSED) { if (status == Status.CLOSED) {
// notify about the CLOSED state of the SSLEngine. See #137 // notify about the CLOSED state of the SSLEngine. See #137
sslCloseFuture.trySuccess(ctx.channel()); sslCloseFuture.trySuccess(ctx.channel());
@ -886,12 +934,6 @@ public class SslHandler extends ByteToMessageDecoder {
} catch (SSLException e) { } catch (SSLException e) {
setHandshakeFailure(e); setHandshakeFailure(e);
throw e; throw e;
} finally {
if (totalProduced > 0) {
ByteBuf decodeOut = this.decodeOut;
this.decodeOut = null;
out.add(decodeOut);
}
} }
} }