diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket08FrameDecoder.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket08FrameDecoder.java index ea911d3f78..0c0528c22e 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket08FrameDecoder.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket08FrameDecoder.java @@ -63,6 +63,7 @@ import io.netty.handler.codec.TooLongFrameException; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; +import java.nio.ByteOrder; import java.util.List; import static io.netty.buffer.ByteBufUtil.readBytes; @@ -356,7 +357,29 @@ public class WebSocket08FrameDecoder extends ByteToMessageDecoder } private void unmask(ByteBuf frame) { - for (int i = frame.readerIndex(); i < frame.writerIndex(); i++) { + int i = frame.readerIndex(); + int end = frame.writerIndex(); + + ByteOrder order = frame.order(); + + // Remark: & 0xFF is necessary because Java will do signed expansion from + // byte to int which we don't want. + int intMask = ((maskingKey[0] & 0xFF) << 24) + | ((maskingKey[1] & 0xFF) << 16) + | ((maskingKey[2] & 0xFF) << 8) + | (maskingKey[3] & 0xFF); + + // If the byte order of our buffers it little endian we have to bring our mask + // into the same format, because getInt() and writeInt() will use a reversed byte order + if (order == ByteOrder.LITTLE_ENDIAN) { + intMask = Integer.reverseBytes(intMask); + } + + for (; i + 3 < end; i += 4) { + int unmasked = frame.getInt(i) ^ intMask; + frame.setInt(i, unmasked); + } + for (; i < end; i++) { frame.setByte(i, frame.getByte(i) ^ maskingKey[i % 4]); } } diff --git a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket08FrameEncoder.java b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket08FrameEncoder.java index 1589a97502..2bcc87442d 100644 --- a/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket08FrameEncoder.java +++ b/codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket08FrameEncoder.java @@ -61,6 +61,7 @@ import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.util.List; /** @@ -173,8 +174,34 @@ public class WebSocket08FrameEncoder extends MessageToMessageEncoder + * Checks whether the combination of encoding and decoding yields the original data.
+ * Thereby also the masking behavior is checked. + */ +public class WebSocket08EncoderDecoderTest { + + private ByteBuf binTestData; + private String strTestData; + + private static final int MAX_TESTDATA_LENGTH = 100 * 1024; + + private void initTestData() { + binTestData = Unpooled.buffer(MAX_TESTDATA_LENGTH); + byte j = 0; + for (int i = 0; i < MAX_TESTDATA_LENGTH; i++) { + binTestData.array()[i] = j; + j++; + } + + StringBuilder s = new StringBuilder(); + char c = 'A'; + for (int i = 0; i < MAX_TESTDATA_LENGTH; i++) { + s.append(c); + c++; + if (c == 'Z') { + c = 'A'; + } + } + strTestData = s.toString(); + } + + @Test + public void testWebSocketEncodingAndDecoding() { + initTestData(); + + // Test without masking + EmbeddedChannel outChannel = new EmbeddedChannel(new WebSocket08FrameEncoder(false)); + EmbeddedChannel inChannel = new EmbeddedChannel(new WebSocket08FrameDecoder(false, false, 1024 * 1024)); + executeTests(outChannel, inChannel); + + // Test with activated masking + outChannel = new EmbeddedChannel(new WebSocket08FrameEncoder(true)); + inChannel = new EmbeddedChannel(new WebSocket08FrameDecoder(true, false, 1024 * 1024)); + executeTests(outChannel, inChannel); + + // Release test data + binTestData.release(); + } + + private void executeTests(EmbeddedChannel outChannel, EmbeddedChannel inChannel) { + // Test at the boundaries of each message type, because this shifts the position of the mask field + // Test min. 4 lengths to check for problems related to an uneven frame length + executeTests(outChannel, inChannel, 0); + executeTests(outChannel, inChannel, 1); + executeTests(outChannel, inChannel, 2); + executeTests(outChannel, inChannel, 3); + executeTests(outChannel, inChannel, 4); + executeTests(outChannel, inChannel, 5); + + executeTests(outChannel, inChannel, 125); + executeTests(outChannel, inChannel, 126); + executeTests(outChannel, inChannel, 127); + executeTests(outChannel, inChannel, 128); + executeTests(outChannel, inChannel, 129); + + executeTests(outChannel, inChannel, 65535); + executeTests(outChannel, inChannel, 65536); + executeTests(outChannel, inChannel, 65537); + executeTests(outChannel, inChannel, 65538); + executeTests(outChannel, inChannel, 65539); + } + + private void executeTests(EmbeddedChannel outChannel, EmbeddedChannel inChannel, int testDataLength) { + testTextWithLen(outChannel, inChannel, testDataLength); + testBinaryWithLen(outChannel, inChannel, testDataLength); + } + + private void testTextWithLen(EmbeddedChannel outChannel, EmbeddedChannel inChannel, int testDataLength) { + String testStr = strTestData.substring(0, testDataLength); + outChannel.writeOutbound(new TextWebSocketFrame(testStr)); + + // Transfer encoded data into decoder + // Loop because there might be multiple frames (gathering write) + while (true) { + ByteBuf encoded = outChannel.readOutbound(); + if (encoded != null) { + inChannel.writeInbound(encoded); + } else { + break; + } + } + + Object decoded = inChannel.readInbound(); + Assert.assertNotNull(decoded); + Assert.assertTrue(decoded instanceof TextWebSocketFrame); + TextWebSocketFrame txt = (TextWebSocketFrame) decoded; + Assert.assertEquals(txt.text(), testStr); + txt.release(); + } + + private void testBinaryWithLen(EmbeddedChannel outChannel, EmbeddedChannel inChannel, int testDataLength) { + binTestData.retain(); // need to retain for sending and still keeping it + binTestData.setIndex(0, testDataLength); // Send only len bytes + outChannel.writeOutbound(new BinaryWebSocketFrame(binTestData)); + + // Transfer encoded data into decoder + // Loop because there might be multiple frames (gathering write) + while (true) { + ByteBuf encoded = outChannel.readOutbound(); + if (encoded != null) { + inChannel.writeInbound(encoded); + } else { + break; + } + } + + Object decoded = inChannel.readInbound(); + Assert.assertNotNull(decoded); + Assert.assertTrue(decoded instanceof BinaryWebSocketFrame); + BinaryWebSocketFrame binFrame = (BinaryWebSocketFrame) decoded; + int readable = binFrame.content().readableBytes(); + Assert.assertEquals(readable, testDataLength); + for (int i = 0; i < testDataLength; i++) { + Assert.assertEquals(binTestData.getByte(i), binFrame.content().getByte(i)); + } + binFrame.release(); + } +}