Fix a bug where SslHandler does not respect the startTls flag
- Fixes #856 - Add a dedicated test case: SocketStartTlsTest
This commit is contained in:
parent
2960af851a
commit
7d80182e51
@ -319,7 +319,7 @@ public class SslHandler
|
||||
}
|
||||
engine.beginHandshake();
|
||||
handshakePromises.add(promise);
|
||||
flush(ctx, ctx.newPromise());
|
||||
flush0(ctx, ctx.newPromise(), true);
|
||||
} catch (Exception e) {
|
||||
if (promise.tryFailure(e)) {
|
||||
ctx.fireExceptionCaught(e);
|
||||
@ -408,6 +408,10 @@ public class SslHandler
|
||||
|
||||
@Override
|
||||
public void flush(final ChannelHandlerContext ctx, ChannelPromise promise) throws Exception {
|
||||
flush0(ctx, promise, false);
|
||||
}
|
||||
|
||||
private void flush0(ChannelHandlerContext ctx, ChannelPromise promise, boolean internal) throws Exception {
|
||||
final ByteBuf in = ctx.outboundByteBuffer();
|
||||
final ByteBuf out = ctx.nextOutboundByteBuffer();
|
||||
|
||||
@ -415,7 +419,7 @@ public class SslHandler
|
||||
|
||||
// Do not encrypt the first write request if this handler is
|
||||
// created with startTLS flag turned on.
|
||||
if (startTls && !sentFirstMessage) {
|
||||
if (!internal && startTls && !sentFirstMessage) {
|
||||
sentFirstMessage = true;
|
||||
out.writeBytes(in);
|
||||
ctx.flush(promise);
|
||||
@ -824,7 +828,7 @@ public class SslHandler
|
||||
}
|
||||
|
||||
if (wrapLater) {
|
||||
flush(ctx, ctx.newPromise());
|
||||
flush0(ctx, ctx.newPromise(), true);
|
||||
}
|
||||
} catch (SSLException e) {
|
||||
setHandshakeFailure(e);
|
||||
@ -932,7 +936,7 @@ public class SslHandler
|
||||
engine.closeOutbound();
|
||||
|
||||
ChannelPromise closeNotifyFuture = ctx.newPromise();
|
||||
flush(ctx, closeNotifyFuture);
|
||||
flush0(ctx, closeNotifyFuture, true);
|
||||
safeClose(ctx, closeNotifyFuture, promise);
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,226 @@
|
||||
/*
|
||||
* Copyright 2012 The Netty Project
|
||||
*
|
||||
* The Netty Project licenses this file to you under the Apache License,
|
||||
* version 2.0 (the "License"); you may not use this file except in compliance
|
||||
* with the License. You may obtain a copy of the License at:
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*/
|
||||
package io.netty.testsuite.transport.socket;
|
||||
|
||||
import io.netty.bootstrap.Bootstrap;
|
||||
import io.netty.bootstrap.ServerBootstrap;
|
||||
import io.netty.channel.Channel;
|
||||
import io.netty.channel.ChannelFuture;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundMessageHandlerAdapter;
|
||||
import io.netty.channel.ChannelInitializer;
|
||||
import io.netty.channel.ChannelPipeline;
|
||||
import io.netty.channel.DefaultEventExecutorGroup;
|
||||
import io.netty.channel.EventExecutorGroup;
|
||||
import io.netty.channel.socket.SocketChannel;
|
||||
import io.netty.handler.codec.LineBasedFrameDecoder;
|
||||
import io.netty.handler.codec.string.StringDecoder;
|
||||
import io.netty.handler.codec.string.StringEncoder;
|
||||
import io.netty.handler.logging.ByteLoggingHandler;
|
||||
import io.netty.handler.logging.LogLevel;
|
||||
import io.netty.handler.ssl.SslHandler;
|
||||
import io.netty.testsuite.util.BogusSslContextFactory;
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Test;
|
||||
|
||||
import javax.net.ssl.SSLEngine;
|
||||
import java.io.IOException;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
public class SocketStartTlsTest extends AbstractSocketTest {
|
||||
|
||||
private static final LogLevel LOG_LEVEL = LogLevel.TRACE;
|
||||
private static EventExecutorGroup executor;
|
||||
|
||||
@BeforeClass
|
||||
public static void createExecutor() {
|
||||
executor = new DefaultEventExecutorGroup(2);
|
||||
}
|
||||
|
||||
@AfterClass
|
||||
public static void shutdownExecutor() {
|
||||
executor.shutdown();
|
||||
}
|
||||
|
||||
@Test(timeout = 30000)
|
||||
public void testStartTls() throws Throwable {
|
||||
run();
|
||||
}
|
||||
|
||||
public void testStartTls(ServerBootstrap sb, Bootstrap cb) throws Throwable {
|
||||
final EventExecutorGroup executor = SocketStartTlsTest.executor;
|
||||
final SSLEngine sse = BogusSslContextFactory.getServerContext().createSSLEngine();
|
||||
final SSLEngine cse = BogusSslContextFactory.getClientContext().createSSLEngine();
|
||||
|
||||
final StartTlsServerHandler sh = new StartTlsServerHandler(sse);
|
||||
final StartTlsClientHandler ch = new StartTlsClientHandler(cse);
|
||||
|
||||
sb.childHandler(new ChannelInitializer<SocketChannel>() {
|
||||
@Override
|
||||
public void initChannel(SocketChannel sch) throws Exception {
|
||||
ChannelPipeline p = sch.pipeline();
|
||||
p.addLast("logger", new ByteLoggingHandler(LOG_LEVEL));
|
||||
p.addLast(new LineBasedFrameDecoder(64), new StringDecoder(), new StringEncoder());
|
||||
p.addLast(executor, sh);
|
||||
}
|
||||
});
|
||||
|
||||
cb.handler(new ChannelInitializer<SocketChannel>() {
|
||||
@Override
|
||||
public void initChannel(SocketChannel sch) throws Exception {
|
||||
ChannelPipeline p = sch.pipeline();
|
||||
p.addLast("logger", new ByteLoggingHandler(LOG_LEVEL));
|
||||
p.addLast(new LineBasedFrameDecoder(64), new StringDecoder(), new StringEncoder());
|
||||
p.addLast(executor, ch);
|
||||
}
|
||||
});
|
||||
|
||||
Channel sc = sb.bind().sync().channel();
|
||||
Channel cc = cb.connect().sync().channel();
|
||||
|
||||
while (cc.isActive()) {
|
||||
if (sh.exception.get() != null) {
|
||||
break;
|
||||
}
|
||||
if (ch.exception.get() != null) {
|
||||
break;
|
||||
}
|
||||
|
||||
try {
|
||||
Thread.sleep(1);
|
||||
} catch (InterruptedException e) {
|
||||
// Ignore.
|
||||
}
|
||||
}
|
||||
|
||||
while (sh.channel.isActive()) {
|
||||
if (sh.exception.get() != null) {
|
||||
break;
|
||||
}
|
||||
if (ch.exception.get() != null) {
|
||||
break;
|
||||
}
|
||||
|
||||
try {
|
||||
Thread.sleep(1);
|
||||
} catch (InterruptedException e) {
|
||||
// Ignore.
|
||||
}
|
||||
}
|
||||
|
||||
sh.channel.close().awaitUninterruptibly();
|
||||
cc.close().awaitUninterruptibly();
|
||||
sc.close().awaitUninterruptibly();
|
||||
|
||||
if (sh.exception.get() != null && !(sh.exception.get() instanceof IOException)) {
|
||||
throw sh.exception.get();
|
||||
}
|
||||
if (ch.exception.get() != null && !(ch.exception.get() instanceof IOException)) {
|
||||
throw ch.exception.get();
|
||||
}
|
||||
if (sh.exception.get() != null) {
|
||||
throw sh.exception.get();
|
||||
}
|
||||
if (ch.exception.get() != null) {
|
||||
throw ch.exception.get();
|
||||
}
|
||||
}
|
||||
|
||||
private class StartTlsClientHandler extends ChannelInboundMessageHandlerAdapter<String> {
|
||||
private final SslHandler sslHandler;
|
||||
private ChannelFuture handshakeFuture;
|
||||
final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
|
||||
|
||||
StartTlsClientHandler(SSLEngine engine) {
|
||||
engine.setUseClientMode(true);
|
||||
sslHandler = new SslHandler(engine);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelActive(ChannelHandlerContext ctx)
|
||||
throws Exception {
|
||||
ctx.write("StartTlsRequest\n");
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void messageReceived(final ChannelHandlerContext ctx, String msg) throws Exception {
|
||||
if ("StartTlsResponse".equals(msg)) {
|
||||
ctx.pipeline().addAfter("logger", "ssl", sslHandler);
|
||||
handshakeFuture = sslHandler.handshake();
|
||||
ctx.write("EncryptedRequest\n");
|
||||
return;
|
||||
}
|
||||
|
||||
assertEquals("EncryptedResponse", msg);
|
||||
assertNotNull(handshakeFuture);
|
||||
assertTrue(handshakeFuture.isSuccess());
|
||||
ctx.close();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void exceptionCaught(ChannelHandlerContext ctx,
|
||||
Throwable cause) throws Exception {
|
||||
if (logger.isWarnEnabled()) {
|
||||
logger.warn("Unexpected exception from the client side", cause);
|
||||
}
|
||||
|
||||
exception.compareAndSet(null, cause);
|
||||
ctx.close();
|
||||
}
|
||||
}
|
||||
|
||||
private class StartTlsServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
|
||||
private final SslHandler sslHandler;
|
||||
volatile Channel channel;
|
||||
final AtomicReference<Throwable> exception = new AtomicReference<Throwable>();
|
||||
|
||||
StartTlsServerHandler(SSLEngine engine) {
|
||||
engine.setUseClientMode(false);
|
||||
sslHandler = new SslHandler(engine, true);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void channelActive(ChannelHandlerContext ctx) throws Exception {
|
||||
channel = ctx.channel();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void messageReceived(final ChannelHandlerContext ctx, String msg) throws Exception {
|
||||
if ("StartTlsRequest".equals(msg)) {
|
||||
ctx.pipeline().addAfter("logger", "ssl", sslHandler);
|
||||
ctx.write("StartTlsResponse\n");
|
||||
return;
|
||||
}
|
||||
|
||||
assertEquals("EncryptedRequest", msg);
|
||||
ctx.write("EncryptedResponse\n");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void exceptionCaught(ChannelHandlerContext ctx,
|
||||
Throwable cause) throws Exception {
|
||||
if (logger.isWarnEnabled()) {
|
||||
logger.warn("Unexpected exception from the server side", cause);
|
||||
}
|
||||
|
||||
exception.compareAndSet(null, cause);
|
||||
ctx.close();
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user