Remove ContinuationWebSocketFrame.aggregatedText()

Motivation:
Before we aggregated the full text in the WebSocket08FrameDecoder just to fill in the ContinuationWebSocketFrame.aggregatedText(). The problem was that there was no upper-limit and so it would be possible to see an OOME if the remote peer sends a TextWebSocketFrame + a never ending stream of ContinuationWebSocketFrames. Furthermore the aggregation does not really belong in the WebSocket08FrameDecoder, as we provide an extra ChannelHandler for this anyway (WebSocketFrameAggregator).

Modification:
Remove the ContinuationWebSocketFrame.aggregatedText() method and corresponding constructor. Also refactored WebSocket08FrameDecoder a bit to me more efficient which is now possible as we not need to aggregate here.

Result:
No more risk of OOME because of frames.
This commit is contained in:
Norman Maurer 2014-04-30 07:38:42 +02:00
parent 76355a28b0
commit 787a85f9f1
4 changed files with 45 additions and 127 deletions

View File

@ -25,13 +25,11 @@ import io.netty.util.CharsetUtil;
*/ */
public class ContinuationWebSocketFrame extends WebSocketFrame { public class ContinuationWebSocketFrame extends WebSocketFrame {
private String aggregatedText;
/** /**
* Creates a new empty continuation frame. * Creates a new empty continuation frame.
*/ */
public ContinuationWebSocketFrame() { public ContinuationWebSocketFrame() {
super(Unpooled.buffer(0)); this(Unpooled.buffer(0));
} }
/** /**
@ -58,25 +56,6 @@ public class ContinuationWebSocketFrame extends WebSocketFrame {
super(finalFragment, rsv, binaryData); super(finalFragment, rsv, binaryData);
} }
/**
* Creates a new continuation frame with the specified binary data
*
* @param finalFragment
* flag indicating if this frame is the final fragment
* @param rsv
* reserved bits used for protocol extensions
* @param binaryData
* the content of the frame.
* @param aggregatedText
* Aggregated text set by decoder on the final continuation frame of a fragmented
* text message
*/
public ContinuationWebSocketFrame(
boolean finalFragment, int rsv, ByteBuf binaryData, String aggregatedText) {
super(finalFragment, rsv, binaryData);
this.aggregatedText = aggregatedText;
}
/** /**
* Creates a new continuation frame with the specified text data * Creates a new continuation frame with the specified text data
* *
@ -88,7 +67,7 @@ public class ContinuationWebSocketFrame extends WebSocketFrame {
* text content of the frame. * text content of the frame.
*/ */
public ContinuationWebSocketFrame(boolean finalFragment, int rsv, String text) { public ContinuationWebSocketFrame(boolean finalFragment, int rsv, String text) {
this(finalFragment, rsv, fromText(text), null); this(finalFragment, rsv, fromText(text));
} }
/** /**
@ -112,21 +91,14 @@ public class ContinuationWebSocketFrame extends WebSocketFrame {
} }
} }
/**
* Aggregated text returned by decoder on the final continuation frame of a fragmented text message
*/
public String aggregatedText() {
return aggregatedText;
}
@Override @Override
public ContinuationWebSocketFrame copy() { public ContinuationWebSocketFrame copy() {
return new ContinuationWebSocketFrame(isFinalFragment(), rsv(), content().copy(), aggregatedText()); return new ContinuationWebSocketFrame(isFinalFragment(), rsv(), content().copy());
} }
@Override @Override
public ContinuationWebSocketFrame duplicate() { public ContinuationWebSocketFrame duplicate() {
return new ContinuationWebSocketFrame(isFinalFragment(), rsv(), content().duplicate(), aggregatedText()); return new ContinuationWebSocketFrame(isFinalFragment(), rsv(), content().duplicate());
} }
@Override @Override

View File

@ -1,47 +0,0 @@
/*
* Copyright 2012 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
/*
* Adaptation of http://bjoern.hoehrmann.de/utf-8/decoder/dfa/
*
* Copyright (c) 2008-2009 Bjoern Hoehrmann <bjoern@hoehrmann.de>
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software
* and associated documentation files (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge, publish, distribute,
* sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all copies or
* substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
package io.netty.handler.codec.http.websocketx;
/**
* Invalid UTF8 bytes encountered
*/
final class UTF8Exception extends RuntimeException {
private static final long serialVersionUID = 1L;
UTF8Exception(String reason) {
super(reason);
}
}

View File

@ -36,11 +36,13 @@
package io.netty.handler.codec.http.websocketx; package io.netty.handler.codec.http.websocketx;
import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufProcessor;
import io.netty.handler.codec.CorruptedFrameException;
/** /**
* Checks UTF8 bytes for validity before converting it into a string * Checks UTF8 bytes for validity
*/ */
final class UTF8Output { final class Utf8Validator implements ByteBufProcessor {
private static final int UTF8_ACCEPT = 0; private static final int UTF8_ACCEPT = 0;
private static final int UTF8_REJECT = 12; private static final int UTF8_REJECT = 12;
@ -65,45 +67,38 @@ final class UTF8Output {
@SuppressWarnings("RedundantFieldInitialization") @SuppressWarnings("RedundantFieldInitialization")
private int state = UTF8_ACCEPT; private int state = UTF8_ACCEPT;
private int codep; private int codep;
private boolean checking;
private final StringBuilder stringBuilder; public void check(ByteBuf buffer) {
checking = true;
UTF8Output(ByteBuf buffer) { buffer.forEachByte(this);
stringBuilder = new StringBuilder(buffer.readableBytes());
write(buffer);
} }
public void write(ByteBuf buffer) { public void finish() {
for (int i = buffer.readerIndex(); i < buffer.writerIndex(); i++) { checking = false;
write(buffer.getByte(i)); codep = 0;
if (state != UTF8_ACCEPT) {
state = UTF8_ACCEPT;
throw new CorruptedFrameException("bytes are not UTF-8");
} }
} }
public void write(byte[] bytes) { @Override
for (byte b : bytes) { public boolean process(byte b) throws Exception {
write(b);
}
}
public void write(int b) {
byte type = TYPES[b & 0xFF]; byte type = TYPES[b & 0xFF];
codep = state != UTF8_ACCEPT ? b & 0x3f | codep << 6 : 0xff >> type & b; codep = state != UTF8_ACCEPT ? b & 0x3f | codep << 6 : 0xff >> type & b;
state = STATES[state + type]; state = STATES[state + type];
if (state == UTF8_ACCEPT) { if (state == UTF8_REJECT) {
stringBuilder.append((char) codep); checking = false;
} else if (state == UTF8_REJECT) { throw new CorruptedFrameException("bytes are not UTF-8");
throw new UTF8Exception("bytes are not UTF-8");
} }
return true;
} }
@Override public boolean isChecking() {
public String toString() { return checking;
if (state != UTF8_ACCEPT) {
throw new UTF8Exception("bytes are not UTF-8");
}
return stringBuilder.toString();
} }
} }

View File

@ -81,9 +81,7 @@ public class WebSocket08FrameDecoder extends ReplayingDecoder<WebSocket08FrameDe
private static final byte OPCODE_PING = 0x9; private static final byte OPCODE_PING = 0x9;
private static final byte OPCODE_PONG = 0xA; private static final byte OPCODE_PONG = 0xA;
private UTF8Output fragmentedFramesText;
private int fragmentedFramesCount; private int fragmentedFramesCount;
private final long maxFramePayloadLength; private final long maxFramePayloadLength;
private boolean frameFinalFlag; private boolean frameFinalFlag;
private int frameRsv; private int frameRsv;
@ -93,10 +91,10 @@ public class WebSocket08FrameDecoder extends ReplayingDecoder<WebSocket08FrameDe
private int framePayloadBytesRead; private int framePayloadBytesRead;
private byte[] maskingKey; private byte[] maskingKey;
private ByteBuf payloadBuffer; private ByteBuf payloadBuffer;
private final boolean allowExtensions; private final boolean allowExtensions;
private final boolean maskedPayload; private final boolean maskedPayload;
private boolean receivedClosingHandshake; private boolean receivedClosingHandshake;
private Utf8Validator utf8Validator;
enum State { enum State {
FRAME_START, MASKING_KEY, PAYLOAD, CORRUPT FRAME_START, MASKING_KEY, PAYLOAD, CORRUPT
@ -325,7 +323,6 @@ public class WebSocket08FrameDecoder extends ReplayingDecoder<WebSocket08FrameDe
// Processing for possible fragmented messages for text and binary // Processing for possible fragmented messages for text and binary
// frames // frames
String aggregatedText = null;
if (frameFinalFlag) { if (frameFinalFlag) {
// Final frame of the sequence. Apparently ping frames are // Final frame of the sequence. Apparently ping frames are
// allowed in the middle of a fragmented message // allowed in the middle of a fragmented message
@ -333,15 +330,14 @@ public class WebSocket08FrameDecoder extends ReplayingDecoder<WebSocket08FrameDe
fragmentedFramesCount = 0; fragmentedFramesCount = 0;
// Check text for UTF8 correctness // Check text for UTF8 correctness
if (frameOpcode == OPCODE_TEXT || fragmentedFramesText != null) { if (frameOpcode == OPCODE_TEXT ||
(utf8Validator != null && utf8Validator.isChecking())) {
// Check UTF-8 correctness for this payload // Check UTF-8 correctness for this payload
checkUTF8String(ctx, framePayload); checkUTF8String(ctx, framePayload);
// This does a second check to make sure UTF-8 // This does a second check to make sure UTF-8
// correctness for entire text message // correctness for entire text message
aggregatedText = fragmentedFramesText.toString(); utf8Validator.finish();
fragmentedFramesText = null;
} }
} }
} else { } else {
@ -349,13 +345,12 @@ public class WebSocket08FrameDecoder extends ReplayingDecoder<WebSocket08FrameDe
// fragmented sequence // fragmented sequence
if (fragmentedFramesCount == 0) { if (fragmentedFramesCount == 0) {
// First text or binary frame for a fragmented set // First text or binary frame for a fragmented set
fragmentedFramesText = null;
if (frameOpcode == OPCODE_TEXT) { if (frameOpcode == OPCODE_TEXT) {
checkUTF8String(ctx, framePayload); checkUTF8String(ctx, framePayload);
} }
} else { } else {
// Subsequent frames - only check if init frame is text // Subsequent frames - only check if init frame is text
if (fragmentedFramesText != null) { if (utf8Validator != null && utf8Validator.isChecking()) {
checkUTF8String(ctx, framePayload); checkUTF8String(ctx, framePayload);
} }
} }
@ -374,7 +369,7 @@ public class WebSocket08FrameDecoder extends ReplayingDecoder<WebSocket08FrameDe
framePayload = null; framePayload = null;
return; return;
} else if (frameOpcode == OPCODE_CONT) { } else if (frameOpcode == OPCODE_CONT) {
out.add(new ContinuationWebSocketFrame(frameFinalFlag, frameRsv, framePayload, aggregatedText)); out.add(new ContinuationWebSocketFrame(frameFinalFlag, frameRsv, framePayload));
framePayload = null; framePayload = null;
return; return;
} else { } else {
@ -413,11 +408,15 @@ public class WebSocket08FrameDecoder extends ReplayingDecoder<WebSocket08FrameDe
} }
private void protocolViolation(ChannelHandlerContext ctx, String reason) { private void protocolViolation(ChannelHandlerContext ctx, String reason) {
protocolViolation(ctx, new CorruptedFrameException(reason));
}
private void protocolViolation(ChannelHandlerContext ctx, CorruptedFrameException ex) {
checkpoint(State.CORRUPT); checkpoint(State.CORRUPT);
if (ctx.channel().isActive()) { if (ctx.channel().isActive()) {
ctx.writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE); ctx.writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE);
} }
throw new CorruptedFrameException(reason); throw ex;
} }
private static int toFrameLength(long l) { private static int toFrameLength(long l) {
@ -430,13 +429,12 @@ public class WebSocket08FrameDecoder extends ReplayingDecoder<WebSocket08FrameDe
private void checkUTF8String(ChannelHandlerContext ctx, ByteBuf buffer) { private void checkUTF8String(ChannelHandlerContext ctx, ByteBuf buffer) {
try { try {
if (fragmentedFramesText == null) { if (utf8Validator == null) {
fragmentedFramesText = new UTF8Output(buffer); utf8Validator = new Utf8Validator();
} else {
fragmentedFramesText.write(buffer);
} }
} catch (UTF8Exception ex) { utf8Validator.check(buffer);
protocolViolation(ctx, "invalid UTF-8 bytes"); } catch (CorruptedFrameException ex) {
protocolViolation(ctx, ex);
} }
} }
@ -464,9 +462,9 @@ public class WebSocket08FrameDecoder extends ReplayingDecoder<WebSocket08FrameDe
// May have UTF-8 message // May have UTF-8 message
if (buffer.isReadable()) { if (buffer.isReadable()) {
try { try {
new UTF8Output(buffer); new Utf8Validator().check(buffer);
} catch (UTF8Exception ex) { } catch (CorruptedFrameException ex) {
protocolViolation(ctx, "Invalid close frame reason text. Invalid UTF-8 bytes"); protocolViolation(ctx, ex);
} }
} }