Support for handling SSL and non-SSL in pipeline

Motivation:

Some pipelines require support for both SSL and non-SSL messaging.

Modifications:

Add utility decoder to support both SSL and non-SSL handlers based on the initial message.

Result:

Less boilerplate code to write for developers.
This commit is contained in:
Johno Crawford 2017-03-07 23:31:05 +01:00 committed by Scott Mitchell
parent a94b23df7d
commit cfebaa36c0
2 changed files with 238 additions and 0 deletions

View File

@ -0,0 +1,117 @@
/*
* Copyright 2017 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.ssl;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.internal.ObjectUtil;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLParameters;
import java.util.List;
/**
* {@link OptionalSslHandler} is a utility decoder to support both SSL and non-SSL handlers
* based on the first message received.
*/
public class OptionalSslHandler extends ByteToMessageDecoder {
private final SslContext sslContext;
public OptionalSslHandler(SslContext sslContext) {
this.sslContext = ObjectUtil.checkNotNull(sslContext, "sslContext");
}
@Override
protected void decode(ChannelHandlerContext context, ByteBuf in, List<Object> out) throws Exception {
if (in.readableBytes() < SslUtils.SSL_RECORD_HEADER_LENGTH) {
return;
}
if (SslHandler.isEncrypted(in)) {
handleSsl(context);
} else {
handleNonSsl(context);
}
}
private void handleSsl(ChannelHandlerContext context) {
SslHandler sslHandler = null;
try {
sslHandler = newSslHandler(context, sslContext);
context.pipeline().replace(this, newSslHandlerName(), sslHandler);
sslHandler = null;
} finally {
// Since the SslHandler was not inserted into the pipeline the ownership of the SSLEngine was not
// transferred to the SslHandler.
if (sslHandler != null) {
ReferenceCountUtil.safeRelease(sslHandler.engine());
}
}
}
private void handleNonSsl(ChannelHandlerContext context) {
ChannelHandler handler = newNonSslHandler(context);
if (handler != null) {
context.pipeline().replace(this, newNonSslHandlerName(), handler);
} else {
context.pipeline().remove(this);
}
}
/**
* Optionally specify the SSL handler name, this method may return {@code null}.
* @return the name of the SSL handler.
*/
protected String newSslHandlerName() {
return null;
}
/**
* Override to configure the SslHandler eg. {@link SSLParameters#setEndpointIdentificationAlgorithm(String)}.
* The hostname and port is not known by this method so servers may want to override this method and use the
* {@link SslContext#newHandler(ByteBufAllocator, String, int)} variant.
*
* @param context the {@link ChannelHandlerContext} to use.
* @param sslContext the {@link SSLContext} to use.
* @return the {@link SslHandler} which will replace the {@link OptionalSslHandler} in the pipeline if the
* traffic is SSL.
*/
protected SslHandler newSslHandler(ChannelHandlerContext context, SslContext sslContext) {
return sslContext.newHandler(context.alloc());
}
/**
* Optionally specify the non-SSL handler name, this method may return {@code null}.
* @return the name of the non-SSL handler.
*/
protected String newNonSslHandlerName() {
return null;
}
/**
* Override to configure the ChannelHandler.
* @param context the {@link ChannelHandlerContext} to use.
* @return the {@link ChannelHandler} which will replace the {@link OptionalSslHandler} in the pipeline
* or {@code null} to simply remove the {@link OptionalSslHandler} if the traffic is non-SSL.
*/
protected ChannelHandler newNonSslHandler(ChannelHandlerContext context) {
return null;
}
}

View File

@ -0,0 +1,121 @@
/*
* Copyright 2017 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.ssl;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import static org.mockito.Mockito.*;
@RunWith(MockitoJUnitRunner.class)
public class OptionalSslHandlerTest {
private static final String SSL_HANDLER_NAME = "sslhandler";
private static final String HANDLER_NAME = "handler";
@Mock
private ChannelHandlerContext context;
@Mock
private SslContext sslContext;
@Mock
private ChannelPipeline pipeline;
@Before
public void setUp() throws Exception {
when(context.pipeline()).thenReturn(pipeline);
}
@Test
public void handlerRemoved() throws Exception {
OptionalSslHandler handler = new OptionalSslHandler(sslContext);
final ByteBuf payload = Unpooled.copiedBuffer("plaintext".getBytes());
try {
handler.decode(context, payload, null);
verify(pipeline).remove(handler);
} finally {
payload.release();
}
}
@Test
public void handlerReplaced() throws Exception {
final ChannelHandler nonSslHandler = Mockito.mock(ChannelHandler.class);
OptionalSslHandler handler = new OptionalSslHandler(sslContext) {
@Override
protected ChannelHandler newNonSslHandler(ChannelHandlerContext context) {
return nonSslHandler;
}
@Override
protected String newNonSslHandlerName() {
return HANDLER_NAME;
}
};
final ByteBuf payload = Unpooled.copiedBuffer("plaintext".getBytes());
try {
handler.decode(context, payload, null);
verify(pipeline).replace(handler, HANDLER_NAME, nonSslHandler);
} finally {
payload.release();
}
}
@Test
public void sslHandlerReplaced() throws Exception {
final SslHandler sslHandler = Mockito.mock(SslHandler.class);
OptionalSslHandler handler = new OptionalSslHandler(sslContext) {
@Override
protected SslHandler newSslHandler(ChannelHandlerContext context, SslContext sslContext) {
return sslHandler;
}
@Override
protected String newSslHandlerName() {
return SSL_HANDLER_NAME;
}
};
final ByteBuf payload = Unpooled.wrappedBuffer(new byte[] { 22, 3, 1, 0, 5 });
try {
handler.decode(context, payload, null);
verify(pipeline).replace(handler, SSL_HANDLER_NAME, sslHandler);
} finally {
payload.release();
}
}
@Test
public void decodeBuffered() throws Exception {
OptionalSslHandler handler = new OptionalSslHandler(sslContext);
final ByteBuf payload = Unpooled.wrappedBuffer(new byte[] { 22, 3 });
try {
handler.decode(context, payload, null);
verifyZeroInteractions(pipeline);
} finally {
payload.release();
}
}
}