Added SSL support to the HTTP tunneling client socket channel

This commit is contained in:
Trustin Lee 2009-06-25 09:47:45 +00:00
parent fb8cc156f3
commit dde115525b
2 changed files with 123 additions and 186 deletions

View File

@ -30,6 +30,9 @@ import java.net.URI;
import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock; import java.util.concurrent.locks.ReentrantLock;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import org.jboss.netty.buffer.ChannelBuffer; import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.buffer.ChannelBuffers; import org.jboss.netty.buffer.ChannelBuffers;
import org.jboss.netty.channel.AbstractChannel; import org.jboss.netty.channel.AbstractChannel;
@ -39,14 +42,15 @@ import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.channel.ChannelPipeline; import org.jboss.netty.channel.ChannelPipeline;
import org.jboss.netty.channel.ChannelPipelineCoverage; import org.jboss.netty.channel.ChannelPipelineCoverage;
import org.jboss.netty.channel.ChannelSink; import org.jboss.netty.channel.ChannelSink;
import org.jboss.netty.channel.ChannelStateEvent;
import org.jboss.netty.channel.DefaultChannelPipeline; import org.jboss.netty.channel.DefaultChannelPipeline;
import org.jboss.netty.channel.ExceptionEvent; import org.jboss.netty.channel.ExceptionEvent;
import org.jboss.netty.channel.MessageEvent; import org.jboss.netty.channel.MessageEvent;
import org.jboss.netty.channel.SimpleChannelUpstreamHandler; import org.jboss.netty.channel.SimpleChannelUpstreamHandler;
import org.jboss.netty.channel.socket.ClientSocketChannelFactory; import org.jboss.netty.channel.socket.ClientSocketChannelFactory;
import org.jboss.netty.channel.socket.SocketChannel; import org.jboss.netty.channel.socket.SocketChannel;
import org.jboss.netty.channel.socket.SocketChannelConfig;
import org.jboss.netty.handler.codec.frame.DelimiterBasedFrameDecoder; import org.jboss.netty.handler.codec.frame.DelimiterBasedFrameDecoder;
import org.jboss.netty.handler.ssl.SslHandler;
import org.jboss.netty.logging.InternalLogger; import org.jboss.netty.logging.InternalLogger;
import org.jboss.netty.logging.InternalLoggerFactory; import org.jboss.netty.logging.InternalLoggerFactory;
import org.jboss.netty.util.internal.LinkedTransferQueue; import org.jboss.netty.util.internal.LinkedTransferQueue;
@ -63,6 +67,7 @@ class HttpTunnelingClientSocketChannel extends AbstractChannel
static final InternalLogger logger = static final InternalLogger logger =
InternalLoggerFactory.getInstance(HttpTunnelingClientSocketChannel.class); InternalLoggerFactory.getInstance(HttpTunnelingClientSocketChannel.class);
private final HttpTunnelingSocketChannelConfig config;
private final Lock reconnectLock = new ReentrantLock(); private final Lock reconnectLock = new ReentrantLock();
volatile boolean awaitingInitialResponse = true; volatile boolean awaitingInitialResponse = true;
@ -82,9 +87,8 @@ class HttpTunnelingClientSocketChannel extends AbstractChannel
SocketChannel channel; SocketChannel channel;
private final DelimiterBasedFrameDecoder handler = new DelimiterBasedFrameDecoder(8092, ChannelBuffers.wrappedBuffer(new byte[] { '\r', '\n' })); private final DelimiterBasedFrameDecoder decoder = new DelimiterBasedFrameDecoder(8092, ChannelBuffers.wrappedBuffer(new byte[] { '\r', '\n' }));
private final HttpTunnelingClientSocketChannel.ServletChannelHandler handler = new ServletChannelHandler();
private final HttpTunnelingClientSocketChannel.ServletChannelHandler servletHandler = new ServletChannelHandler();
private HttpTunnelAddress remoteAddress; private HttpTunnelAddress remoteAddress;
@ -94,18 +98,16 @@ class HttpTunnelingClientSocketChannel extends AbstractChannel
ChannelSink sink, ClientSocketChannelFactory clientSocketChannelFactory) { ChannelSink sink, ClientSocketChannelFactory clientSocketChannelFactory) {
super(null, factory, pipeline, sink); super(null, factory, pipeline, sink);
this.clientSocketChannelFactory = clientSocketChannelFactory; this.clientSocketChannelFactory = clientSocketChannelFactory;
DefaultChannelPipeline channelPipeline = new DefaultChannelPipeline(); createSocketChannel();
channelPipeline.addLast("DelimiterBasedFrameDecoder", handler); config = new HttpTunnelingSocketChannelConfig(this);
channelPipeline.addLast("servletHandler", servletHandler);
channel = clientSocketChannelFactory.newChannel(channelPipeline);
fireChannelOpen(this); fireChannelOpen(this);
} }
public SocketChannelConfig getConfig() { public HttpTunnelingSocketChannelConfig getConfig() {
return channel.getConfig(); return config;
} }
public InetSocketAddress getLocalAddress() { public InetSocketAddress getLocalAddress() {
@ -117,7 +119,7 @@ class HttpTunnelingClientSocketChannel extends AbstractChannel
} }
public boolean isBound() { public boolean isBound() {
return channel.isOpen(); return channel.isBound();
} }
public boolean isConnected() { public boolean isConnected() {
@ -149,10 +151,7 @@ class HttpTunnelingClientSocketChannel extends AbstractChannel
URI url = remoteAddress.getUri(); URI url = remoteAddress.getUri();
if (reconnect) { if (reconnect) {
closeSocket(); closeSocket();
DefaultChannelPipeline channelPipeline = new DefaultChannelPipeline(); createSocketChannel();
channelPipeline.addLast("DelimiterBasedFrameDecoder", handler);
channelPipeline.addLast("servletHandler", servletHandler);
channel = clientSocketChannelFactory.newChannel(channelPipeline);
} }
SocketAddress connectAddress = new InetSocketAddress(url.getHost(), url.getPort()); SocketAddress connectAddress = new InetSocketAddress(url.getHost(), url.getPort());
channel.connect(connectAddress).awaitUninterruptibly(); channel.connect(connectAddress).awaitUninterruptibly();
@ -170,6 +169,16 @@ class HttpTunnelingClientSocketChannel extends AbstractChannel
channel.write(ChannelBuffers.copiedBuffer(msg, "ASCII")); channel.write(ChannelBuffers.copiedBuffer(msg, "ASCII"));
} }
/**
*
*/
private void createSocketChannel() {
DefaultChannelPipeline channelPipeline = new DefaultChannelPipeline();
channelPipeline.addLast("decoder", decoder);
channelPipeline.addLast("handler", handler);
channel = clientSocketChannelFactory.newChannel(channelPipeline);
}
int sendChunk(ChannelBuffer a) { int sendChunk(ChannelBuffer a) {
int size = a.readableBytes(); int size = a.readableBytes();
String hex = Integer.toHexString(size) + HttpTunnelingClientSocketPipelineSink.LINE_TERMINATOR; String hex = Integer.toHexString(size) + HttpTunnelingClientSocketPipelineSink.LINE_TERMINATOR;
@ -241,6 +250,23 @@ class HttpTunnelingClientSocketChannel extends AbstractChannel
class ServletChannelHandler extends SimpleChannelUpstreamHandler { class ServletChannelHandler extends SimpleChannelUpstreamHandler {
int nextChunkSize = -1; int nextChunkSize = -1;
@Override
public void channelConnected(ChannelHandlerContext ctx,
ChannelStateEvent e) throws Exception {
SSLContext sslContext = getConfig().getSslContext();
if (sslContext != null) {
// FIXME: specify peer host and port.
SSLEngine engine = sslContext.createSSLEngine();
engine.setUseClientMode(true);
SocketChannel ch = (SocketChannel) e.getChannel();
SslHandler sslHandler = new SslHandler(engine);
ch.getPipeline().addFirst("ssl", sslHandler);
sslHandler.handshake(channel);
}
}
@Override @Override
public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception { public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception {
ChannelBuffer buf = (ChannelBuffer) e.getMessage(); ChannelBuffer buf = (ChannelBuffer) e.getMessage();

View File

@ -22,243 +22,154 @@
*/ */
package org.jboss.netty.channel.socket.http; package org.jboss.netty.channel.socket.http;
import java.net.Socket; import java.util.Map;
import java.net.SocketException;
import org.jboss.netty.channel.ChannelException; import javax.net.ssl.SSLContext;
import org.jboss.netty.channel.DefaultChannelConfig;
import org.jboss.netty.buffer.ChannelBufferFactory;
import org.jboss.netty.channel.ChannelPipelineFactory;
import org.jboss.netty.channel.socket.SocketChannelConfig; import org.jboss.netty.channel.socket.SocketChannelConfig;
import org.jboss.netty.util.internal.ConversionUtil;
/** /**
* @author The Netty Project (netty-dev@lists.jboss.org) * @author The Netty Project (netty-dev@lists.jboss.org)
* @author Andy Taylor (andy.taylor@jboss.org) * @author Andy Taylor (andy.taylor@jboss.org)
* @version $Rev$, $Date$ * @version $Rev$, $Date$
*/ */
final class HttpTunnelingSocketChannelConfig extends DefaultChannelConfig final class HttpTunnelingSocketChannelConfig implements SocketChannelConfig {
implements SocketChannelConfig {
final Socket socket; private final HttpTunnelingClientSocketChannel channel;
private Integer trafficClass; private volatile SSLContext sslContext;
private Boolean tcpNoDelay;
private Integer soLinger;
private Integer sendBufferSize;
private Boolean reuseAddress;
private Integer receiveBufferSize;
private Integer connectionTime;
private Integer latency;
private Integer bandwidth;
private Boolean keepAlive;
/** /**
* Creates a new instance. * Creates a new instance.
*/ */
HttpTunnelingSocketChannelConfig(Socket socket) { HttpTunnelingSocketChannelConfig(HttpTunnelingClientSocketChannel channel) {
this.socket = socket; this.channel = channel;
}
public SSLContext getSslContext() {
return sslContext;
}
public void setSslContext(SSLContext sslContext) {
this.sslContext = sslContext;
}
public void setOptions(Map<String, Object> options) {
channel.channel.getConfig().setOptions(options);
SSLContext sslContext = (SSLContext) options.get("sslContext");
if (sslContext != null) {
setSslContext(sslContext);
}
} }
@Override
public boolean setOption(String key, Object value) { public boolean setOption(String key, Object value) {
if (key.equals("receiveBufferSize")) { if (channel.channel.getConfig().setOption(key, value)) {
setReceiveBufferSize(ConversionUtil.toInt(value)); return true;
} else if (key.equals("sendBufferSize")) { }
setSendBufferSize(ConversionUtil.toInt(value));
} else if (key.equals("tcpNoDelay")) { if (key.equals("sslContext")) {
setTcpNoDelay(ConversionUtil.toBoolean(value)); setSslContext((SSLContext) value);
} else if (key.equals("keepAlive")) {
setKeepAlive(ConversionUtil.toBoolean(value));
} else if (key.equals("reuseAddress")) {
setReuseAddress(ConversionUtil.toBoolean(value));
} else if (key.equals("soLinger")) {
setSoLinger(ConversionUtil.toInt(value));
} else if (key.equals("trafficClass")) {
setTrafficClass(ConversionUtil.toInt(value));
} else { } else {
return false; return false;
} }
return true; return true;
} }
public int getReceiveBufferSize() { public int getReceiveBufferSize() {
try { return channel.channel.getConfig().getReceiveBufferSize();
return socket.getReceiveBufferSize();
}
catch (SocketException e) {
throw new ChannelException(e);
}
} }
public int getSendBufferSize() { public int getSendBufferSize() {
try { return channel.channel.getConfig().getSendBufferSize();
return socket.getSendBufferSize();
}
catch (SocketException e) {
throw new ChannelException(e);
}
} }
public int getSoLinger() { public int getSoLinger() {
try { return channel.channel.getConfig().getSoLinger();
return socket.getSoLinger();
}
catch (SocketException e) {
throw new ChannelException(e);
}
} }
public int getTrafficClass() { public int getTrafficClass() {
try { return channel.channel.getConfig().getTrafficClass();
return socket.getTrafficClass();
}
catch (SocketException e) {
throw new ChannelException(e);
}
} }
public boolean isKeepAlive() { public boolean isKeepAlive() {
try { return channel.channel.getConfig().isKeepAlive();
return socket.getKeepAlive();
}
catch (SocketException e) {
throw new ChannelException(e);
}
} }
public boolean isReuseAddress() { public boolean isReuseAddress() {
try { return channel.channel.getConfig().isReuseAddress();
return socket.getReuseAddress();
}
catch (SocketException e) {
throw new ChannelException(e);
}
} }
public boolean isTcpNoDelay() { public boolean isTcpNoDelay() {
try { return channel.channel.getConfig().isTcpNoDelay();
return socket.getTcpNoDelay();
}
catch (SocketException e) {
throw new ChannelException(e);
}
} }
public void setKeepAlive(boolean keepAlive) { public void setKeepAlive(boolean keepAlive) {
this.keepAlive = keepAlive; channel.channel.getConfig().setKeepAlive(keepAlive);
try {
socket.setKeepAlive(keepAlive);
}
catch (SocketException e) {
throw new ChannelException(e);
}
} }
public void setPerformancePreferences( public void setPerformancePreferences(
int connectionTime, int latency, int bandwidth) { int connectionTime, int latency, int bandwidth) {
this.connectionTime = connectionTime; channel.channel.getConfig().setPerformancePreferences(connectionTime, latency, bandwidth);
this.latency = latency;
this.bandwidth = bandwidth;
socket.setPerformancePreferences(connectionTime, latency, bandwidth);
} }
public void setReceiveBufferSize(int receiveBufferSize) { public void setReceiveBufferSize(int receiveBufferSize) {
this.receiveBufferSize = receiveBufferSize; channel.channel.getConfig().setReceiveBufferSize(receiveBufferSize);
try {
socket.setReceiveBufferSize(receiveBufferSize);
}
catch (SocketException e) {
throw new ChannelException(e);
}
} }
public void setReuseAddress(boolean reuseAddress) { public void setReuseAddress(boolean reuseAddress) {
this.reuseAddress = reuseAddress; channel.channel.getConfig().setReuseAddress(reuseAddress);
try {
socket.setReuseAddress(reuseAddress);
}
catch (SocketException e) {
throw new ChannelException(e);
}
} }
public void setSendBufferSize(int sendBufferSize) { public void setSendBufferSize(int sendBufferSize) {
this.sendBufferSize = sendBufferSize; channel.channel.getConfig().setSendBufferSize(sendBufferSize);
if (socket != null) {
try {
socket.setSendBufferSize(sendBufferSize);
}
catch (SocketException e) {
throw new ChannelException(e);
}
}
} }
public void setSoLinger(int soLinger) { public void setSoLinger(int soLinger) {
this.soLinger = soLinger; channel.channel.getConfig().setSoLinger(soLinger);
try {
if (soLinger < 0) {
socket.setSoLinger(false, 0);
}
else {
socket.setSoLinger(true, soLinger);
}
}
catch (SocketException e) {
throw new ChannelException(e);
}
} }
public void setTcpNoDelay(boolean tcpNoDelay) { public void setTcpNoDelay(boolean tcpNoDelay) {
this.tcpNoDelay = tcpNoDelay; channel.channel.getConfig().setTcpNoDelay(tcpNoDelay);
try {
socket.setTcpNoDelay(tcpNoDelay);
}
catch (SocketException e) {
throw new ChannelException(e);
}
} }
public void setTrafficClass(int trafficClass) { public void setTrafficClass(int trafficClass) {
this.trafficClass = trafficClass; channel.channel.getConfig().setTrafficClass(trafficClass);
try {
socket.setTrafficClass(trafficClass);
}
catch (SocketException e) {
throw new ChannelException(e);
}
} }
HttpTunnelingSocketChannelConfig copyConfig(Socket socket) { public ChannelBufferFactory getBufferFactory() {
HttpTunnelingSocketChannelConfig config = new HttpTunnelingSocketChannelConfig(socket); return channel.channel.getConfig().getBufferFactory();
config.setConnectTimeoutMillis(getConnectTimeoutMillis()); }
if (trafficClass != null) {
config.setTrafficClass(trafficClass); public int getConnectTimeoutMillis() {
} return channel.channel.getConfig().getConnectTimeoutMillis();
if (tcpNoDelay != null) { }
config.setTcpNoDelay(tcpNoDelay);
} public ChannelPipelineFactory getPipelineFactory() {
if (soLinger != null) { return channel.channel.getConfig().getPipelineFactory();
config.setSoLinger(soLinger); }
}
if (sendBufferSize != null) { @Deprecated
config.setSendBufferSize(sendBufferSize); public int getWriteTimeoutMillis() {
} return channel.channel.getConfig().getWriteTimeoutMillis();
if (reuseAddress != null) { }
config.setReuseAddress(reuseAddress);
} public void setBufferFactory(ChannelBufferFactory bufferFactory) {
if (receiveBufferSize != null) { channel.channel.getConfig().setBufferFactory(bufferFactory);
config.setReceiveBufferSize(receiveBufferSize); }
}
if (keepAlive != null) { public void setConnectTimeoutMillis(int connectTimeoutMillis) {
config.setKeepAlive(keepAlive); channel.channel.getConfig().setConnectTimeoutMillis(connectTimeoutMillis);
} }
if (connectionTime != null) {
config.setPerformancePreferences(connectionTime, latency, bandwidth); public void setPipelineFactory(ChannelPipelineFactory pipelineFactory) {
} channel.channel.getConfig().setPipelineFactory(pipelineFactory);
return config; }
@Deprecated
public void setWriteTimeoutMillis(int writeTimeoutMillis) {
channel.channel.getConfig().setWriteTimeoutMillis(writeTimeoutMillis);
} }
} }