diff --git a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java index b961b3eec2..9763c59683 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslHandler.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -230,6 +230,8 @@ public class SslHandler extends ByteToMessageDecoder { */ private boolean needsFlush; + private boolean outboundClosed; + private int packetLength; /** @@ -367,6 +369,7 @@ public class SslHandler extends ByteToMessageDecoder { ctx.executor().execute(new Runnable() { @Override public void run() { + SslHandler.this.outboundClosed = true; engine.closeOutbound(); try { write(ctx, Unpooled.EMPTY_BUFFER, future); @@ -659,7 +662,7 @@ public class SslHandler extends ByteToMessageDecoder { public void channelInactive(ChannelHandlerContext ctx) throws Exception { // Make sure to release SSLEngine, // and notify the handshake future if the connection has been closed during handshake. - setHandshakeFailure(ctx, CHANNEL_CLOSED); + setHandshakeFailure(ctx, CHANNEL_CLOSED, !outboundClosed); super.channelInactive(ctx); } @@ -1157,20 +1160,29 @@ public class SslHandler extends ByteToMessageDecoder { * Notify all the handshake futures about the failure during the handshake. */ private void setHandshakeFailure(ChannelHandlerContext ctx, Throwable cause) { + setHandshakeFailure(ctx, cause, true); + } + + /** + * Notify all the handshake futures about the failure during the handshake. + */ + private void setHandshakeFailure(ChannelHandlerContext ctx, Throwable cause, boolean closeInbound) { // Release all resources such as internal buffers that SSLEngine // is managing. engine.closeOutbound(); - try { - engine.closeInbound(); - } catch (SSLException e) { - // only log in debug mode as it most likely harmless and latest chrome still trigger - // this all the time. - // - // See https://github.com/netty/netty/issues/1340 - String msg = e.getMessage(); - if (msg == null || !msg.contains("possible truncation attack")) { - logger.debug("{} SSLEngine.closeInbound() raised an exception.", ctx.channel(), e); + if (closeInbound) { + try { + engine.closeInbound(); + } catch (SSLException e) { + // only log in debug mode as it most likely harmless and latest chrome still trigger + // this all the time. + // + // See https://github.com/netty/netty/issues/1340 + String msg = e.getMessage(); + if (msg == null || !msg.contains("possible truncation attack")) { + logger.debug("{} SSLEngine.closeInbound() raised an exception.", ctx.channel(), e); + } } } notifyHandshakeFailure(cause); @@ -1195,6 +1207,7 @@ public class SslHandler extends ByteToMessageDecoder { return; } + outboundClosed = true; engine.closeOutbound(); ChannelPromise closeNotifyFuture = ctx.newPromise(); diff --git a/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslSessionReuseTest.java b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslSessionReuseTest.java new file mode 100644 index 0000000000..4e560ce158 --- /dev/null +++ b/testsuite/src/main/java/io/netty/testsuite/transport/socket/SocketSslSessionReuseTest.java @@ -0,0 +1,217 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.testsuite.transport.socket; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.socket.SocketChannel; +import io.netty.handler.ssl.JdkSslClientContext; +import io.netty.handler.ssl.JdkSslContext; +import io.netty.handler.ssl.JdkSslServerContext; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslHandler; +import io.netty.handler.ssl.util.SelfSignedCertificate; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLSessionContext; + +import java.io.File; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.security.cert.CertificateException; +import java.util.Collection; +import java.util.Collections; +import java.util.Enumeration; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.Assert.*; + +@RunWith(Parameterized.class) +public class SocketSslSessionReuseTest extends AbstractSocketTest { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(SocketSslSessionReuseTest.class); + + private static final File CERT_FILE; + private static final File KEY_FILE; + + static { + SelfSignedCertificate ssc; + try { + ssc = new SelfSignedCertificate(); + } catch (CertificateException e) { + throw new Error(e); + } + CERT_FILE = ssc.certificate(); + KEY_FILE = ssc.privateKey(); + } + + @Parameters(name = "{index}: serverEngine = {0}, clientEngine = {1}") + public static Collection data() throws Exception { + return Collections.singletonList(new Object[] { + new JdkSslServerContext(CERT_FILE, KEY_FILE), + new JdkSslClientContext(CERT_FILE) + }); + } + + private final SslContext serverCtx; + private final SslContext clientCtx; + + public SocketSslSessionReuseTest(SslContext serverCtx, SslContext clientCtx) { + this.serverCtx = serverCtx; + this.clientCtx = clientCtx; + } + + @Test(timeout = 30000) + public void testSslSessionReuse() throws Throwable { + run(); + } + + public void testSslSessionReuse(ServerBootstrap sb, Bootstrap cb) throws Throwable { + final ReadAndDiscardHandler sh = new ReadAndDiscardHandler(true, true); + final ReadAndDiscardHandler ch = new ReadAndDiscardHandler(false, true); + final String[] protocols = new String[]{ "TLSv1", "TLSv1.1", "TLSv1.2" }; + + sb.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel sch) throws Exception { + SSLEngine engine = serverCtx.newEngine(sch.alloc()); + engine.setUseClientMode(false); + engine.setEnabledProtocols(protocols); + + sch.pipeline().addLast(new SslHandler(engine)); + sch.pipeline().addLast(sh); + } + }); + final Channel sc = sb.bind().sync().channel(); + + cb.handler(new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel sch) throws Exception { + InetSocketAddress serverAddr = (InetSocketAddress) sc.localAddress(); + SSLEngine engine = clientCtx.newEngine(sch.alloc(), serverAddr.getHostString(), serverAddr.getPort()); + engine.setUseClientMode(true); + engine.setEnabledProtocols(protocols); + + sch.pipeline().addLast(new SslHandler(engine)); + sch.pipeline().addLast(ch); + } + }); + + try { + SSLSessionContext clientSessionCtx = ((JdkSslContext) clientCtx).sessionContext(); + ByteBuf msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4); + Channel cc = cb.connect().sync().channel(); + cc.writeAndFlush(msg).sync(); + cc.closeFuture().sync(); + rethrowHandlerExceptions(sh, ch); + Set sessions = sessionIdSet(clientSessionCtx.getIds()); + + msg = Unpooled.wrappedBuffer(new byte[] { 0xa, 0xb, 0xc, 0xd }, 0, 4); + cc = cb.connect().sync().channel(); + cc.writeAndFlush(msg).sync(); + cc.closeFuture().sync(); + assertEquals("Expected no new sessions", sessions, sessionIdSet(clientSessionCtx.getIds())); + rethrowHandlerExceptions(sh, ch); + } finally { + sc.close().awaitUninterruptibly(); + } + } + + private void rethrowHandlerExceptions(ReadAndDiscardHandler sh, ReadAndDiscardHandler ch) throws Throwable { + if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) { + throw sh.exception.get(); + } + if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) { + throw ch.exception.get(); + } + if (sh.exception.get() != null) { + throw sh.exception.get(); + } + if (ch.exception.get() != null) { + throw ch.exception.get(); + } + } + + private Set sessionIdSet(Enumeration sessionIds) { + Set idSet = new HashSet(); + byte[] id; + while (sessionIds.hasMoreElements()) { + id = sessionIds.nextElement(); + idSet.add(ByteBufUtil.hexDump(Unpooled.wrappedBuffer(id))); + } + return idSet; + } + + @Sharable + private static class ReadAndDiscardHandler extends SimpleChannelInboundHandler { + final AtomicReference exception = new AtomicReference(); + private final boolean server; + private final boolean autoRead; + + ReadAndDiscardHandler(boolean server, boolean autoRead) { + this.server = server; + this.autoRead = autoRead; + } + + @Override + public void messageReceived(ChannelHandlerContext ctx, ByteBuf in) throws Exception { + byte[] actual = new byte[in.readableBytes()]; + in.readBytes(actual); + ctx.close(); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + try { + ctx.flush(); + } finally { + if (!autoRead) { + ctx.read(); + } + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, + Throwable cause) throws Exception { + if (logger.isWarnEnabled()) { + logger.warn( + "Unexpected exception from the " + + (server? "server" : "client") + " side", cause); + } + + exception.compareAndSet(null, cause); + ctx.close(); + } + } +}