diff --git a/src/test/java/org/jboss/netty/handler/ssl/SslHandshakeRaceTester.java b/src/test/java/org/jboss/netty/handler/ssl/SslHandshakeRaceTester.java new file mode 100644 index 0000000000..fdf99d897e --- /dev/null +++ b/src/test/java/org/jboss/netty/handler/ssl/SslHandshakeRaceTester.java @@ -0,0 +1,197 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package org.jboss.netty.handler.ssl; + +import org.jboss.netty.bootstrap.ClientBootstrap; +import org.jboss.netty.bootstrap.ServerBootstrap; +import org.jboss.netty.buffer.ChannelBuffers; +import org.jboss.netty.channel.Channel; +import org.jboss.netty.channel.ChannelFuture; +import org.jboss.netty.channel.ChannelFutureListener; +import org.jboss.netty.channel.ChannelHandlerContext; +import org.jboss.netty.channel.ChannelPipeline; +import org.jboss.netty.channel.ChannelPipelineFactory; +import org.jboss.netty.channel.ChannelStateEvent; +import org.jboss.netty.channel.Channels; +import org.jboss.netty.channel.ExceptionEvent; +import org.jboss.netty.channel.MessageEvent; +import org.jboss.netty.channel.SimpleChannelUpstreamHandler; +import org.jboss.netty.channel.socket.ClientSocketChannelFactory; +import org.jboss.netty.channel.socket.ServerSocketChannelFactory; +import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory; +import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory; +import org.jboss.netty.channel.socket.oio.OioServerSocketChannelFactory; +import org.jboss.netty.channel.socket.oio.OioClientSocketChannelFactory; +import org.jboss.netty.example.securechat.SecureChatSslContextFactory; + +import javax.net.ssl.SSLEngine; +import java.net.InetSocketAddress; +import java.util.Random; +import java.util.concurrent.atomic.AtomicReference; + +public class SslHandshakeRaceTester { + + private static final Random random = new Random(); + static final byte[] data = new byte[1048576]; + private int count; + + static { + random.nextBytes(data); + } + + public void run(int rounds, boolean nio) throws Throwable { + ClientSocketChannelFactory clientFactory; + if (nio) { + clientFactory = new NioClientSocketChannelFactory(); + } else { + clientFactory = new OioClientSocketChannelFactory(); + } + ClientBootstrap cb = new ClientBootstrap(clientFactory); + cb.setPipelineFactory(new ChannelPipelineFactory() { + public ChannelPipeline getPipeline() throws Exception { + ChannelPipeline cp = Channels.pipeline(); + + SSLEngine cse = SecureChatSslContextFactory.getClientContext().createSSLEngine(); + cse.setUseClientMode(true); + cp.addFirst("ssl", new SslHandler(cse)); + + cp.addLast("handler", new TestHandler()); + return cp; + } + }); + + ServerSocketChannelFactory serverFactory; + if (nio) { + serverFactory = new NioServerSocketChannelFactory(); + } else { + serverFactory = new OioServerSocketChannelFactory(); + } + ServerBootstrap sb = new ServerBootstrap(serverFactory); + + sb.setPipelineFactory(new ChannelPipelineFactory() { + public ChannelPipeline getPipeline() throws Exception { + ChannelPipeline cp = Channels.pipeline(); + + cp.addFirst("counter", new SimpleChannelUpstreamHandler() { + @Override + public void channelConnected(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception { + ctx.getPipeline().get(SslHandler.class).handshake().addListener(new ChannelFutureListener() { + public void operationComplete(ChannelFuture future) throws Exception { + if (!future.isSuccess()) { + future.getCause().printStackTrace(); + future.getChannel().close(); + } + } + }); + + ++count; + System.out.println("Connection #" + count); + } + }); + + SSLEngine sse = SecureChatSslContextFactory.getServerContext().createSSLEngine(); + sse.setUseClientMode(false); + cp.addFirst("ssl", new SslHandler(sse)); + + cp.addLast("handler", new TestHandler()); + return cp; + } + }); + + Channel sc = sb.bind(new InetSocketAddress(0)); + int port = ((InetSocketAddress) sc.getLocalAddress()).getPort(); + + for (int i = 0; i < rounds; i++) { + connectAndSend(cb, port); + } + + cb.shutdown(); + cb.releaseExternalResources(); + sc.close().awaitUninterruptibly(); + sb.shutdown(); + sb.releaseExternalResources(); + } + + private static void connectAndSend(ClientBootstrap cb, int port) throws Throwable { + ChannelFuture ccf = cb.connect(new InetSocketAddress("127.0.0.1", port)); + ccf.awaitUninterruptibly(); + if (!ccf.isSuccess()) { + ccf.getCause().printStackTrace(); + throw ccf.getCause(); + } + TestHandler ch = ccf.getChannel().getPipeline().get(TestHandler.class); + + Channel cc = ccf.getChannel(); + ChannelFuture hf = cc.getPipeline().get(SslHandler.class).handshake(); + hf.awaitUninterruptibly(); + if (!hf.isSuccess()) { + hf.getCause().printStackTrace(); + ch.channel.close(); + throw hf.getCause(); + } + + for (int i = 0; i < data.length;) { + int length = Math.min(random.nextInt(1024 * 64), data.length - i); + ChannelFuture future = cc.write(ChannelBuffers.wrappedBuffer(data, i, length)); + i += length; + if (i >= data.length) { + future.awaitUninterruptibly(); + } + } + + ch.channel.close().awaitUninterruptibly(); + + if (ch.exception.get() != null) { + throw ch.exception.get(); + } + } + + private static class TestHandler extends SimpleChannelUpstreamHandler { + volatile Channel channel; + final AtomicReference exception = new AtomicReference(); + + @Override + public void channelOpen(ChannelHandlerContext ctx, ChannelStateEvent e) + throws Exception { + channel = e.getChannel(); + } + + @Override + public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) + throws Exception { + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) + throws Exception { + e.getCause().printStackTrace(); + + exception.compareAndSet(null, e.getCause()); + e.getChannel().close(); + } + } + + public static void main(String[] args) throws Throwable { + int count = 20000; + boolean nio = false; + if (args.length == 2) { + count = Integer.parseInt(args[0]); + nio = Boolean.parseBoolean(args[1]); + } + SslHandshakeRaceTester test = new SslHandshakeRaceTester(); + test.run(count, nio); + } +}