Added SSL support to the HTTP tunneling client socket channel

This commit is contained in:
Trustin Lee 2009-06-25 09:47:45 +00:00
parent fb8cc156f3
commit dde115525b
2 changed files with 123 additions and 186 deletions

View File

@ -30,6 +30,9 @@ import java.net.URI;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.buffer.ChannelBuffers;
import org.jboss.netty.channel.AbstractChannel;
@ -39,14 +42,15 @@ import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.channel.ChannelPipeline;
import org.jboss.netty.channel.ChannelPipelineCoverage;
import org.jboss.netty.channel.ChannelSink;
import org.jboss.netty.channel.ChannelStateEvent;
import org.jboss.netty.channel.DefaultChannelPipeline;
import org.jboss.netty.channel.ExceptionEvent;
import org.jboss.netty.channel.MessageEvent;
import org.jboss.netty.channel.SimpleChannelUpstreamHandler;
import org.jboss.netty.channel.socket.ClientSocketChannelFactory;
import org.jboss.netty.channel.socket.SocketChannel;
import org.jboss.netty.channel.socket.SocketChannelConfig;
import org.jboss.netty.handler.codec.frame.DelimiterBasedFrameDecoder;
import org.jboss.netty.handler.ssl.SslHandler;
import org.jboss.netty.logging.InternalLogger;
import org.jboss.netty.logging.InternalLoggerFactory;
import org.jboss.netty.util.internal.LinkedTransferQueue;
@ -63,6 +67,7 @@ class HttpTunnelingClientSocketChannel extends AbstractChannel
static final InternalLogger logger =
InternalLoggerFactory.getInstance(HttpTunnelingClientSocketChannel.class);
private final HttpTunnelingSocketChannelConfig config;
private final Lock reconnectLock = new ReentrantLock();
volatile boolean awaitingInitialResponse = true;
@ -82,9 +87,8 @@ class HttpTunnelingClientSocketChannel extends AbstractChannel
SocketChannel channel;
private final DelimiterBasedFrameDecoder handler = new DelimiterBasedFrameDecoder(8092, ChannelBuffers.wrappedBuffer(new byte[] { '\r', '\n' }));
private final HttpTunnelingClientSocketChannel.ServletChannelHandler servletHandler = new ServletChannelHandler();
private final DelimiterBasedFrameDecoder decoder = new DelimiterBasedFrameDecoder(8092, ChannelBuffers.wrappedBuffer(new byte[] { '\r', '\n' }));
private final HttpTunnelingClientSocketChannel.ServletChannelHandler handler = new ServletChannelHandler();
private HttpTunnelAddress remoteAddress;
@ -94,18 +98,16 @@ class HttpTunnelingClientSocketChannel extends AbstractChannel
ChannelSink sink, ClientSocketChannelFactory clientSocketChannelFactory) {
super(null, factory, pipeline, sink);
this.clientSocketChannelFactory = clientSocketChannelFactory;
DefaultChannelPipeline channelPipeline = new DefaultChannelPipeline();
channelPipeline.addLast("DelimiterBasedFrameDecoder", handler);
channelPipeline.addLast("servletHandler", servletHandler);
channel = clientSocketChannelFactory.newChannel(channelPipeline);
createSocketChannel();
config = new HttpTunnelingSocketChannelConfig(this);
fireChannelOpen(this);
}
public SocketChannelConfig getConfig() {
return channel.getConfig();
public HttpTunnelingSocketChannelConfig getConfig() {
return config;
}
public InetSocketAddress getLocalAddress() {
@ -117,7 +119,7 @@ class HttpTunnelingClientSocketChannel extends AbstractChannel
}
public boolean isBound() {
return channel.isOpen();
return channel.isBound();
}
public boolean isConnected() {
@ -149,10 +151,7 @@ class HttpTunnelingClientSocketChannel extends AbstractChannel
URI url = remoteAddress.getUri();
if (reconnect) {
closeSocket();
DefaultChannelPipeline channelPipeline = new DefaultChannelPipeline();
channelPipeline.addLast("DelimiterBasedFrameDecoder", handler);
channelPipeline.addLast("servletHandler", servletHandler);
channel = clientSocketChannelFactory.newChannel(channelPipeline);
createSocketChannel();
}
SocketAddress connectAddress = new InetSocketAddress(url.getHost(), url.getPort());
channel.connect(connectAddress).awaitUninterruptibly();
@ -170,6 +169,16 @@ class HttpTunnelingClientSocketChannel extends AbstractChannel
channel.write(ChannelBuffers.copiedBuffer(msg, "ASCII"));
}
/**
*
*/
private void createSocketChannel() {
DefaultChannelPipeline channelPipeline = new DefaultChannelPipeline();
channelPipeline.addLast("decoder", decoder);
channelPipeline.addLast("handler", handler);
channel = clientSocketChannelFactory.newChannel(channelPipeline);
}
int sendChunk(ChannelBuffer a) {
int size = a.readableBytes();
String hex = Integer.toHexString(size) + HttpTunnelingClientSocketPipelineSink.LINE_TERMINATOR;
@ -241,6 +250,23 @@ class HttpTunnelingClientSocketChannel extends AbstractChannel
class ServletChannelHandler extends SimpleChannelUpstreamHandler {
int nextChunkSize = -1;
@Override
public void channelConnected(ChannelHandlerContext ctx,
ChannelStateEvent e) throws Exception {
SSLContext sslContext = getConfig().getSslContext();
if (sslContext != null) {
// FIXME: specify peer host and port.
SSLEngine engine = sslContext.createSSLEngine();
engine.setUseClientMode(true);
SocketChannel ch = (SocketChannel) e.getChannel();
SslHandler sslHandler = new SslHandler(engine);
ch.getPipeline().addFirst("ssl", sslHandler);
sslHandler.handshake(channel);
}
}
@Override
public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception {
ChannelBuffer buf = (ChannelBuffer) e.getMessage();

View File

@ -22,243 +22,154 @@
*/
package org.jboss.netty.channel.socket.http;
import java.net.Socket;
import java.net.SocketException;
import java.util.Map;
import org.jboss.netty.channel.ChannelException;
import org.jboss.netty.channel.DefaultChannelConfig;
import javax.net.ssl.SSLContext;
import org.jboss.netty.buffer.ChannelBufferFactory;
import org.jboss.netty.channel.ChannelPipelineFactory;
import org.jboss.netty.channel.socket.SocketChannelConfig;
import org.jboss.netty.util.internal.ConversionUtil;
/**
* @author The Netty Project (netty-dev@lists.jboss.org)
* @author Andy Taylor (andy.taylor@jboss.org)
* @version $Rev$, $Date$
*/
final class HttpTunnelingSocketChannelConfig extends DefaultChannelConfig
implements SocketChannelConfig {
final class HttpTunnelingSocketChannelConfig implements SocketChannelConfig {
final Socket socket;
private Integer trafficClass;
private Boolean tcpNoDelay;
private Integer soLinger;
private Integer sendBufferSize;
private Boolean reuseAddress;
private Integer receiveBufferSize;
private Integer connectionTime;
private Integer latency;
private Integer bandwidth;
private Boolean keepAlive;
private final HttpTunnelingClientSocketChannel channel;
private volatile SSLContext sslContext;
/**
* Creates a new instance.
*/
HttpTunnelingSocketChannelConfig(Socket socket) {
this.socket = socket;
HttpTunnelingSocketChannelConfig(HttpTunnelingClientSocketChannel channel) {
this.channel = channel;
}
public SSLContext getSslContext() {
return sslContext;
}
public void setSslContext(SSLContext sslContext) {
this.sslContext = sslContext;
}
public void setOptions(Map<String, Object> options) {
channel.channel.getConfig().setOptions(options);
SSLContext sslContext = (SSLContext) options.get("sslContext");
if (sslContext != null) {
setSslContext(sslContext);
}
}
@Override
public boolean setOption(String key, Object value) {
if (key.equals("receiveBufferSize")) {
setReceiveBufferSize(ConversionUtil.toInt(value));
} else if (key.equals("sendBufferSize")) {
setSendBufferSize(ConversionUtil.toInt(value));
} else if (key.equals("tcpNoDelay")) {
setTcpNoDelay(ConversionUtil.toBoolean(value));
} else if (key.equals("keepAlive")) {
setKeepAlive(ConversionUtil.toBoolean(value));
} else if (key.equals("reuseAddress")) {
setReuseAddress(ConversionUtil.toBoolean(value));
} else if (key.equals("soLinger")) {
setSoLinger(ConversionUtil.toInt(value));
} else if (key.equals("trafficClass")) {
setTrafficClass(ConversionUtil.toInt(value));
if (channel.channel.getConfig().setOption(key, value)) {
return true;
}
if (key.equals("sslContext")) {
setSslContext((SSLContext) value);
} else {
return false;
}
return true;
}
public int getReceiveBufferSize() {
try {
return socket.getReceiveBufferSize();
}
catch (SocketException e) {
throw new ChannelException(e);
}
return channel.channel.getConfig().getReceiveBufferSize();
}
public int getSendBufferSize() {
try {
return socket.getSendBufferSize();
}
catch (SocketException e) {
throw new ChannelException(e);
}
return channel.channel.getConfig().getSendBufferSize();
}
public int getSoLinger() {
try {
return socket.getSoLinger();
}
catch (SocketException e) {
throw new ChannelException(e);
}
return channel.channel.getConfig().getSoLinger();
}
public int getTrafficClass() {
try {
return socket.getTrafficClass();
}
catch (SocketException e) {
throw new ChannelException(e);
}
return channel.channel.getConfig().getTrafficClass();
}
public boolean isKeepAlive() {
try {
return socket.getKeepAlive();
}
catch (SocketException e) {
throw new ChannelException(e);
}
return channel.channel.getConfig().isKeepAlive();
}
public boolean isReuseAddress() {
try {
return socket.getReuseAddress();
}
catch (SocketException e) {
throw new ChannelException(e);
}
return channel.channel.getConfig().isReuseAddress();
}
public boolean isTcpNoDelay() {
try {
return socket.getTcpNoDelay();
}
catch (SocketException e) {
throw new ChannelException(e);
}
return channel.channel.getConfig().isTcpNoDelay();
}
public void setKeepAlive(boolean keepAlive) {
this.keepAlive = keepAlive;
try {
socket.setKeepAlive(keepAlive);
}
catch (SocketException e) {
throw new ChannelException(e);
}
channel.channel.getConfig().setKeepAlive(keepAlive);
}
public void setPerformancePreferences(
int connectionTime, int latency, int bandwidth) {
this.connectionTime = connectionTime;
this.latency = latency;
this.bandwidth = bandwidth;
socket.setPerformancePreferences(connectionTime, latency, bandwidth);
channel.channel.getConfig().setPerformancePreferences(connectionTime, latency, bandwidth);
}
public void setReceiveBufferSize(int receiveBufferSize) {
this.receiveBufferSize = receiveBufferSize;
try {
socket.setReceiveBufferSize(receiveBufferSize);
}
catch (SocketException e) {
throw new ChannelException(e);
}
channel.channel.getConfig().setReceiveBufferSize(receiveBufferSize);
}
public void setReuseAddress(boolean reuseAddress) {
this.reuseAddress = reuseAddress;
try {
socket.setReuseAddress(reuseAddress);
}
catch (SocketException e) {
throw new ChannelException(e);
}
channel.channel.getConfig().setReuseAddress(reuseAddress);
}
public void setSendBufferSize(int sendBufferSize) {
this.sendBufferSize = sendBufferSize;
if (socket != null) {
try {
socket.setSendBufferSize(sendBufferSize);
}
catch (SocketException e) {
throw new ChannelException(e);
}
}
channel.channel.getConfig().setSendBufferSize(sendBufferSize);
}
public void setSoLinger(int soLinger) {
this.soLinger = soLinger;
try {
if (soLinger < 0) {
socket.setSoLinger(false, 0);
}
else {
socket.setSoLinger(true, soLinger);
}
}
catch (SocketException e) {
throw new ChannelException(e);
}
channel.channel.getConfig().setSoLinger(soLinger);
}
public void setTcpNoDelay(boolean tcpNoDelay) {
this.tcpNoDelay = tcpNoDelay;
try {
socket.setTcpNoDelay(tcpNoDelay);
}
catch (SocketException e) {
throw new ChannelException(e);
}
channel.channel.getConfig().setTcpNoDelay(tcpNoDelay);
}
public void setTrafficClass(int trafficClass) {
this.trafficClass = trafficClass;
try {
socket.setTrafficClass(trafficClass);
}
catch (SocketException e) {
throw new ChannelException(e);
}
channel.channel.getConfig().setTrafficClass(trafficClass);
}
HttpTunnelingSocketChannelConfig copyConfig(Socket socket) {
HttpTunnelingSocketChannelConfig config = new HttpTunnelingSocketChannelConfig(socket);
config.setConnectTimeoutMillis(getConnectTimeoutMillis());
if (trafficClass != null) {
config.setTrafficClass(trafficClass);
}
if (tcpNoDelay != null) {
config.setTcpNoDelay(tcpNoDelay);
}
if (soLinger != null) {
config.setSoLinger(soLinger);
}
if (sendBufferSize != null) {
config.setSendBufferSize(sendBufferSize);
}
if (reuseAddress != null) {
config.setReuseAddress(reuseAddress);
}
if (receiveBufferSize != null) {
config.setReceiveBufferSize(receiveBufferSize);
}
if (keepAlive != null) {
config.setKeepAlive(keepAlive);
}
if (connectionTime != null) {
config.setPerformancePreferences(connectionTime, latency, bandwidth);
}
return config;
public ChannelBufferFactory getBufferFactory() {
return channel.channel.getConfig().getBufferFactory();
}
public int getConnectTimeoutMillis() {
return channel.channel.getConfig().getConnectTimeoutMillis();
}
public ChannelPipelineFactory getPipelineFactory() {
return channel.channel.getConfig().getPipelineFactory();
}
@Deprecated
public int getWriteTimeoutMillis() {
return channel.channel.getConfig().getWriteTimeoutMillis();
}
public void setBufferFactory(ChannelBufferFactory bufferFactory) {
channel.channel.getConfig().setBufferFactory(bufferFactory);
}
public void setConnectTimeoutMillis(int connectTimeoutMillis) {
channel.channel.getConfig().setConnectTimeoutMillis(connectTimeoutMillis);
}
public void setPipelineFactory(ChannelPipelineFactory pipelineFactory) {
channel.channel.getConfig().setPipelineFactory(pipelineFactory);
}
@Deprecated
public void setWriteTimeoutMillis(int writeTimeoutMillis) {
channel.channel.getConfig().setWriteTimeoutMillis(writeTimeoutMillis);
}
}