Feed only a single SSL record to SSLEngine.unwrap()

Motivation:

Some SSLEngine implementations violate the contract and raises an
exception when SslHandler feeds an input buffer that contains multiple
SSL records to SSLEngine.unwrap(), while the expected behavior is to
decode the first record and return.

Modification:

- Modify SslHandler.decode() to keep the lengths of each record and feed
  SSLEngine.unwrap() record by record to work around the forementioned
  issue.
- Rename unwrap() to unwrapMultiple() and unwrapNonApp()
- Rename unwrap0() to unwrapSingle()

Result:

SslHandler now works OpenSSLEngine from finagle-native.  Performance
impact remains unnoticeable.  Slightly better readability. Fixes #2116.
This commit is contained in:
Trustin Lee 2014-04-20 17:20:00 +09:00
parent d854d3a617
commit af9ab8b370

View File

@ -36,7 +36,6 @@ import io.netty.util.concurrent.ImmediateExecutor;
import io.netty.util.internal.EmptyArrays; import io.netty.util.internal.EmptyArrays;
import io.netty.util.internal.PendingWrite; import io.netty.util.internal.PendingWrite;
import io.netty.util.internal.PlatformDependent; import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.RecyclableArrayList;
import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory; import io.netty.util.internal.logging.InternalLoggerFactory;
@ -53,6 +52,7 @@ import java.nio.channels.DatagramChannel;
import java.nio.channels.SocketChannel; import java.nio.channels.SocketChannel;
import java.util.ArrayDeque; import java.util.ArrayDeque;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Deque; import java.util.Deque;
import java.util.List; import java.util.List;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
@ -480,7 +480,8 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
case NEED_UNWRAP: case NEED_UNWRAP:
return; return;
default: default:
throw new IllegalStateException("Unknown handshake status: " + result.getHandshakeStatus()); throw new IllegalStateException(
"Unknown handshake status: " + result.getHandshakeStatus());
} }
} }
} }
@ -537,7 +538,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
break; break;
case NEED_UNWRAP: case NEED_UNWRAP:
if (!inUnwrap) { if (!inUnwrap) {
unwrap(ctx); unwrapNonApp(ctx);
} }
break; break;
case NEED_WRAP: case NEED_WRAP:
@ -547,7 +548,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
// Workaround for TLS False Start problem reported at: // Workaround for TLS False Start problem reported at:
// https://github.com/netty/netty/issues/1108#issuecomment-14266970 // https://github.com/netty/netty/issues/1108#issuecomment-14266970
if (!inUnwrap) { if (!inUnwrap) {
unwrap(ctx); unwrapNonApp(ctx);
} }
break; break;
default: default:
@ -774,6 +775,11 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
@Override @Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws SSLException { protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws SSLException {
// Keeps the list of the length of every SSL record in the input buffer.
int[] recordLengths = null;
int nRecords = 0;
final int startOffset = in.readerIndex(); final int startOffset = in.readerIndex();
final int endOffset = in.writerIndex(); final int endOffset = in.writerIndex();
int offset = startOffset; int offset = startOffset;
@ -783,6 +789,10 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
if (endOffset - startOffset < packetLength) { if (endOffset - startOffset < packetLength) {
return; return;
} else { } else {
recordLengths = new int[4];
recordLengths[0] = packetLength;
nRecords = 1;
offset += packetLength; offset += packetLength;
packetLength = 0; packetLength = 0;
} }
@ -810,11 +820,22 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
break; break;
} }
// We have a whole packet.
// Remember the length of the current packet.
if (recordLengths == null) {
recordLengths = new int[4];
}
if (nRecords == recordLengths.length) {
recordLengths = Arrays.copyOf(recordLengths, recordLengths.length << 1);
}
recordLengths[nRecords ++] = packetLength;
// Increment the offset to handle the next packet.
offset += packetLength; offset += packetLength;
} }
final int length = offset - startOffset; final int totalLength = offset - startOffset;
if (length > 0) { if (totalLength > 0) {
// The buffer contains one or more full SSL records. // The buffer contains one or more full SSL records.
// Slice out the whole packet so unwrap will only be called with complete packets. // Slice out the whole packet so unwrap will only be called with complete packets.
// Also directly reset the packetLength. This is needed as unwrap(..) may trigger // Also directly reset the packetLength. This is needed as unwrap(..) may trigger
@ -825,9 +846,9 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
// 4) unwrapLater(...) calls decode(...) // 4) unwrapLater(...) calls decode(...)
// //
// See https://github.com/netty/netty/issues/1534 // See https://github.com/netty/netty/issues/1534
in.skipBytes(length); in.skipBytes(totalLength);
ByteBuffer buffer = in.nioBuffer(startOffset, length); ByteBuffer buffer = in.nioBuffer(startOffset, totalLength);
unwrap(ctx, buffer, out); unwrapMultiple(ctx, buffer, totalLength, recordLengths, nRecords, out);
} }
if (nonSslRecord) { if (nonSslRecord) {
@ -849,26 +870,54 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
super.channelReadComplete(ctx); super.channelReadComplete(ctx);
} }
private void unwrap(ChannelHandlerContext ctx) throws SSLException { /**
RecyclableArrayList out = RecyclableArrayList.newInstance(); * Calls {@link SSLEngine#unwrap(ByteBuffer, ByteBuffer)} with an empty buffer to handle handshakes, etc.
*/
private void unwrapNonApp(ChannelHandlerContext ctx) throws SSLException {
try { try {
unwrap(ctx, Unpooled.EMPTY_BUFFER.nioBuffer(), out); unwrapSingle(ctx, Unpooled.EMPTY_BUFFER.nioBuffer(), 0);
final int size = out.size();
for (int i = 0; i < size; i++) {
ctx.fireChannelRead(out.get(i));
}
} finally { } finally {
out.recycle(); ByteBuf decodeOut = this.decodeOut;
if (decodeOut != null && decodeOut.isReadable()) {
this.decodeOut = null;
ctx.fireChannelRead(decodeOut);
}
} }
} }
private void unwrap(ChannelHandlerContext ctx, ByteBuffer packet, List<Object> out) throws SSLException { /**
* Unwraps multiple inbound SSL records.
*/
private void unwrapMultiple(
ChannelHandlerContext ctx, ByteBuffer packet, int totalLength,
int[] recordLengths, int nRecords, List<Object> out) throws SSLException {
for (int i = 0; i < nRecords; i ++) {
packet.limit(packet.position() + recordLengths[i]);
try {
unwrapSingle(ctx, packet, totalLength);
assert !packet.hasRemaining();
} finally {
ByteBuf decodeOut = this.decodeOut;
if (decodeOut != null && decodeOut.isReadable()) {
this.decodeOut = null;
out.add(decodeOut);
}
}
}
}
/**
* Unwraps a single SSL record.
*/
private void unwrapSingle(
ChannelHandlerContext ctx, ByteBuffer packet, int initialOutAppBufCapacity) throws SSLException {
boolean wrapLater = false; boolean wrapLater = false;
int totalProduced = 0;
try { try {
for (;;) { for (;;) {
if (decodeOut == null) { if (decodeOut == null) {
decodeOut = ctx.alloc().buffer(packet.remaining()); decodeOut = ctx.alloc().buffer(initialOutAppBufCapacity);
} }
final SSLEngineResult result = unwrap(engine, packet, decodeOut); final SSLEngineResult result = unwrap(engine, packet, decodeOut);
@ -877,7 +926,6 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
final int produced = result.bytesProduced(); final int produced = result.bytesProduced();
final int consumed = result.bytesConsumed(); final int consumed = result.bytesConsumed();
totalProduced += produced;
if (status == Status.CLOSED) { if (status == Status.CLOSED) {
// notify about the CLOSED state of the SSLEngine. See #137 // notify about the CLOSED state of the SSLEngine. See #137
sslCloseFuture.trySuccess(ctx.channel()); sslCloseFuture.trySuccess(ctx.channel());
@ -918,12 +966,6 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
} catch (SSLException e) { } catch (SSLException e) {
setHandshakeFailure(e); setHandshakeFailure(e);
throw e; throw e;
} finally {
if (totalProduced > 0) {
ByteBuf decodeOut = this.decodeOut;
this.decodeOut = null;
out.add(decodeOut);
}
} }
} }