[#1534] Fix handling of partial frames in SslHandler

* Let SslHandler not extend ByteToMessageDecoder
This commit is contained in:
Norman Maurer 2013-07-08 14:22:54 +02:00
parent 75229e145a
commit 1d196c5b59

View File

@ -19,6 +19,7 @@ import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil; import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled; import io.netty.buffer.Unpooled;
import io.netty.channel.Channel; import io.netty.channel.Channel;
import io.netty.channel.ChannelException;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
@ -183,6 +184,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
private final Queue<PendingWrite> pendingUnencryptedWrites = new ArrayDeque<PendingWrite>(); private final Queue<PendingWrite> pendingUnencryptedWrites = new ArrayDeque<PendingWrite>();
private int packetLength; private int packetLength;
private ByteBuf decodeOut;
private volatile long handshakeTimeoutMillis = 10000; private volatile long handshakeTimeoutMillis = 10000;
private volatile long closeNotifyTimeoutMillis = 3000; private volatile long closeNotifyTimeoutMillis = 3000;
@ -343,9 +345,17 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
} }
@Override @Override
protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
if (decodeOut != null) { if (decodeOut != null) {
decodeOut.release(); decodeOut.release();
decodeOut = null;
}
for (;;) {
PendingWrite write = pendingUnencryptedWrites.poll();
if (write == null) {
break;
}
write.fail(new ChannelException("Pending write on removal of SslHandler"));
} }
} }
@ -495,7 +505,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
} }
if (unwrapLater) { if (unwrapLater) {
decode0(ctx); unwrapLater(ctx);
} }
} catch (SSLException e) { } catch (SSLException e) {
setHandshakeFailure(e); setHandshakeFailure(e);
@ -517,6 +527,16 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
} }
} }
private void unwrapLater(ChannelHandlerContext ctx) throws SSLException {
MessageList<Object> messageList = MessageList.newInstance();
decode(ctx, internalBuffer(), messageList);
if (messageList.isEmpty()) {
messageList.recycle();
} else {
ctx.fireMessageReceived(messageList);
}
}
private void flushNonAppData0(ChannelHandlerContext ctx) throws SSLException { private void flushNonAppData0(ChannelHandlerContext ctx) throws SSLException {
boolean unwrapLater = false; boolean unwrapLater = false;
ByteBuf out = null; ByteBuf out = null;
@ -567,7 +587,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
} }
if (unwrapLater) { if (unwrapLater) {
decode0(ctx); unwrapLater(ctx);
} }
} catch (SSLException e) { } catch (SSLException e) {
setHandshakeFailure(e); setHandshakeFailure(e);
@ -782,26 +802,19 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
} }
@Override @Override
public void decode(final ChannelHandlerContext ctx, ByteBuf in, MessageList<Object> out) throws Exception { protected void decode(ChannelHandlerContext ctx, ByteBuf in, MessageList<Object> out) throws SSLException {
decode0(ctx);
}
private ByteBuf decodeOut;
private void decode0(final ChannelHandlerContext ctx) throws SSLException {
final ByteBuf in = internalBuffer();
// Check if the packet length was parsed yet, if so we can skip the parsing
final int readableBytes = in.readableBytes();
int packetLength = this.packetLength; int packetLength = this.packetLength;
if (packetLength == 0) { if (packetLength == 0) {
// the previous packet was consumed so try to read the length of the next packet
final int readableBytes = in.readableBytes();
if (readableBytes < 5) { if (readableBytes < 5) {
// not enough bytes readable to read the packet length
return; return;
} }
packetLength = getEncryptedPacketLength(in); packetLength = getEncryptedPacketLength(in);
if (packetLength == -1) { if (packetLength == -1) {
// Bad data - discard the buffer and raise an exception. // Not an SSL/TLS packet
NotSslRecordException e = new NotSslRecordException( NotSslRecordException e = new NotSslRecordException(
"not an SSL/TLS record: " + ByteBufUtil.hexDump(in)); "not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
in.skipBytes(readableBytes); in.skipBytes(readableBytes);
@ -814,10 +827,22 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
this.packetLength = packetLength; this.packetLength = packetLength;
} }
if (readableBytes < packetLength) { if (in.readableBytes() < packetLength) {
// wait until the whole packet can be read
return; return;
} }
try {
// slice out the whole packet so unwrap will only be called with complete packets
int readerIndex = in.readerIndex();
in.skipBytes(packetLength);
unwrap(ctx, in.nioBuffer(readerIndex, packetLength), out);
} finally {
this.packetLength = 0;
}
}
private void unwrap(ChannelHandlerContext ctx, ByteBuffer packet, MessageList<Object> out) throws SSLException {
boolean wrapLater = false; boolean wrapLater = false;
int bytesProduced = 0; int bytesProduced = 0;
try { try {
@ -826,9 +851,8 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
if (decodeOut == null) { if (decodeOut == null) {
decodeOut = ctx.alloc().buffer(); decodeOut = ctx.alloc().buffer();
} }
SSLEngineResult result = unwrap(engine, in, decodeOut); SSLEngineResult result = unwrap(engine, packet, decodeOut);
bytesProduced += result.bytesProduced(); bytesProduced += result.bytesProduced();
switch (result.getStatus()) { switch (result.getStatus()) {
case CLOSED: case CLOSED:
// notify about the CLOSED state of the SSLEngine. See #137 // notify about the CLOSED state of the SSLEngine. See #137
@ -870,23 +894,18 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
setHandshakeFailure(e); setHandshakeFailure(e);
throw e; throw e;
} finally { } finally {
// reset the packet length so it will be parsed again on the next call
this.packetLength = 0;
if (bytesProduced > 0) { if (bytesProduced > 0) {
ByteBuf decodeOut = this.decodeOut; ByteBuf decodeOut = this.decodeOut;
this.decodeOut = null; this.decodeOut = null;
ctx.fireMessageReceived(decodeOut); out.add(decodeOut);
} }
} }
} }
private static SSLEngineResult unwrap(SSLEngine engine, ByteBuf in, ByteBuf out) throws SSLException { private static SSLEngineResult unwrap(SSLEngine engine, ByteBuffer in, ByteBuf out) throws SSLException {
ByteBuffer in0 = in.nioBuffer();
for (;;) { for (;;) {
ByteBuffer out0 = out.nioBuffer(out.writerIndex(), out.writableBytes()); ByteBuffer out0 = out.nioBuffer(out.writerIndex(), out.writableBytes());
SSLEngineResult result = engine.unwrap(in0, out0); SSLEngineResult result = engine.unwrap(in, out0);
in.skipBytes(result.bytesConsumed());
out.writerIndex(out.writerIndex() + result.bytesProduced()); out.writerIndex(out.writerIndex() + result.bytesProduced());
switch (result.getStatus()) { switch (result.getStatus()) {
case BUFFER_OVERFLOW: case BUFFER_OVERFLOW: