Add WebSocketFrameAggregator which takes care to aggregate fragmented websocket frames

This commit is contained in:
Norman Maurer 2013-03-25 10:50:59 +01:00
parent 71727e42de
commit 4a9ab4f57c
2 changed files with 199 additions and 0 deletions

View File

@ -0,0 +1,88 @@
/*
* Copyright 2013 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.http.websocketx;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.CompositeByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToMessageDecoder;
import io.netty.handler.codec.TooLongFrameException;
/**
* Handler that aggregate fragmented WebSocketFrame's.
*
* Be aware if PING/PONG/CLOSE frames are send in the middle of a fragmented {@link WebSocketFrame} they will
* just get forwarded to the next handler in the pipeline.
*/
public class WebSocketFrameAggregator extends MessageToMessageDecoder<WebSocketFrame> {
private final int maxFrameSize;
private WebSocketFrame currentFrame;
/**
* Construct a new instance
*
* @param maxFrameSize If the size of the aggregated frame exceeds this value,
* a {@link TooLongFrameException} is thrown.
*/
public WebSocketFrameAggregator(int maxFrameSize) {
if (maxFrameSize < 1) {
throw new IllegalArgumentException("maxFrameSize must be > 0");
}
this.maxFrameSize = maxFrameSize;
}
@Override
protected Object decode(ChannelHandlerContext ctx, WebSocketFrame msg) throws Exception {
if (currentFrame == null) {
if (msg.isFinalFragment()) {
return msg.retain();
}
ByteBuf buf = ctx.alloc().compositeBuffer().addComponent(msg.data().retain());
buf.writerIndex(buf.writerIndex() + msg.data().readableBytes());
if (msg instanceof TextWebSocketFrame) {
currentFrame = new TextWebSocketFrame(true, msg.rsv(), buf);
} else if (msg instanceof BinaryWebSocketFrame) {
currentFrame = new BinaryWebSocketFrame(true, msg.rsv(), buf);
} else {
throw new IllegalStateException(
"WebSocket frame was not of type TextWebSocketFrame or BinaryWebSocketFrame");
}
return null;
}
if (msg instanceof ContinuationWebSocketFrame) {
CompositeByteBuf content = (CompositeByteBuf) currentFrame.data();
if (content.readableBytes() > maxFrameSize - msg.data().readableBytes()) {
throw new TooLongFrameException(
"WebSocketFrame length exceeded " + content +
" bytes.");
}
content.addComponent(msg.data().retain());
content.writerIndex(content.writerIndex() + msg.data().readableBytes());
if (msg.isFinalFragment()) {
WebSocketFrame frame = this.currentFrame;
this.currentFrame = null;
return frame;
} else {
return null;
}
}
// It is possible to receive CLOSE/PING/PONG frames during fragmented frames so just pass them to the next
// handler in the chain
return msg.retain();
}
}

View File

@ -0,0 +1,111 @@
/*
* Copyright 2013 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.handler.codec.http.websocketx;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedMessageChannel;
import io.netty.handler.codec.TooLongFrameException;
import io.netty.util.CharsetUtil;
import org.junit.Assert;
import org.junit.Test;
public class WebSocketFrameAggregatorTest {
private final ByteBuf content1 = Unpooled.copiedBuffer("Content1", CharsetUtil.UTF_8);
private final ByteBuf content2 = Unpooled.copiedBuffer("Content2", CharsetUtil.UTF_8);
private final ByteBuf content3 = Unpooled.copiedBuffer("Content3", CharsetUtil.UTF_8);
private final ByteBuf aggregatedContent = Unpooled.buffer().writeBytes(content1.duplicate())
.writeBytes(content2.duplicate()).writeBytes(content3.duplicate());
@Test
public void testAggregationBinary() {
EmbeddedMessageChannel channel = new EmbeddedMessageChannel(new WebSocketFrameAggregator(Integer.MAX_VALUE));
channel.writeInbound(new BinaryWebSocketFrame(true, 1, content1.copy()));
channel.writeInbound(new BinaryWebSocketFrame(false, 0, content1.copy()));
channel.writeInbound(new ContinuationWebSocketFrame(false, 0, content2.copy()));
channel.writeInbound(new PingWebSocketFrame(content1.copy()));
channel.writeInbound(new PongWebSocketFrame(content1.copy()));
channel.writeInbound(new ContinuationWebSocketFrame(true, 0, content3.copy()));
Assert.assertTrue(channel.finish());
System.out.println(channel.lastInboundMessageBuffer().size());
BinaryWebSocketFrame frame = (BinaryWebSocketFrame) channel.readInbound();
Assert.assertTrue(frame.isFinalFragment());
Assert.assertEquals(1, frame.rsv());
Assert.assertEquals(content1, frame.data());
PingWebSocketFrame frame2 = (PingWebSocketFrame) channel.readInbound();
Assert.assertTrue(frame2.isFinalFragment());
Assert.assertEquals(0, frame2.rsv());
Assert.assertEquals(content1, frame2.data());
PongWebSocketFrame frame3 = (PongWebSocketFrame) channel.readInbound();
Assert.assertTrue(frame3.isFinalFragment());
Assert.assertEquals(0, frame3.rsv());
Assert.assertEquals(content1, frame3.data());
BinaryWebSocketFrame frame4 = (BinaryWebSocketFrame) channel.readInbound();
Assert.assertTrue(frame4.isFinalFragment());
Assert.assertEquals(0, frame4.rsv());
Assert.assertEquals(aggregatedContent, frame4.data());
Assert.assertNull(channel.readInbound());
}
@Test
public void testAggregationText() {
EmbeddedMessageChannel channel = new EmbeddedMessageChannel(new WebSocketFrameAggregator(Integer.MAX_VALUE));
channel.writeInbound(new TextWebSocketFrame(true, 1, content1.copy()));
channel.writeInbound(new TextWebSocketFrame(false, 0, content1.copy()));
channel.writeInbound(new ContinuationWebSocketFrame(false, 0, content2.copy()));
channel.writeInbound(new PingWebSocketFrame(content1.copy()));
channel.writeInbound(new PongWebSocketFrame(content1.copy()));
channel.writeInbound(new ContinuationWebSocketFrame(true, 0, content3.copy()));
Assert.assertTrue(channel.finish());
TextWebSocketFrame frame = (TextWebSocketFrame) channel.readInbound();
Assert.assertTrue(frame.isFinalFragment());
Assert.assertEquals(1, frame.rsv());
Assert.assertEquals(content1.duplicate(), frame.data());
PingWebSocketFrame frame2 = (PingWebSocketFrame) channel.readInbound();
Assert.assertTrue(frame2.isFinalFragment());
Assert.assertEquals(0, frame2.rsv());
Assert.assertEquals(content1, frame2.data());
PongWebSocketFrame frame3 = (PongWebSocketFrame) channel.readInbound();
Assert.assertTrue(frame3.isFinalFragment());
Assert.assertEquals(0, frame3.rsv());
Assert.assertEquals(content1, frame3.data());
TextWebSocketFrame frame4 = (TextWebSocketFrame) channel.readInbound();
Assert.assertTrue(frame4.isFinalFragment());
Assert.assertEquals(0, frame4.rsv());
Assert.assertEquals(aggregatedContent, frame4.data());
Assert.assertNull(channel.readInbound());
}
@Test(expected = TooLongFrameException.class)
public void textFrameTooBig() {
EmbeddedMessageChannel channel = new EmbeddedMessageChannel(new WebSocketFrameAggregator(8));
channel.writeInbound(new BinaryWebSocketFrame(true, 1, content1.copy()));
channel.writeInbound(new BinaryWebSocketFrame(false, 0, content1.copy()));
channel.writeInbound(new ContinuationWebSocketFrame(false, 0, content2.copy()));
}
}