Port SslHandler to the new API. Everything except of starttls works

This commit is contained in:
norman 2012-06-06 07:52:28 +02:00
parent c2e3d305b4
commit 2aea5291bd

View File

@ -21,21 +21,17 @@ import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerContext;
import io.netty.channel.ChannelOutboundHandlerContext; import io.netty.channel.ChannelOutboundHandlerContext;
import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPipeline;
import io.netty.channel.DefaultChannelFuture; import io.netty.channel.DefaultChannelFuture;
import io.netty.handler.codec.StreamToStreamCodec; import io.netty.handler.codec.StreamToStreamCodec;
import io.netty.logging.InternalLogger; import io.netty.logging.InternalLogger;
import io.netty.logging.InternalLoggerFactory; import io.netty.logging.InternalLoggerFactory;
import io.netty.util.internal.NonReentrantLock;
import io.netty.util.internal.QueueFactory;
import java.io.IOException; import java.io.IOException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.ClosedChannelException; import java.nio.channels.ClosedChannelException;
import java.util.LinkedList;
import java.util.Queue;
import java.util.concurrent.Executor; import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.regex.Pattern; import java.util.regex.Pattern;
@ -169,7 +165,10 @@ public class SslHandler extends StreamToStreamCodec {
private final SSLEngine engine; private final SSLEngine engine;
private final SslBufferPool bufferPool; private final SslBufferPool bufferPool;
private final Executor delegatedTaskExecutor; private final Executor delegatedTaskExecutor;
// TODO: Fix STARTTLS
private final boolean startTls; private final boolean startTls;
private boolean firstMessageSend;
private volatile boolean enableRenegotiation = true; private volatile boolean enableRenegotiation = true;
@ -178,13 +177,9 @@ public class SslHandler extends StreamToStreamCodec {
private volatile boolean handshaken; private volatile boolean handshaken;
private volatile ChannelFuture handshakeFuture; private volatile ChannelFuture handshakeFuture;
private final AtomicBoolean sentFirstMessage = new AtomicBoolean();
private final AtomicBoolean sentCloseNotify = new AtomicBoolean(); private final AtomicBoolean sentCloseNotify = new AtomicBoolean();
int ignoreClosedChannelException; int ignoreClosedChannelException;
final Object ignoreClosedChannelExceptionLock = new Object(); final Object ignoreClosedChannelExceptionLock = new Object();
private final Queue<PendingWrite> pendingUnencryptedWrites = new LinkedList<PendingWrite>();
private final Queue<MessageEvent> pendingEncryptedWrites = QueueFactory.createQueue(MessageEvent.class);
private final NonReentrantLock pendingEncryptedWritesLock = new NonReentrantLock();
private volatile boolean issueHandshake; private volatile boolean issueHandshake;
private final SSLEngineInboundCloseFuture sslEngineCloseFuture = new SSLEngineInboundCloseFuture(); private final SSLEngineInboundCloseFuture sslEngineCloseFuture = new SSLEngineInboundCloseFuture();
@ -327,36 +322,55 @@ public class SslHandler extends StreamToStreamCodec {
throw new IllegalStateException("renegotiation disabled"); throw new IllegalStateException("renegotiation disabled");
} }
ChannelHandlerContext ctx = this.ctx;
Channel channel = ctx.channel(); Channel channel = ctx.channel();
ChannelFuture handshakeFuture;
Exception exception = null; Exception exception = null;
synchronized (handshakeLock) { synchronized (handshakeLock) {
if (handshaking) { if (handshaking) {
return this.handshakeFuture; return handshakeFuture;
} else { } else {
handshaking = true; handshaking = true;
try { if (ctx.executor().inEventLoop()) {
engine.beginHandshake(); try {
runDelegatedTasks(); engine.beginHandshake();
handshakeFuture = this.handshakeFuture = future(channel); runDelegatedTasks();
} catch (Exception e) { wrapNonAppData(ctx, channel);
handshakeFuture = this.handshakeFuture = failedFuture(channel, e);
exception = e; } catch (Exception e) {
exception = e;
}
} else {
ctx.executor().execute(new Runnable() {
@Override
public void run() {
Throwable exception = null;
synchronized (handshakeLock) {
try {
engine.beginHandshake();
runDelegatedTasks();
wrapNonAppData(ctx, ctx.channel());
} catch (Exception e) {
exception = e;
}
}
if (exception != null) { // Failed to initiate handshake.
handshakeFuture.setFailure(exception);
ctx.fireExceptionCaught(exception);
}
}
});
} }
} }
} }
if (exception == null) { // Began handshake successfully. if (exception != null) { // Failed to initiate handshake.
try { handshakeFuture.setFailure(exception);
wrapNonAppData(ctx, channel); ctx.fireExceptionCaught(exception);
} catch (SSLException e) {
fireExceptionCaught(ctx, e);
handshakeFuture.setFailure(e);
}
} else { // Failed to initiate handshake.
fireExceptionCaught(ctx, exception);
} }
return handshakeFuture; return handshakeFuture;
@ -373,8 +387,8 @@ public class SslHandler extends StreamToStreamCodec {
engine.closeOutbound(); engine.closeOutbound();
return wrapNonAppData(ctx, channel); return wrapNonAppData(ctx, channel);
} catch (SSLException e) { } catch (SSLException e) {
fireExceptionCaught(ctx, e); ctx.fireExceptionCaught(e);
return failedFuture(channel, e); return channel.newFailedFuture(e);
} }
} }
@ -421,42 +435,105 @@ public class SslHandler extends StreamToStreamCodec {
} }
@Override @Override
public void disconnect(ChannelOutboundHandlerContext<Byte> ctx, public void disconnect(final ChannelOutboundHandlerContext<Byte> ctx,
ChannelFuture future) throws Exception { final ChannelFuture future) throws Exception {
closeOutboundAndChannel(ctx, e); closeOutboundAndChannel(ctx, future, true);
super.disconnect(ctx, future);
} }
@Override @Override
public void close(ChannelOutboundHandlerContext<Byte> ctx, public void close(final ChannelOutboundHandlerContext<Byte> ctx,
ChannelFuture future) throws Exception { final ChannelFuture future) throws Exception {
closeOutboundAndChannel(ctx, e); closeOutboundAndChannel(ctx, future, false);
super.close(ctx, future);
} }
@Override @Override
public void encode(ChannelOutboundHandlerContext<Byte> ctx, ChannelBuffer in, ChannelBuffer out) throws Exception { public void encode(ChannelOutboundHandlerContext<Byte> ctx, ChannelBuffer in, ChannelBuffer out) throws Exception {
// Otherwise, all messages are encrypted.
ChannelBuffer msg = (ChannelBuffer) e.getMessage();
PendingWrite pendingWrite;
if (msg.readable()) { ByteBuffer outNetBuf = bufferPool.acquireBuffer();
pendingWrite = new PendingWrite(evt.getFuture(), msg.toByteBuffer(msg.readerIndex(), msg.readableBytes())); boolean success = true;
} else { boolean needsUnwrap = false;
pendingWrite = new PendingWrite(evt.getFuture(), null); try {
} ByteBuffer outAppBuf = in.nioBuffer();
synchronized (pendingUnencryptedWrites) {
boolean offered = pendingUnencryptedWrites.offer(pendingWrite); while(in.readable()) {
assert offered;
int read;
int remaining = outAppBuf.remaining();
SSLEngineResult result = null;
synchronized (handshakeLock) {
result = engine.wrap(outAppBuf, outNetBuf);
}
read = remaining - outAppBuf.remaining();
in.readerIndex(in.readerIndex() + read);
if (result.bytesProduced() > 0) {
outNetBuf.flip();
out.writeBytes(outNetBuf);
outNetBuf.clear();
} else if (result.getStatus() == Status.CLOSED) {
// SSLEngine has been closed already.
// Any further write attempts should be denied.
success = false;
break;
} else {
final HandshakeStatus handshakeStatus = result.getHandshakeStatus();
handleRenegotiation(handshakeStatus);
switch (handshakeStatus) {
case NEED_WRAP:
if (outAppBuf.hasRemaining()) {
break;
} else {
break;
}
case NEED_UNWRAP:
needsUnwrap = true;
case NEED_TASK:
runDelegatedTasks();
break;
case FINISHED:
case NOT_HANDSHAKING:
if (handshakeStatus == HandshakeStatus.FINISHED) {
setHandshakeSuccess(ctx.channel());
}
if (result.getStatus() == Status.CLOSED) {
success = false;
}
break;
default:
throw new IllegalStateException("Unknown handshake status: " + handshakeStatus);
}
}
}
} catch (SSLException e) {
success = false;
setHandshakeFailure(ctx.channel(), e);
throw e;
} finally {
bufferPool.releaseBuffer(outNetBuf);
if (!success) {
// mark all bytes as read
in.readerIndex(in.readerIndex() + in.readableBytes());
throw new IllegalStateException("SSLEngine already closed");
}
} }
wrap(context, evt.getChannel()); if (needsUnwrap) {
unwrap(ctx, ctx.channel(), ChannelBuffers.EMPTY_BUFFER, 0, 0, ChannelBuffers.EMPTY_BUFFER);
}
} }
@Override @Override
public void channelDisconnected(ChannelHandlerContext ctx, public void channelInactive(ChannelInboundHandlerContext<Byte> ctx) throws Exception {
ChannelStateEvent e) throws Exception {
// Make sure the handshake future is notified when a connection has // Make sure the handshake future is notified when a connection has
// been closed during handshake. // been closed during handshake.
synchronized (handshakeLock) { synchronized (handshakeLock) {
@ -466,9 +543,9 @@ public class SslHandler extends StreamToStreamCodec {
} }
try { try {
super.channelDisconnected(ctx, e); super.channelInactive(ctx);
} finally { } finally {
unwrap(ctx, e.channel(), ChannelBuffers.EMPTY_BUFFER, 0, 0); unwrap(ctx, ctx.channel(), ChannelBuffers.EMPTY_BUFFER, 0, 0, ChannelBuffers.EMPTY_BUFFER);
engine.closeOutbound(); engine.closeOutbound();
if (!sentCloseNotify.get() && handshaken) { if (!sentCloseNotify.get() && handshaken) {
try { try {
@ -483,10 +560,7 @@ public class SslHandler extends StreamToStreamCodec {
} }
@Override @Override
public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) public void exceptionCaught(ChannelInboundHandlerContext<Byte> ctx, Throwable cause) throws Exception {
throws Exception {
Throwable cause = e.cause();
if (cause instanceof IOException) { if (cause instanceof IOException) {
if (cause instanceof ClosedChannelException) { if (cause instanceof ClosedChannelException) {
synchronized (ignoreClosedChannelExceptionLock) { synchronized (ignoreClosedChannelExceptionLock) {
@ -515,28 +589,29 @@ public class SslHandler extends StreamToStreamCodec {
// Close the connection explicitly just in case the transport // Close the connection explicitly just in case the transport
// did not close the connection automatically. // did not close the connection automatically.
Channels.close(ctx, succeededFuture(e.channel())); ctx.close(ctx.channel().newFuture());
return; return;
} }
} }
} }
super.exceptionCaught(ctx, cause);
ctx.sendUpstream(e);
} }
@Override
protected Object decode(
final ChannelHandlerContext ctx, Channel channel, ChannelBuffer buffer) throws Exception {
if (buffer.readableBytes() < 5) {
return null; @Override
public void decode(ChannelInboundHandlerContext<Byte> ctx, ChannelBuffer in, ChannelBuffer out) throws Exception {
if (in.readableBytes() < 5) {
return;
} }
in.markReaderIndex();
int packetLength = 0; int packetLength = 0;
// SSLv3 or TLS - Check ContentType // SSLv3 or TLS - Check ContentType
boolean tls; boolean tls;
switch (buffer.getUnsignedByte(buffer.readerIndex())) { switch (in.getUnsignedByte(in.readerIndex())) {
case 20: // change_cipher_spec case 20: // change_cipher_spec
case 21: // alert case 21: // alert
case 22: // handshake case 22: // handshake
@ -550,10 +625,10 @@ public class SslHandler extends StreamToStreamCodec {
if (tls) { if (tls) {
// SSLv3 or TLS - Check ProtocolVersion // SSLv3 or TLS - Check ProtocolVersion
int majorVersion = buffer.getUnsignedByte(buffer.readerIndex() + 1); int majorVersion = in.getUnsignedByte(in.readerIndex() + 1);
if (majorVersion == 3) { if (majorVersion == 3) {
// SSLv3 or TLS // SSLv3 or TLS
packetLength = (getShort(buffer, buffer.readerIndex() + 3) & 0xFFFF) + 5; packetLength = (getShort(in, in.readerIndex() + 3) & 0xFFFF) + 5;
if (packetLength <= 5) { if (packetLength <= 5) {
// Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data) // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data)
tls = false; tls = false;
@ -567,16 +642,16 @@ public class SslHandler extends StreamToStreamCodec {
if (!tls) { if (!tls) {
// SSLv2 or bad data - Check the version // SSLv2 or bad data - Check the version
boolean sslv2 = true; boolean sslv2 = true;
int headerLength = (buffer.getUnsignedByte( int headerLength = (in.getUnsignedByte(
buffer.readerIndex()) & 0x80) != 0 ? 2 : 3; in.readerIndex()) & 0x80) != 0 ? 2 : 3;
int majorVersion = buffer.getUnsignedByte( int majorVersion = in.getUnsignedByte(
buffer.readerIndex() + headerLength + 1); in.readerIndex() + headerLength + 1);
if (majorVersion == 2 || majorVersion == 3) { if (majorVersion == 2 || majorVersion == 3) {
// SSLv2 // SSLv2
if (headerLength == 2) { if (headerLength == 2) {
packetLength = (getShort(buffer, buffer.readerIndex()) & 0x7FFF) + 2; packetLength = (getShort(in, in.readerIndex()) & 0x7FFF) + 2;
} else { } else {
packetLength = (getShort(buffer, buffer.readerIndex()) & 0x3FFF) + 3; packetLength = (getShort(in, in.readerIndex()) & 0x3FFF) + 3;
} }
if (packetLength <= headerLength) { if (packetLength <= headerLength) {
sslv2 = false; sslv2 = false;
@ -588,16 +663,21 @@ public class SslHandler extends StreamToStreamCodec {
if (!sslv2) { if (!sslv2) {
// Bad data - discard the buffer and raise an exception. // Bad data - discard the buffer and raise an exception.
SSLException e = new SSLException( SSLException e = new SSLException(
"not an SSL/TLS record: " + ChannelBuffers.hexDump(buffer)); "not an SSL/TLS record: " + ChannelBuffers.hexDump(in));
buffer.skipBytes(buffer.readableBytes()); in.skipBytes(in.readableBytes());
throw e; throw e;
} }
} }
assert packetLength > 0; assert packetLength > 0;
if (buffer.readableBytes() < packetLength) {
return null; if (in.readableBytes() < packetLength) {
// not enough bytes so reset the reader index and return
//
// TODO: store the parsed packetlength and reuse it. This will safe us from re-parse it
in.resetReaderIndex();
return;
} }
// We advance the buffer's readerIndex before calling unwrap() because // We advance the buffer's readerIndex before calling unwrap() because
@ -614,9 +694,9 @@ public class SslHandler extends StreamToStreamCodec {
// 6) SslHandler.decode() will feed the same packet with what was // 6) SslHandler.decode() will feed the same packet with what was
// deciphered at the step 2 again if the readerIndex was not advanced // deciphered at the step 2 again if the readerIndex was not advanced
// before calling the user code. // before calling the user code.
final int packetOffset = buffer.readerIndex(); final int packetOffset = in.readerIndex();
buffer.skipBytes(packetLength); in.skipBytes(packetLength);
return unwrap(ctx, channel, buffer, packetOffset, packetLength); unwrap(ctx, ctx.channel(), in, packetOffset, packetLength, out);
} }
/** /**
@ -627,178 +707,6 @@ public class SslHandler extends StreamToStreamCodec {
return (short) (buf.getByte(offset) << 8 | buf.getByte(offset + 1) & 0xFF); return (short) (buf.getByte(offset) << 8 | buf.getByte(offset + 1) & 0xFF);
} }
private ChannelFuture wrap(ChannelHandlerContext context, Channel channel)
throws SSLException {
ChannelFuture future = null;
ChannelBuffer msg;
ByteBuffer outNetBuf = bufferPool.acquireBuffer();
boolean success = true;
boolean offered = false;
boolean needsUnwrap = false;
try {
loop:
for (;;) {
// Acquire a lock to make sure unencrypted data is polled
// in order and their encrypted counterpart is offered in
// order.
synchronized (pendingUnencryptedWrites) {
PendingWrite pendingWrite = pendingUnencryptedWrites.peek();
if (pendingWrite == null) {
break;
}
ByteBuffer outAppBuf = pendingWrite.outAppBuf;
if (outAppBuf == null) {
// A write request with an empty buffer
pendingUnencryptedWrites.remove();
offerEncryptedWriteRequest(
new DownstreamMessageEvent(
channel, pendingWrite.future,
ChannelBuffers.EMPTY_BUFFER,
channel.getRemoteAddress()));
offered = true;
} else {
SSLEngineResult result = null;
try {
synchronized (handshakeLock) {
result = engine.wrap(outAppBuf, outNetBuf);
}
} finally {
if (!outAppBuf.hasRemaining()) {
pendingUnencryptedWrites.remove();
}
}
if (result.bytesProduced() > 0) {
outNetBuf.flip();
int remaining = outNetBuf.remaining();
msg = ChannelBuffers.buffer(remaining);
msg.writeBytes(outNetBuf);
outNetBuf.clear();
if (pendingWrite.outAppBuf.hasRemaining()) {
// pendingWrite's future shouldn't be notified if
// only partial data is written.
future = succeededFuture(channel);
} else {
future = pendingWrite.future;
}
MessageEvent encryptedWrite = new DownstreamMessageEvent(
channel, future, msg, channel.getRemoteAddress());
offerEncryptedWriteRequest(encryptedWrite);
offered = true;
} else if (result.getStatus() == Status.CLOSED) {
// SSLEngine has been closed already.
// Any further write attempts should be denied.
success = false;
break;
} else {
final HandshakeStatus handshakeStatus = result.getHandshakeStatus();
handleRenegotiation(handshakeStatus);
switch (handshakeStatus) {
case NEED_WRAP:
if (outAppBuf.hasRemaining()) {
break;
} else {
break loop;
}
case NEED_UNWRAP:
needsUnwrap = true;
break loop;
case NEED_TASK:
runDelegatedTasks();
break;
case FINISHED:
case NOT_HANDSHAKING:
if (handshakeStatus == HandshakeStatus.FINISHED) {
setHandshakeSuccess(channel);
}
if (result.getStatus() == Status.CLOSED) {
success = false;
}
break loop;
default:
throw new IllegalStateException(
"Unknown handshake status: " +
handshakeStatus);
}
}
}
}
}
} catch (SSLException e) {
success = false;
setHandshakeFailure(channel, e);
throw e;
} finally {
bufferPool.releaseBuffer(outNetBuf);
if (offered) {
flushPendingEncryptedWrites(context);
}
if (!success) {
IllegalStateException cause =
new IllegalStateException("SSLEngine already closed");
// Mark all remaining pending writes as failure if anything
// wrong happened before the write requests are wrapped.
// Please note that we do not call setFailure while a lock is
// acquired, to avoid a potential dead lock.
for (;;) {
PendingWrite pendingWrite;
synchronized (pendingUnencryptedWrites) {
pendingWrite = pendingUnencryptedWrites.poll();
if (pendingWrite == null) {
break;
}
}
pendingWrite.future.setFailure(cause);
}
}
}
if (needsUnwrap) {
unwrap(context, channel, ChannelBuffers.EMPTY_BUFFER, 0, 0);
}
if (future == null) {
future = succeededFuture(channel);
}
return future;
}
private void offerEncryptedWriteRequest(MessageEvent encryptedWrite) {
final boolean locked = pendingEncryptedWritesLock.tryLock();
try {
pendingEncryptedWrites.offer(encryptedWrite);
} finally {
if (locked) {
pendingEncryptedWritesLock.unlock();
}
}
}
private void flushPendingEncryptedWrites(ChannelHandlerContext ctx) {
// Avoid possible dead lock and data integrity issue
// which is caused by cross communication between more than one channel
// in the same VM.
if (!pendingEncryptedWritesLock.tryLock()) {
return;
}
try {
MessageEvent e;
while ((e = pendingEncryptedWrites.poll()) != null) {
ctx.sendDownstream(e);
}
} finally {
pendingEncryptedWritesLock.unlock();
}
}
private ChannelFuture wrapNonAppData(ChannelHandlerContext ctx, Channel channel) throws SSLException { private ChannelFuture wrapNonAppData(ChannelHandlerContext ctx, Channel channel) throws SSLException {
ChannelFuture future = null; ChannelFuture future = null;
ByteBuffer outNetBuf = bufferPool.acquireBuffer(); ByteBuffer outNetBuf = bufferPool.acquireBuffer();
@ -820,7 +728,7 @@ public class SslHandler extends StreamToStreamCodec {
msg.writeBytes(outNetBuf); msg.writeBytes(outNetBuf);
outNetBuf.clear(); outNetBuf.clear();
future = future(channel); future = channel.newFuture();
future.addListener(new ChannelFutureListener() { future.addListener(new ChannelFutureListener() {
@Override @Override
public void operationComplete(ChannelFuture future) public void operationComplete(ChannelFuture future)
@ -833,7 +741,7 @@ public class SslHandler extends StreamToStreamCodec {
} }
}); });
write(ctx, future, msg); ctx.write(msg, future);
} }
final HandshakeStatus handshakeStatus = result.getHandshakeStatus(); final HandshakeStatus handshakeStatus = result.getHandshakeStatus();
@ -851,7 +759,7 @@ public class SslHandler extends StreamToStreamCodec {
// unwrap shouldn't be called when this method was // unwrap shouldn't be called when this method was
// called by unwrap - unwrap will keep running after // called by unwrap - unwrap will keep running after
// this method returns. // this method returns.
unwrap(ctx, channel, ChannelBuffers.EMPTY_BUFFER, 0, 0); unwrap(ctx, channel, ChannelBuffers.EMPTY_BUFFER, 0, 0, ChannelBuffers.EMPTY_BUFFER);
} }
break; break;
case NOT_HANDSHAKING: case NOT_HANDSHAKING:
@ -874,15 +782,15 @@ public class SslHandler extends StreamToStreamCodec {
} }
if (future == null) { if (future == null) {
future = succeededFuture(channel); future = channel.newSucceededFuture();
} }
return future; return future;
} }
private ChannelBuffer unwrap( private void unwrap(
ChannelHandlerContext ctx, Channel channel, ChannelBuffer buffer, int offset, int length) throws SSLException { ChannelHandlerContext ctx, Channel channel, ChannelBuffer buffer, int offset, int length, ChannelBuffer out) throws SSLException {
ByteBuffer inNetBuf = buffer.toByteBuffer(offset, length); ByteBuffer inNetBuf = buffer.nioBuffer(offset, length);
ByteBuffer outAppBuf = bufferPool.acquireBuffer(); ByteBuffer outAppBuf = bufferPool.acquireBuffer();
try { try {
@ -942,32 +850,17 @@ public class SslHandler extends StreamToStreamCodec {
} }
if (needsWrap) { if (needsWrap) {
// wrap() acquires pendingUnencryptedWrites first and then wrapNonAppData(ctx, channel);
// handshakeLock. If handshakeLock is already hold by the
// current thread, calling wrap() will lead to a dead lock
// i.e. pendingUnencryptedWrites -> handshakeLock vs.
// handshakeLock -> pendingUnencryptedLock -> handshakeLock
//
// There is also a same issue between pendingEncryptedWrites
// and pendingUnencryptedWrites.
if (!Thread.holdsLock(handshakeLock) &&
!pendingEncryptedWritesLock.isHeldByCurrentThread()) {
wrap(ctx, channel);
}
} }
outAppBuf.flip(); outAppBuf.flip();
if (outAppBuf.hasRemaining()) { if (outAppBuf.hasRemaining()) {
ChannelBuffer frame = ctx.getChannel().getConfig().getBufferFactory().getBuffer(outAppBuf.remaining());
// Transfer the bytes to the new ChannelBuffer using some safe method that will also // Transfer the bytes to the new ChannelBuffer using some safe method that will also
// work with "non" heap buffers // work with "non" heap buffers
// //
// See https://github.com/netty/netty/issues/329 // See https://github.com/netty/netty/issues/329
frame.writeBytes(outAppBuf); out.writeBytes(outAppBuf);
return frame;
} else {
return null;
} }
} catch (SSLException e) { } catch (SSLException e) {
setHandshakeFailure(channel, e); setHandshakeFailure(channel, e);
@ -1018,13 +911,12 @@ public class SslHandler extends StreamToStreamCodec {
handshake(); handshake();
} else { } else {
// Raise an exception. // Raise an exception.
fireExceptionCaught( ctx.fireExceptionCaught(new SSLException(
ctx, new SSLException(
"renegotiation attempted by peer; " + "renegotiation attempted by peer; " +
"closing the connection")); "closing the connection"));
// Close the connection to stop renegotiation. // Close the connection to stop renegotiation.
Channels.close(ctx, succeededFuture(ctx.channel())); ctx.close(ctx.channel().newSucceededFuture());
} }
} }
@ -1056,7 +948,7 @@ public class SslHandler extends StreamToStreamCodec {
handshaken = true; handshaken = true;
if (handshakeFuture == null) { if (handshakeFuture == null) {
handshakeFuture = future(channel); handshakeFuture = channel.newFuture();
} }
} }
@ -1072,7 +964,7 @@ public class SslHandler extends StreamToStreamCodec {
handshaken = false; handshaken = false;
if (handshakeFuture == null) { if (handshakeFuture == null) {
handshakeFuture = future(channel); handshakeFuture = channel.newFuture();
} }
// Release all resources such as internal buffers that SSLEngine // Release all resources such as internal buffers that SSLEngine
@ -1096,16 +988,20 @@ public class SslHandler extends StreamToStreamCodec {
} }
private void closeOutboundAndChannel( private void closeOutboundAndChannel(
final ChannelHandlerContext context, final ChannelStateEvent e) { final ChannelOutboundHandlerContext<Byte> context, final ChannelFuture future, boolean disconnect) throws Exception {
if (!e.channel().isConnected()) { if (!context.channel().isActive()) {
context.sendDownstream(e); if (disconnect) {
ctx.disconnect(future);
} else {
ctx.close(future);
}
return; return;
} }
boolean success = false; boolean success = false;
try { try {
try { try {
unwrap(context, e.channel(), ChannelBuffers.EMPTY_BUFFER, 0, 0); unwrap(context, context.channel(), ChannelBuffers.EMPTY_BUFFER, 0, 0, ChannelBuffers.EMPTY_BUFFER);
} catch (SSLException ex) { } catch (SSLException ex) {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("Failed to unwrap before sending a close_notify message", ex); logger.debug("Failed to unwrap before sending a close_notify message", ex);
@ -1116,9 +1012,9 @@ public class SslHandler extends StreamToStreamCodec {
if (sentCloseNotify.compareAndSet(false, true)) { if (sentCloseNotify.compareAndSet(false, true)) {
engine.closeOutbound(); engine.closeOutbound();
try { try {
ChannelFuture closeNotifyFuture = wrapNonAppData(context, e.channel()); ChannelFuture closeNotifyFuture = wrapNonAppData(context, context.channel());
closeNotifyFuture.addListener( closeNotifyFuture.addListener(
new ClosingChannelFutureListener(context, e)); new ClosingChannelFutureListener(context, future));
success = true; success = true;
} catch (SSLException ex) { } catch (SSLException ex) {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
@ -1131,38 +1027,31 @@ public class SslHandler extends StreamToStreamCodec {
} }
} finally { } finally {
if (!success) { if (!success) {
context.sendDownstream(e); if (disconnect) {
ctx.disconnect(future);
} else {
ctx.close(future);
}
} }
} }
} }
private static final class PendingWrite {
final ChannelFuture future;
final ByteBuffer outAppBuf;
PendingWrite(ChannelFuture future, ByteBuffer outAppBuf) {
this.future = future;
this.outAppBuf = outAppBuf;
}
}
private static final class ClosingChannelFutureListener implements ChannelFutureListener { private static final class ClosingChannelFutureListener implements ChannelFutureListener {
private final ChannelHandlerContext context; private final ChannelHandlerContext context;
private final ChannelStateEvent e; private final ChannelFuture f;
ClosingChannelFutureListener( ClosingChannelFutureListener(
ChannelHandlerContext context, ChannelStateEvent e) { ChannelHandlerContext context, ChannelFuture f) {
this.context = context; this.context = context;
this.e = e; this.f = f;
} }
@Override @Override
public void operationComplete(ChannelFuture closeNotifyFuture) throws Exception { public void operationComplete(ChannelFuture closeNotifyFuture) throws Exception {
if (!(closeNotifyFuture.cause() instanceof ClosedChannelException)) { if (!(closeNotifyFuture.cause() instanceof ClosedChannelException)) {
Channels.close(context, e.getFuture()); context.close(f);
} else { } else {
e.getFuture().setSuccess(); f.setSuccess();
} }
} }
} }
@ -1170,61 +1059,16 @@ public class SslHandler extends StreamToStreamCodec {
@Override @Override
public void beforeAdd(ChannelHandlerContext ctx) throws Exception { public void beforeAdd(ChannelHandlerContext ctx) throws Exception {
this.ctx = ctx; this.ctx = ctx;
this.handshakeFuture = ctx.newFuture();
} }
@Override
public void afterAdd(ChannelHandlerContext ctx) throws Exception {
// Unused
}
@Override
public void beforeRemove(ChannelHandlerContext ctx) throws Exception {
// Unused
}
/**
* Fail all pending writes which we were not able to flush out
*/
@Override
public void afterRemove(ChannelHandlerContext ctx) throws Exception {
// there is no need for synchronization here as we do not receive downstream events anymore
Throwable cause = null;
for (;;) {
PendingWrite pw = pendingUnencryptedWrites.poll();
if (pw == null) {
break;
}
if (cause == null) {
cause = new IOException("Unable to write data");
}
pw.future.setFailure(cause);
}
for (;;) {
MessageEvent ev = pendingEncryptedWrites.poll();
if (ev == null) {
break;
}
if (cause == null) {
cause = new IOException("Unable to write data");
}
ev.getFuture().setFailure(cause);
}
if (cause != null) {
fireExceptionCaughtLater(ctx, cause);
}
}
/** /**
* Calls {@link #handshake()} once the {@link Channel} is connected * Calls {@link #handshake()} once the {@link Channel} is connected
*/ */
@Override @Override
public void channelConnected(final ChannelHandlerContext ctx, final ChannelStateEvent e) throws Exception { public void channelActive(final ChannelInboundHandlerContext<Byte> ctx) throws Exception {
if (issueHandshake) { if (issueHandshake) {
// issue and handshake and add a listener to it which will fire an exception event if an exception was thrown while doing the handshake // issue and handshake and add a listener to it which will fire an exception event if an exception was thrown while doing the handshake
handshake().addListener(new ChannelFutureListener() { handshake().addListener(new ChannelFutureListener() {
@ -1232,63 +1076,21 @@ public class SslHandler extends StreamToStreamCodec {
@Override @Override
public void operationComplete(ChannelFuture future) throws Exception { public void operationComplete(ChannelFuture future) throws Exception {
if (!future.isSuccess()) { if (!future.isSuccess()) {
Channels.fireExceptionCaught(future.getChannel(), future.getCause()); ctx.pipeline().fireExceptionCaught(future.cause());
} else { } else {
// Send the event upstream after the handshake was completed without an error. // Send the event upstream after the handshake was completed without an error.
// //
// See https://github.com/netty/netty/issues/358 // See https://github.com/netty/netty/issues/358
ctx.sendUpstream(e); ctx.fireChannelActive();
} }
} }
}); });
} else { } else {
super.channelConnected(ctx, e); super.channelActive(ctx);
} }
} }
/**
* Loop over all the pending writes and fail them.
*
* See <a href="https://github.com/netty/netty/issues/305">#305</a> for more details.
*/
@Override
public void channelClosed(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception {
Throwable cause = null;
synchronized (pendingUnencryptedWrites) {
for (;;) {
PendingWrite pw = pendingUnencryptedWrites.poll();
if (pw == null) {
break;
}
if (cause == null) {
cause = new ClosedChannelException();
}
pw.future.setFailure(cause);
}
for (;;) {
MessageEvent ev = pendingEncryptedWrites.poll();
if (ev == null) {
break;
}
if (cause == null) {
cause = new ClosedChannelException();
}
ev.getFuture().setFailure(cause);
}
}
if (cause != null) {
fireExceptionCaught(ctx, cause);
}
super.channelClosed(ctx, e);
}
private final class SSLEngineInboundCloseFuture extends DefaultChannelFuture { private final class SSLEngineInboundCloseFuture extends DefaultChannelFuture {
public SSLEngineInboundCloseFuture() { public SSLEngineInboundCloseFuture() {
super(null, true); super(null, true);
@ -1299,12 +1101,12 @@ public class SslHandler extends StreamToStreamCodec {
} }
@Override @Override
public Channel getChannel() { public Channel channel() {
if (ctx == null) { if (ctx == null) {
// Maybe we should better throw an IllegalStateException() ? // Maybe we should better throw an IllegalStateException() ?
return null; return null;
} else { } else {
return ctx.getChannel(); return ctx.channel();
} }
} }
@ -1318,4 +1120,6 @@ public class SslHandler extends StreamToStreamCodec {
return false; return false;
} }
} }
} }