Fix a bug where SslHandler doesn't sometimes handle renegotiation correctly

- Fixes #1964
This commit is contained in:
Trustin Lee 2013-11-04 16:52:07 +09:00
parent de2c6acecf
commit 2eb5d4f0dd

View File

@ -42,9 +42,9 @@ import io.netty.util.internal.logging.InternalLoggerFactory;
import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult; import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLEngineResult.Status; import javax.net.ssl.SSLEngineResult.Status;
import javax.net.ssl.SSLException; import javax.net.ssl.SSLException;
import java.io.IOException; import java.io.IOException;
import java.net.SocketAddress; import java.net.SocketAddress;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
@ -391,7 +391,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
@Override @Override
public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
pendingUnencryptedWrites.add(PendingWrite.newInstance((ByteBuf) msg, promise)); pendingUnencryptedWrites.add(PendingWrite.newInstance(msg, promise));
} }
@Override @Override
@ -713,7 +713,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
int majorVersion = buffer.getUnsignedByte(first + 1); int majorVersion = buffer.getUnsignedByte(first + 1);
if (majorVersion == 3) { if (majorVersion == 3) {
// SSLv3 or TLS // SSLv3 or TLS
packetLength = (buffer.getUnsignedShort(first + 3)) + 5; packetLength = buffer.getUnsignedShort(first + 3) + 5;
if (packetLength <= 5) { if (packetLength <= 5) {
// Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data) // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data)
tls = false; tls = false;
@ -814,25 +814,27 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
private void unwrap(ChannelHandlerContext ctx, ByteBuffer packet, List<Object> out) throws SSLException { private void unwrap(ChannelHandlerContext ctx, ByteBuffer packet, List<Object> out) throws SSLException {
boolean wrapLater = false; boolean wrapLater = false;
int bytesProduced = 0; int totalProduced = 0;
try { try {
loop:
for (;;) { for (;;) {
if (decodeOut == null) { if (decodeOut == null) {
decodeOut = ctx.alloc().buffer(); decodeOut = ctx.alloc().buffer();
} }
SSLEngineResult result = unwrap(engine, packet, decodeOut);
bytesProduced += result.bytesProduced(); final SSLEngineResult result = unwrap(engine, packet, decodeOut);
switch (result.getStatus()) { final Status status = result.getStatus();
case CLOSED: final HandshakeStatus handshakeStatus = result.getHandshakeStatus();
final int produced = result.bytesProduced();
final int consumed = result.bytesConsumed();
totalProduced += produced;
if (status == Status.CLOSED) {
// notify about the CLOSED state of the SSLEngine. See #137 // notify about the CLOSED state of the SSLEngine. See #137
sslCloseFuture.trySuccess(ctx.channel()); sslCloseFuture.trySuccess(ctx.channel());
break; break;
case BUFFER_UNDERFLOW:
break loop;
} }
switch (result.getHandshakeStatus()) { switch (handshakeStatus) {
case NEED_UNWRAP: case NEED_UNWRAP:
break; break;
case NEED_WRAP: case NEED_WRAP:
@ -848,11 +850,10 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
case NOT_HANDSHAKING: case NOT_HANDSHAKING:
break; break;
default: default:
throw new IllegalStateException( throw new IllegalStateException("Unknown handshake status: " + handshakeStatus);
"Unknown handshake status: " + result.getHandshakeStatus());
} }
if (result.bytesConsumed() == 0 && result.bytesProduced() == 0) { if (status == Status.BUFFER_UNDERFLOW || consumed == 0 && produced == 0) {
break; break;
} }
} }
@ -864,7 +865,7 @@ public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundH
setHandshakeFailure(e); setHandshakeFailure(e);
throw e; throw e;
} finally { } finally {
if (bytesProduced > 0) { if (totalProduced > 0) {
ByteBuf decodeOut = this.decodeOut; ByteBuf decodeOut = this.decodeOut;
this.decodeOut = null; this.decodeOut = null;
out.add(decodeOut); out.add(decodeOut);