diff --git a/src/main/java/org/jboss/netty/channel/socket/http/HttpTunnelingClientSocketChannel.java b/src/main/java/org/jboss/netty/channel/socket/http/HttpTunnelingClientSocketChannel.java index fe3f7f762a..734e425afd 100644 --- a/src/main/java/org/jboss/netty/channel/socket/http/HttpTunnelingClientSocketChannel.java +++ b/src/main/java/org/jboss/netty/channel/socket/http/HttpTunnelingClientSocketChannel.java @@ -30,6 +30,9 @@ import java.net.URI; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; + import org.jboss.netty.buffer.ChannelBuffer; import org.jboss.netty.buffer.ChannelBuffers; import org.jboss.netty.channel.AbstractChannel; @@ -39,14 +42,15 @@ import org.jboss.netty.channel.ChannelHandlerContext; import org.jboss.netty.channel.ChannelPipeline; import org.jboss.netty.channel.ChannelPipelineCoverage; import org.jboss.netty.channel.ChannelSink; +import org.jboss.netty.channel.ChannelStateEvent; import org.jboss.netty.channel.DefaultChannelPipeline; import org.jboss.netty.channel.ExceptionEvent; import org.jboss.netty.channel.MessageEvent; import org.jboss.netty.channel.SimpleChannelUpstreamHandler; import org.jboss.netty.channel.socket.ClientSocketChannelFactory; import org.jboss.netty.channel.socket.SocketChannel; -import org.jboss.netty.channel.socket.SocketChannelConfig; import org.jboss.netty.handler.codec.frame.DelimiterBasedFrameDecoder; +import org.jboss.netty.handler.ssl.SslHandler; import org.jboss.netty.logging.InternalLogger; import org.jboss.netty.logging.InternalLoggerFactory; import org.jboss.netty.util.internal.LinkedTransferQueue; @@ -63,6 +67,7 @@ class HttpTunnelingClientSocketChannel extends AbstractChannel static final InternalLogger logger = InternalLoggerFactory.getInstance(HttpTunnelingClientSocketChannel.class); + private final HttpTunnelingSocketChannelConfig config; private final Lock reconnectLock = new ReentrantLock(); volatile boolean awaitingInitialResponse = true; @@ -82,9 +87,8 @@ class HttpTunnelingClientSocketChannel extends AbstractChannel SocketChannel channel; - private final DelimiterBasedFrameDecoder handler = new DelimiterBasedFrameDecoder(8092, ChannelBuffers.wrappedBuffer(new byte[] { '\r', '\n' })); - - private final HttpTunnelingClientSocketChannel.ServletChannelHandler servletHandler = new ServletChannelHandler(); + private final DelimiterBasedFrameDecoder decoder = new DelimiterBasedFrameDecoder(8092, ChannelBuffers.wrappedBuffer(new byte[] { '\r', '\n' })); + private final HttpTunnelingClientSocketChannel.ServletChannelHandler handler = new ServletChannelHandler(); private HttpTunnelAddress remoteAddress; @@ -94,18 +98,16 @@ class HttpTunnelingClientSocketChannel extends AbstractChannel ChannelSink sink, ClientSocketChannelFactory clientSocketChannelFactory) { super(null, factory, pipeline, sink); + this.clientSocketChannelFactory = clientSocketChannelFactory; - DefaultChannelPipeline channelPipeline = new DefaultChannelPipeline(); - channelPipeline.addLast("DelimiterBasedFrameDecoder", handler); - channelPipeline.addLast("servletHandler", servletHandler); - channel = clientSocketChannelFactory.newChannel(channelPipeline); - + createSocketChannel(); + config = new HttpTunnelingSocketChannelConfig(this); fireChannelOpen(this); } - public SocketChannelConfig getConfig() { - return channel.getConfig(); + public HttpTunnelingSocketChannelConfig getConfig() { + return config; } public InetSocketAddress getLocalAddress() { @@ -117,7 +119,7 @@ class HttpTunnelingClientSocketChannel extends AbstractChannel } public boolean isBound() { - return channel.isOpen(); + return channel.isBound(); } public boolean isConnected() { @@ -149,10 +151,7 @@ class HttpTunnelingClientSocketChannel extends AbstractChannel URI url = remoteAddress.getUri(); if (reconnect) { closeSocket(); - DefaultChannelPipeline channelPipeline = new DefaultChannelPipeline(); - channelPipeline.addLast("DelimiterBasedFrameDecoder", handler); - channelPipeline.addLast("servletHandler", servletHandler); - channel = clientSocketChannelFactory.newChannel(channelPipeline); + createSocketChannel(); } SocketAddress connectAddress = new InetSocketAddress(url.getHost(), url.getPort()); channel.connect(connectAddress).awaitUninterruptibly(); @@ -170,6 +169,16 @@ class HttpTunnelingClientSocketChannel extends AbstractChannel channel.write(ChannelBuffers.copiedBuffer(msg, "ASCII")); } + /** + * + */ + private void createSocketChannel() { + DefaultChannelPipeline channelPipeline = new DefaultChannelPipeline(); + channelPipeline.addLast("decoder", decoder); + channelPipeline.addLast("handler", handler); + channel = clientSocketChannelFactory.newChannel(channelPipeline); + } + int sendChunk(ChannelBuffer a) { int size = a.readableBytes(); String hex = Integer.toHexString(size) + HttpTunnelingClientSocketPipelineSink.LINE_TERMINATOR; @@ -241,6 +250,23 @@ class HttpTunnelingClientSocketChannel extends AbstractChannel class ServletChannelHandler extends SimpleChannelUpstreamHandler { int nextChunkSize = -1; + @Override + public void channelConnected(ChannelHandlerContext ctx, + ChannelStateEvent e) throws Exception { + SSLContext sslContext = getConfig().getSslContext(); + if (sslContext != null) { + // FIXME: specify peer host and port. + SSLEngine engine = sslContext.createSSLEngine(); + engine.setUseClientMode(true); + + SocketChannel ch = (SocketChannel) e.getChannel(); + SslHandler sslHandler = new SslHandler(engine); + ch.getPipeline().addFirst("ssl", sslHandler); + + sslHandler.handshake(channel); + } + } + @Override public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception { ChannelBuffer buf = (ChannelBuffer) e.getMessage(); diff --git a/src/main/java/org/jboss/netty/channel/socket/http/HttpTunnelingSocketChannelConfig.java b/src/main/java/org/jboss/netty/channel/socket/http/HttpTunnelingSocketChannelConfig.java index c321824bff..81e222d085 100644 --- a/src/main/java/org/jboss/netty/channel/socket/http/HttpTunnelingSocketChannelConfig.java +++ b/src/main/java/org/jboss/netty/channel/socket/http/HttpTunnelingSocketChannelConfig.java @@ -22,243 +22,154 @@ */ package org.jboss.netty.channel.socket.http; -import java.net.Socket; -import java.net.SocketException; +import java.util.Map; -import org.jboss.netty.channel.ChannelException; -import org.jboss.netty.channel.DefaultChannelConfig; +import javax.net.ssl.SSLContext; + +import org.jboss.netty.buffer.ChannelBufferFactory; +import org.jboss.netty.channel.ChannelPipelineFactory; import org.jboss.netty.channel.socket.SocketChannelConfig; -import org.jboss.netty.util.internal.ConversionUtil; /** * @author The Netty Project (netty-dev@lists.jboss.org) * @author Andy Taylor (andy.taylor@jboss.org) * @version $Rev$, $Date$ */ -final class HttpTunnelingSocketChannelConfig extends DefaultChannelConfig - implements SocketChannelConfig { +final class HttpTunnelingSocketChannelConfig implements SocketChannelConfig { - final Socket socket; - private Integer trafficClass; - private Boolean tcpNoDelay; - private Integer soLinger; - private Integer sendBufferSize; - private Boolean reuseAddress; - private Integer receiveBufferSize; - private Integer connectionTime; - private Integer latency; - private Integer bandwidth; - private Boolean keepAlive; + private final HttpTunnelingClientSocketChannel channel; + private volatile SSLContext sslContext; /** * Creates a new instance. */ - HttpTunnelingSocketChannelConfig(Socket socket) { - this.socket = socket; + HttpTunnelingSocketChannelConfig(HttpTunnelingClientSocketChannel channel) { + this.channel = channel; + } + + public SSLContext getSslContext() { + return sslContext; + } + + public void setSslContext(SSLContext sslContext) { + this.sslContext = sslContext; + } + + public void setOptions(Map options) { + channel.channel.getConfig().setOptions(options); + SSLContext sslContext = (SSLContext) options.get("sslContext"); + if (sslContext != null) { + setSslContext(sslContext); + } } - @Override public boolean setOption(String key, Object value) { - if (key.equals("receiveBufferSize")) { - setReceiveBufferSize(ConversionUtil.toInt(value)); - } else if (key.equals("sendBufferSize")) { - setSendBufferSize(ConversionUtil.toInt(value)); - } else if (key.equals("tcpNoDelay")) { - setTcpNoDelay(ConversionUtil.toBoolean(value)); - } else if (key.equals("keepAlive")) { - setKeepAlive(ConversionUtil.toBoolean(value)); - } else if (key.equals("reuseAddress")) { - setReuseAddress(ConversionUtil.toBoolean(value)); - } else if (key.equals("soLinger")) { - setSoLinger(ConversionUtil.toInt(value)); - } else if (key.equals("trafficClass")) { - setTrafficClass(ConversionUtil.toInt(value)); + if (channel.channel.getConfig().setOption(key, value)) { + return true; + } + + if (key.equals("sslContext")) { + setSslContext((SSLContext) value); } else { return false; } + return true; } public int getReceiveBufferSize() { - try { - return socket.getReceiveBufferSize(); - } - catch (SocketException e) { - throw new ChannelException(e); - } + return channel.channel.getConfig().getReceiveBufferSize(); } public int getSendBufferSize() { - try { - return socket.getSendBufferSize(); - } - catch (SocketException e) { - throw new ChannelException(e); - } + return channel.channel.getConfig().getSendBufferSize(); } public int getSoLinger() { - try { - return socket.getSoLinger(); - } - catch (SocketException e) { - throw new ChannelException(e); - } + return channel.channel.getConfig().getSoLinger(); } public int getTrafficClass() { - try { - return socket.getTrafficClass(); - } - catch (SocketException e) { - throw new ChannelException(e); - } + return channel.channel.getConfig().getTrafficClass(); } public boolean isKeepAlive() { - try { - return socket.getKeepAlive(); - } - catch (SocketException e) { - throw new ChannelException(e); - } + return channel.channel.getConfig().isKeepAlive(); } public boolean isReuseAddress() { - try { - return socket.getReuseAddress(); - } - catch (SocketException e) { - throw new ChannelException(e); - } + return channel.channel.getConfig().isReuseAddress(); } public boolean isTcpNoDelay() { - try { - return socket.getTcpNoDelay(); - } - catch (SocketException e) { - throw new ChannelException(e); - } + return channel.channel.getConfig().isTcpNoDelay(); } public void setKeepAlive(boolean keepAlive) { - this.keepAlive = keepAlive; - try { - socket.setKeepAlive(keepAlive); - } - catch (SocketException e) { - throw new ChannelException(e); - } + channel.channel.getConfig().setKeepAlive(keepAlive); } public void setPerformancePreferences( int connectionTime, int latency, int bandwidth) { - this.connectionTime = connectionTime; - this.latency = latency; - this.bandwidth = bandwidth; - socket.setPerformancePreferences(connectionTime, latency, bandwidth); - + channel.channel.getConfig().setPerformancePreferences(connectionTime, latency, bandwidth); } public void setReceiveBufferSize(int receiveBufferSize) { - this.receiveBufferSize = receiveBufferSize; - try { - socket.setReceiveBufferSize(receiveBufferSize); - } - catch (SocketException e) { - throw new ChannelException(e); - } - + channel.channel.getConfig().setReceiveBufferSize(receiveBufferSize); } public void setReuseAddress(boolean reuseAddress) { - this.reuseAddress = reuseAddress; - try { - socket.setReuseAddress(reuseAddress); - } - catch (SocketException e) { - throw new ChannelException(e); - } - + channel.channel.getConfig().setReuseAddress(reuseAddress); } public void setSendBufferSize(int sendBufferSize) { - this.sendBufferSize = sendBufferSize; - if (socket != null) { - try { - socket.setSendBufferSize(sendBufferSize); - } - catch (SocketException e) { - throw new ChannelException(e); - } - } + channel.channel.getConfig().setSendBufferSize(sendBufferSize); + } public void setSoLinger(int soLinger) { - this.soLinger = soLinger; - try { - if (soLinger < 0) { - socket.setSoLinger(false, 0); - } - else { - socket.setSoLinger(true, soLinger); - } - } - catch (SocketException e) { - throw new ChannelException(e); - } - + channel.channel.getConfig().setSoLinger(soLinger); } public void setTcpNoDelay(boolean tcpNoDelay) { - this.tcpNoDelay = tcpNoDelay; - try { - socket.setTcpNoDelay(tcpNoDelay); - } - catch (SocketException e) { - throw new ChannelException(e); - } + channel.channel.getConfig().setTcpNoDelay(tcpNoDelay); } public void setTrafficClass(int trafficClass) { - this.trafficClass = trafficClass; - try { - socket.setTrafficClass(trafficClass); - } - catch (SocketException e) { - throw new ChannelException(e); - } - + channel.channel.getConfig().setTrafficClass(trafficClass); } - HttpTunnelingSocketChannelConfig copyConfig(Socket socket) { - HttpTunnelingSocketChannelConfig config = new HttpTunnelingSocketChannelConfig(socket); - config.setConnectTimeoutMillis(getConnectTimeoutMillis()); - if (trafficClass != null) { - config.setTrafficClass(trafficClass); - } - if (tcpNoDelay != null) { - config.setTcpNoDelay(tcpNoDelay); - } - if (soLinger != null) { - config.setSoLinger(soLinger); - } - if (sendBufferSize != null) { - config.setSendBufferSize(sendBufferSize); - } - if (reuseAddress != null) { - config.setReuseAddress(reuseAddress); - } - if (receiveBufferSize != null) { - config.setReceiveBufferSize(receiveBufferSize); - } - if (keepAlive != null) { - config.setKeepAlive(keepAlive); - } - if (connectionTime != null) { - config.setPerformancePreferences(connectionTime, latency, bandwidth); - } - return config; + public ChannelBufferFactory getBufferFactory() { + return channel.channel.getConfig().getBufferFactory(); + } + + public int getConnectTimeoutMillis() { + return channel.channel.getConfig().getConnectTimeoutMillis(); + } + + public ChannelPipelineFactory getPipelineFactory() { + return channel.channel.getConfig().getPipelineFactory(); + } + + @Deprecated + public int getWriteTimeoutMillis() { + return channel.channel.getConfig().getWriteTimeoutMillis(); + } + + public void setBufferFactory(ChannelBufferFactory bufferFactory) { + channel.channel.getConfig().setBufferFactory(bufferFactory); + } + + public void setConnectTimeoutMillis(int connectTimeoutMillis) { + channel.channel.getConfig().setConnectTimeoutMillis(connectTimeoutMillis); + } + + public void setPipelineFactory(ChannelPipelineFactory pipelineFactory) { + channel.channel.getConfig().setPipelineFactory(pipelineFactory); + } + + @Deprecated + public void setWriteTimeoutMillis(int writeTimeoutMillis) { + channel.channel.getConfig().setWriteTimeoutMillis(writeTimeoutMillis); } } \ No newline at end of file