[#1534] Fix handling of partial frames in SslHandler

* Let SslHandler not extend ByteToMessageDecoder
This commit is contained in:
Norman Maurer 2013-07-08 14:22:54 +02:00
parent 75229e145a
commit 1d196c5b59

View File

@ -19,6 +19,7 @@ import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelException;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
@ -183,6 +184,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
private final Queue<PendingWrite> pendingUnencryptedWrites = new ArrayDeque<PendingWrite>();
private int packetLength;
private ByteBuf decodeOut;
private volatile long handshakeTimeoutMillis = 10000;
private volatile long closeNotifyTimeoutMillis = 3000;
@ -343,9 +345,17 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
}
@Override
protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
if (decodeOut != null) {
decodeOut.release();
decodeOut = null;
}
for (;;) {
PendingWrite write = pendingUnencryptedWrites.poll();
if (write == null) {
break;
}
write.fail(new ChannelException("Pending write on removal of SslHandler"));
}
}
@ -495,7 +505,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
}
if (unwrapLater) {
decode0(ctx);
unwrapLater(ctx);
}
} catch (SSLException e) {
setHandshakeFailure(e);
@ -517,6 +527,16 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
}
}
private void unwrapLater(ChannelHandlerContext ctx) throws SSLException {
MessageList<Object> messageList = MessageList.newInstance();
decode(ctx, internalBuffer(), messageList);
if (messageList.isEmpty()) {
messageList.recycle();
} else {
ctx.fireMessageReceived(messageList);
}
}
private void flushNonAppData0(ChannelHandlerContext ctx) throws SSLException {
boolean unwrapLater = false;
ByteBuf out = null;
@ -567,7 +587,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
}
if (unwrapLater) {
decode0(ctx);
unwrapLater(ctx);
}
} catch (SSLException e) {
setHandshakeFailure(e);
@ -782,26 +802,19 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
}
@Override
public void decode(final ChannelHandlerContext ctx, ByteBuf in, MessageList<Object> out) throws Exception {
decode0(ctx);
}
private ByteBuf decodeOut;
private void decode0(final ChannelHandlerContext ctx) throws SSLException {
final ByteBuf in = internalBuffer();
// Check if the packet length was parsed yet, if so we can skip the parsing
final int readableBytes = in.readableBytes();
protected void decode(ChannelHandlerContext ctx, ByteBuf in, MessageList<Object> out) throws SSLException {
int packetLength = this.packetLength;
if (packetLength == 0) {
// the previous packet was consumed so try to read the length of the next packet
final int readableBytes = in.readableBytes();
if (readableBytes < 5) {
// not enough bytes readable to read the packet length
return;
}
packetLength = getEncryptedPacketLength(in);
if (packetLength == -1) {
// Bad data - discard the buffer and raise an exception.
// Not an SSL/TLS packet
NotSslRecordException e = new NotSslRecordException(
"not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
in.skipBytes(readableBytes);
@ -814,10 +827,22 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
this.packetLength = packetLength;
}
if (readableBytes < packetLength) {
if (in.readableBytes() < packetLength) {
// wait until the whole packet can be read
return;
}
try {
// slice out the whole packet so unwrap will only be called with complete packets
int readerIndex = in.readerIndex();
in.skipBytes(packetLength);
unwrap(ctx, in.nioBuffer(readerIndex, packetLength), out);
} finally {
this.packetLength = 0;
}
}
private void unwrap(ChannelHandlerContext ctx, ByteBuffer packet, MessageList<Object> out) throws SSLException {
boolean wrapLater = false;
int bytesProduced = 0;
try {
@ -826,9 +851,8 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
if (decodeOut == null) {
decodeOut = ctx.alloc().buffer();
}
SSLEngineResult result = unwrap(engine, in, decodeOut);
SSLEngineResult result = unwrap(engine, packet, decodeOut);
bytesProduced += result.bytesProduced();
switch (result.getStatus()) {
case CLOSED:
// notify about the CLOSED state of the SSLEngine. See #137
@ -870,23 +894,18 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
setHandshakeFailure(e);
throw e;
} finally {
// reset the packet length so it will be parsed again on the next call
this.packetLength = 0;
if (bytesProduced > 0) {
ByteBuf decodeOut = this.decodeOut;
this.decodeOut = null;
ctx.fireMessageReceived(decodeOut);
out.add(decodeOut);
}
}
}
private static SSLEngineResult unwrap(SSLEngine engine, ByteBuf in, ByteBuf out) throws SSLException {
ByteBuffer in0 = in.nioBuffer();
private static SSLEngineResult unwrap(SSLEngine engine, ByteBuffer in, ByteBuf out) throws SSLException {
for (;;) {
ByteBuffer out0 = out.nioBuffer(out.writerIndex(), out.writableBytes());
SSLEngineResult result = engine.unwrap(in0, out0);
in.skipBytes(result.bytesConsumed());
SSLEngineResult result = engine.unwrap(in, out0);
out.writerIndex(out.writerIndex() + result.bytesProduced());
switch (result.getStatus()) {
case BUFFER_OVERFLOW: