[#1534] Fix handling of partial frames in SslHandler
* Let SslHandler not extend ByteToMessageDecoder
This commit is contained in:
parent
75229e145a
commit
1d196c5b59
@ -19,6 +19,7 @@ import io.netty.buffer.ByteBuf;
|
|||||||
import io.netty.buffer.ByteBufUtil;
|
import io.netty.buffer.ByteBufUtil;
|
||||||
import io.netty.buffer.Unpooled;
|
import io.netty.buffer.Unpooled;
|
||||||
import io.netty.channel.Channel;
|
import io.netty.channel.Channel;
|
||||||
|
import io.netty.channel.ChannelException;
|
||||||
import io.netty.channel.ChannelFuture;
|
import io.netty.channel.ChannelFuture;
|
||||||
import io.netty.channel.ChannelFutureListener;
|
import io.netty.channel.ChannelFutureListener;
|
||||||
import io.netty.channel.ChannelHandlerContext;
|
import io.netty.channel.ChannelHandlerContext;
|
||||||
@ -183,6 +184,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
|||||||
private final Queue<PendingWrite> pendingUnencryptedWrites = new ArrayDeque<PendingWrite>();
|
private final Queue<PendingWrite> pendingUnencryptedWrites = new ArrayDeque<PendingWrite>();
|
||||||
|
|
||||||
private int packetLength;
|
private int packetLength;
|
||||||
|
private ByteBuf decodeOut;
|
||||||
|
|
||||||
private volatile long handshakeTimeoutMillis = 10000;
|
private volatile long handshakeTimeoutMillis = 10000;
|
||||||
private volatile long closeNotifyTimeoutMillis = 3000;
|
private volatile long closeNotifyTimeoutMillis = 3000;
|
||||||
@ -343,9 +345,17 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
|
public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception {
|
||||||
if (decodeOut != null) {
|
if (decodeOut != null) {
|
||||||
decodeOut.release();
|
decodeOut.release();
|
||||||
|
decodeOut = null;
|
||||||
|
}
|
||||||
|
for (;;) {
|
||||||
|
PendingWrite write = pendingUnencryptedWrites.poll();
|
||||||
|
if (write == null) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
write.fail(new ChannelException("Pending write on removal of SslHandler"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -495,7 +505,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (unwrapLater) {
|
if (unwrapLater) {
|
||||||
decode0(ctx);
|
unwrapLater(ctx);
|
||||||
}
|
}
|
||||||
} catch (SSLException e) {
|
} catch (SSLException e) {
|
||||||
setHandshakeFailure(e);
|
setHandshakeFailure(e);
|
||||||
@ -517,6 +527,16 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void unwrapLater(ChannelHandlerContext ctx) throws SSLException {
|
||||||
|
MessageList<Object> messageList = MessageList.newInstance();
|
||||||
|
decode(ctx, internalBuffer(), messageList);
|
||||||
|
if (messageList.isEmpty()) {
|
||||||
|
messageList.recycle();
|
||||||
|
} else {
|
||||||
|
ctx.fireMessageReceived(messageList);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private void flushNonAppData0(ChannelHandlerContext ctx) throws SSLException {
|
private void flushNonAppData0(ChannelHandlerContext ctx) throws SSLException {
|
||||||
boolean unwrapLater = false;
|
boolean unwrapLater = false;
|
||||||
ByteBuf out = null;
|
ByteBuf out = null;
|
||||||
@ -567,7 +587,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (unwrapLater) {
|
if (unwrapLater) {
|
||||||
decode0(ctx);
|
unwrapLater(ctx);
|
||||||
}
|
}
|
||||||
} catch (SSLException e) {
|
} catch (SSLException e) {
|
||||||
setHandshakeFailure(e);
|
setHandshakeFailure(e);
|
||||||
@ -782,26 +802,19 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void decode(final ChannelHandlerContext ctx, ByteBuf in, MessageList<Object> out) throws Exception {
|
protected void decode(ChannelHandlerContext ctx, ByteBuf in, MessageList<Object> out) throws SSLException {
|
||||||
decode0(ctx);
|
|
||||||
}
|
|
||||||
|
|
||||||
private ByteBuf decodeOut;
|
|
||||||
|
|
||||||
private void decode0(final ChannelHandlerContext ctx) throws SSLException {
|
|
||||||
final ByteBuf in = internalBuffer();
|
|
||||||
|
|
||||||
// Check if the packet length was parsed yet, if so we can skip the parsing
|
|
||||||
final int readableBytes = in.readableBytes();
|
|
||||||
int packetLength = this.packetLength;
|
int packetLength = this.packetLength;
|
||||||
if (packetLength == 0) {
|
if (packetLength == 0) {
|
||||||
|
// the previous packet was consumed so try to read the length of the next packet
|
||||||
|
final int readableBytes = in.readableBytes();
|
||||||
if (readableBytes < 5) {
|
if (readableBytes < 5) {
|
||||||
|
// not enough bytes readable to read the packet length
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
packetLength = getEncryptedPacketLength(in);
|
packetLength = getEncryptedPacketLength(in);
|
||||||
if (packetLength == -1) {
|
if (packetLength == -1) {
|
||||||
// Bad data - discard the buffer and raise an exception.
|
// Not an SSL/TLS packet
|
||||||
NotSslRecordException e = new NotSslRecordException(
|
NotSslRecordException e = new NotSslRecordException(
|
||||||
"not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
|
"not an SSL/TLS record: " + ByteBufUtil.hexDump(in));
|
||||||
in.skipBytes(readableBytes);
|
in.skipBytes(readableBytes);
|
||||||
@ -814,10 +827,22 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
|||||||
this.packetLength = packetLength;
|
this.packetLength = packetLength;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (readableBytes < packetLength) {
|
if (in.readableBytes() < packetLength) {
|
||||||
|
// wait until the whole packet can be read
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
// slice out the whole packet so unwrap will only be called with complete packets
|
||||||
|
int readerIndex = in.readerIndex();
|
||||||
|
in.skipBytes(packetLength);
|
||||||
|
unwrap(ctx, in.nioBuffer(readerIndex, packetLength), out);
|
||||||
|
} finally {
|
||||||
|
this.packetLength = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void unwrap(ChannelHandlerContext ctx, ByteBuffer packet, MessageList<Object> out) throws SSLException {
|
||||||
boolean wrapLater = false;
|
boolean wrapLater = false;
|
||||||
int bytesProduced = 0;
|
int bytesProduced = 0;
|
||||||
try {
|
try {
|
||||||
@ -826,9 +851,8 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
|||||||
if (decodeOut == null) {
|
if (decodeOut == null) {
|
||||||
decodeOut = ctx.alloc().buffer();
|
decodeOut = ctx.alloc().buffer();
|
||||||
}
|
}
|
||||||
SSLEngineResult result = unwrap(engine, in, decodeOut);
|
SSLEngineResult result = unwrap(engine, packet, decodeOut);
|
||||||
bytesProduced += result.bytesProduced();
|
bytesProduced += result.bytesProduced();
|
||||||
|
|
||||||
switch (result.getStatus()) {
|
switch (result.getStatus()) {
|
||||||
case CLOSED:
|
case CLOSED:
|
||||||
// notify about the CLOSED state of the SSLEngine. See #137
|
// notify about the CLOSED state of the SSLEngine. See #137
|
||||||
@ -870,23 +894,18 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
|
|||||||
setHandshakeFailure(e);
|
setHandshakeFailure(e);
|
||||||
throw e;
|
throw e;
|
||||||
} finally {
|
} finally {
|
||||||
// reset the packet length so it will be parsed again on the next call
|
|
||||||
this.packetLength = 0;
|
|
||||||
|
|
||||||
if (bytesProduced > 0) {
|
if (bytesProduced > 0) {
|
||||||
ByteBuf decodeOut = this.decodeOut;
|
ByteBuf decodeOut = this.decodeOut;
|
||||||
this.decodeOut = null;
|
this.decodeOut = null;
|
||||||
ctx.fireMessageReceived(decodeOut);
|
out.add(decodeOut);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static SSLEngineResult unwrap(SSLEngine engine, ByteBuf in, ByteBuf out) throws SSLException {
|
private static SSLEngineResult unwrap(SSLEngine engine, ByteBuffer in, ByteBuf out) throws SSLException {
|
||||||
ByteBuffer in0 = in.nioBuffer();
|
|
||||||
for (;;) {
|
for (;;) {
|
||||||
ByteBuffer out0 = out.nioBuffer(out.writerIndex(), out.writableBytes());
|
ByteBuffer out0 = out.nioBuffer(out.writerIndex(), out.writableBytes());
|
||||||
SSLEngineResult result = engine.unwrap(in0, out0);
|
SSLEngineResult result = engine.unwrap(in, out0);
|
||||||
in.skipBytes(result.bytesConsumed());
|
|
||||||
out.writerIndex(out.writerIndex() + result.bytesProduced());
|
out.writerIndex(out.writerIndex() + result.bytesProduced());
|
||||||
switch (result.getStatus()) {
|
switch (result.getStatus()) {
|
||||||
case BUFFER_OVERFLOW:
|
case BUFFER_OVERFLOW:
|
||||||
|
Loading…
Reference in New Issue
Block a user