diff --git a/src/main/java/org/jboss/netty/handler/ssl/SslHandler.java b/src/main/java/org/jboss/netty/handler/ssl/SslHandler.java index 26d1a2dc7e..d2378ee210 100644 --- a/src/main/java/org/jboss/netty/handler/ssl/SslHandler.java +++ b/src/main/java/org/jboss/netty/handler/ssl/SslHandler.java @@ -124,8 +124,8 @@ public class SslHandler extends FrameDecoder { private final boolean startTls; final Object handshakeLock = new Object(); - private volatile boolean needsFirstHandshake = true; - private volatile boolean handshaking; + private boolean initialHandshake; + private boolean handshaking; private volatile boolean handshaken; private volatile ChannelFuture handshakeFuture; @@ -275,6 +275,7 @@ public class SslHandler extends FrameDecoder { return this.handshakeFuture; } else { engine.beginHandshake(); + runDelegatedTasks(); handshakeFuture = this.handshakeFuture = newHandshakeFuture(channel); handshaking = true; } @@ -438,11 +439,7 @@ public class SslHandler extends FrameDecoder { SSLEngineResult result; try { - if (handshaking || needsFirstHandshake) { - synchronized (handshakeLock) { - result = engine.wrap(outAppBuf, outNetBuf); - } - } else { + synchronized (handshakeLock) { result = engine.wrap(outAppBuf, outNetBuf); } } finally { @@ -504,9 +501,7 @@ public class SslHandler extends FrameDecoder { } } catch (SSLException e) { success = false; - if (handshaking) { - setHandshakeFailure(channel, e); - } + setHandshakeFailure(channel, e); throw e; } finally { bufferPool.release(outNetBuf); @@ -574,11 +569,7 @@ public class SslHandler extends FrameDecoder { SSLEngineResult result; try { for (;;) { - if (handshaking || needsFirstHandshake) { - synchronized (handshakeLock) { - result = engine.wrap(EMPTY_BUFFER, outNetBuf); - } - } else { + synchronized (handshakeLock) { result = engine.wrap(EMPTY_BUFFER, outNetBuf); } @@ -600,7 +591,12 @@ public class SslHandler extends FrameDecoder { runDelegatedTasks(); break; case NEED_UNWRAP: - unwrap(ctx, channel, ChannelBuffers.EMPTY_BUFFER, 0, 0); + if (!Thread.holdsLock(handshakeLock)) { + // unwrap shouldn't be called when this method was + // called by unwrap - unwrap will keep running after + // this method returns. + unwrap(ctx, channel, ChannelBuffers.EMPTY_BUFFER, 0, 0); + } break; case NOT_HANDSHAKING: case NEED_WRAP: @@ -616,9 +612,7 @@ public class SslHandler extends FrameDecoder { } } } catch (SSLException e) { - if (handshaking) { - setHandshakeFailure(channel, e); - } + setHandshakeFailure(channel, e); throw e; } finally { bufferPool.release(outNetBuf); @@ -636,44 +630,55 @@ public class SslHandler extends FrameDecoder { ByteBuffer outAppBuf = bufferPool.acquire(); try { + boolean needsWrap = false; loop: for (;;) { SSLEngineResult result; - if (handshaking || needsFirstHandshake) { - synchronized (handshakeLock) { + synchronized (handshakeLock) { + if (initialHandshake && !engine.getUseClientMode() && + !engine.isInboundDone() && !engine.isOutboundDone()) { + handshake(channel); + initialHandshake = false; + } + try { result = engine.unwrap(inNetBuf, outAppBuf); + } catch (SSLException e) { + System.err.println(engine.getUseClientMode()); + throw e; } - } else { - result = engine.unwrap(inNetBuf, outAppBuf); - } - switch (result.getHandshakeStatus()) { - case NEED_UNWRAP: - if (inNetBuf.hasRemaining()) { + switch (result.getHandshakeStatus()) { + case NEED_UNWRAP: + if (inNetBuf.hasRemaining()) { + break; + } else { + break loop; + } + case NEED_WRAP: + wrapNonAppData(ctx, channel); break; - } else { + case NEED_TASK: + runDelegatedTasks(); + break; + case FINISHED: + setHandshakeSuccess(channel); + needsWrap = true; break loop; + case NOT_HANDSHAKING: + needsWrap = true; + break loop; + default: + throw new IllegalStateException( + "Unknown handshake status: " + + result.getHandshakeStatus()); } - case NEED_WRAP: - wrapNonAppData(ctx, channel); - break; - case NEED_TASK: - runDelegatedTasks(); - break; - case FINISHED: - setHandshakeSuccess(channel); - wrap(ctx, channel); - break loop; - case NOT_HANDSHAKING: - wrap(ctx, channel); - break loop; - default: - throw new IllegalStateException( - "Unknown handshake status: " + - result.getHandshakeStatus()); } } + if (needsWrap) { + wrap(ctx, channel); + } + outAppBuf.flip(); if (outAppBuf.hasRemaining()) { @@ -684,9 +689,7 @@ public class SslHandler extends FrameDecoder { return null; } } catch (SSLException e) { - if (handshaking) { - setHandshakeFailure(channel, e); - } + setHandshakeFailure(channel, e); throw e; } finally { bufferPool.release(outAppBuf); @@ -694,13 +697,20 @@ public class SslHandler extends FrameDecoder { } private void runDelegatedTasks() { - Runnable task; - while ((task = engine.getDelegatedTask()) != null) { - final Runnable t = task; + for (;;) { + final Runnable task; + synchronized (handshakeLock) { + task = engine.getDelegatedTask(); + } + + if (task == null) { + break; + } + delegatedTaskExecutor.execute(new Runnable() { public void run() { synchronized (handshakeLock) { - t.run(); + task.run(); } } }); @@ -711,7 +721,6 @@ public class SslHandler extends FrameDecoder { synchronized (handshakeLock) { handshaking = false; handshaken = true; - needsFirstHandshake = false; // Will not set to true again if (handshakeFuture == null) { handshakeFuture = newHandshakeFuture(channel); @@ -723,6 +732,9 @@ public class SslHandler extends FrameDecoder { private void setHandshakeFailure(Channel channel, SSLException cause) { synchronized (handshakeLock) { + if (!handshaking) { + return; + } handshaking = false; handshaken = false;