diff --git a/handler/src/main/java/io/netty/handler/ssl/OptionalSslHandler.java b/handler/src/main/java/io/netty/handler/ssl/OptionalSslHandler.java new file mode 100644 index 0000000000..bc1b4b3623 --- /dev/null +++ b/handler/src/main/java/io/netty/handler/ssl/OptionalSslHandler.java @@ -0,0 +1,117 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.ObjectUtil; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; +import java.util.List; + +/** + * {@link OptionalSslHandler} is a utility decoder to support both SSL and non-SSL handlers + * based on the first message received. + */ +public class OptionalSslHandler extends ByteToMessageDecoder { + + private final SslContext sslContext; + + public OptionalSslHandler(SslContext sslContext) { + this.sslContext = ObjectUtil.checkNotNull(sslContext, "sslContext"); + } + + @Override + protected void decode(ChannelHandlerContext context, ByteBuf in, List out) throws Exception { + if (in.readableBytes() < SslUtils.SSL_RECORD_HEADER_LENGTH) { + return; + } + if (SslHandler.isEncrypted(in)) { + handleSsl(context); + } else { + handleNonSsl(context); + } + } + + private void handleSsl(ChannelHandlerContext context) { + SslHandler sslHandler = null; + try { + sslHandler = newSslHandler(context, sslContext); + context.pipeline().replace(this, newSslHandlerName(), sslHandler); + sslHandler = null; + } finally { + // Since the SslHandler was not inserted into the pipeline the ownership of the SSLEngine was not + // transferred to the SslHandler. + if (sslHandler != null) { + ReferenceCountUtil.safeRelease(sslHandler.engine()); + } + } + } + + private void handleNonSsl(ChannelHandlerContext context) { + ChannelHandler handler = newNonSslHandler(context); + if (handler != null) { + context.pipeline().replace(this, newNonSslHandlerName(), handler); + } else { + context.pipeline().remove(this); + } + } + + /** + * Optionally specify the SSL handler name, this method may return {@code null}. + * @return the name of the SSL handler. + */ + protected String newSslHandlerName() { + return null; + } + + /** + * Override to configure the SslHandler eg. {@link SSLParameters#setEndpointIdentificationAlgorithm(String)}. + * The hostname and port is not known by this method so servers may want to override this method and use the + * {@link SslContext#newHandler(ByteBufAllocator, String, int)} variant. + * + * @param context the {@link ChannelHandlerContext} to use. + * @param sslContext the {@link SSLContext} to use. + * @return the {@link SslHandler} which will replace the {@link OptionalSslHandler} in the pipeline if the + * traffic is SSL. + */ + protected SslHandler newSslHandler(ChannelHandlerContext context, SslContext sslContext) { + return sslContext.newHandler(context.alloc()); + } + + /** + * Optionally specify the non-SSL handler name, this method may return {@code null}. + * @return the name of the non-SSL handler. + */ + protected String newNonSslHandlerName() { + return null; + } + + /** + * Override to configure the ChannelHandler. + * @param context the {@link ChannelHandlerContext} to use. + * @return the {@link ChannelHandler} which will replace the {@link OptionalSslHandler} in the pipeline + * or {@code null} to simply remove the {@link OptionalSslHandler} if the traffic is non-SSL. + */ + protected ChannelHandler newNonSslHandler(ChannelHandlerContext context) { + return null; + } +} diff --git a/handler/src/test/java/io/netty/handler/ssl/OptionalSslHandlerTest.java b/handler/src/test/java/io/netty/handler/ssl/OptionalSslHandlerTest.java new file mode 100644 index 0000000000..8705cf2e9c --- /dev/null +++ b/handler/src/test/java/io/netty/handler/ssl/OptionalSslHandlerTest.java @@ -0,0 +1,121 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; + +import static org.mockito.Mockito.*; + +@RunWith(MockitoJUnitRunner.class) +public class OptionalSslHandlerTest { + + private static final String SSL_HANDLER_NAME = "sslhandler"; + private static final String HANDLER_NAME = "handler"; + + @Mock + private ChannelHandlerContext context; + + @Mock + private SslContext sslContext; + + @Mock + private ChannelPipeline pipeline; + + @Before + public void setUp() throws Exception { + when(context.pipeline()).thenReturn(pipeline); + } + + @Test + public void handlerRemoved() throws Exception { + OptionalSslHandler handler = new OptionalSslHandler(sslContext); + final ByteBuf payload = Unpooled.copiedBuffer("plaintext".getBytes()); + try { + handler.decode(context, payload, null); + verify(pipeline).remove(handler); + } finally { + payload.release(); + } + } + + @Test + public void handlerReplaced() throws Exception { + final ChannelHandler nonSslHandler = Mockito.mock(ChannelHandler.class); + OptionalSslHandler handler = new OptionalSslHandler(sslContext) { + @Override + protected ChannelHandler newNonSslHandler(ChannelHandlerContext context) { + return nonSslHandler; + } + + @Override + protected String newNonSslHandlerName() { + return HANDLER_NAME; + } + }; + final ByteBuf payload = Unpooled.copiedBuffer("plaintext".getBytes()); + try { + handler.decode(context, payload, null); + verify(pipeline).replace(handler, HANDLER_NAME, nonSslHandler); + } finally { + payload.release(); + } + } + + @Test + public void sslHandlerReplaced() throws Exception { + final SslHandler sslHandler = Mockito.mock(SslHandler.class); + OptionalSslHandler handler = new OptionalSslHandler(sslContext) { + @Override + protected SslHandler newSslHandler(ChannelHandlerContext context, SslContext sslContext) { + return sslHandler; + } + + @Override + protected String newSslHandlerName() { + return SSL_HANDLER_NAME; + } + }; + final ByteBuf payload = Unpooled.wrappedBuffer(new byte[] { 22, 3, 1, 0, 5 }); + try { + handler.decode(context, payload, null); + verify(pipeline).replace(handler, SSL_HANDLER_NAME, sslHandler); + } finally { + payload.release(); + } + } + + @Test + public void decodeBuffered() throws Exception { + OptionalSslHandler handler = new OptionalSslHandler(sslContext); + final ByteBuf payload = Unpooled.wrappedBuffer(new byte[] { 22, 3 }); + try { + handler.decode(context, payload, null); + verifyZeroInteractions(pipeline); + } finally { + payload.release(); + } + } +}