diff --git a/handler/src/main/java/io/netty/handler/ssl/SslUtils.java b/handler/src/main/java/io/netty/handler/ssl/SslUtils.java index 4cb289dc61..79bd2fc257 100644 --- a/handler/src/main/java/io/netty/handler/ssl/SslUtils.java +++ b/handler/src/main/java/io/netty/handler/ssl/SslUtils.java @@ -64,6 +64,7 @@ final class SslUtils { static final String PROTOCOL_TLS_V1_1 = "TLSv1.1"; static final String PROTOCOL_TLS_V1_2 = "TLSv1.2"; static final String PROTOCOL_TLS_V1_3 = "TLSv1.3"; + static final int GMSSL_PROTOCOL_VERSION = 0x101; static final String INVALID_CIPHER = "SSL_NULL_WITH_NULL_NULL"; @@ -295,10 +296,10 @@ final class SslUtils { } if (tls) { - // SSLv3 or TLS - Check ProtocolVersion + // SSLv3 or TLS or GMSSLv1.0 or GMSSLv1.1 - Check ProtocolVersion int majorVersion = buffer.getUnsignedByte(offset + 1); - if (majorVersion == 3) { - // SSLv3 or TLS + if (majorVersion == 3 || buffer.getShort(offset + 1) == GMSSL_PROTOCOL_VERSION) { + // SSLv3 or TLS or GMSSLv1.0 or GMSSLv1.1 packetLength = unsignedShortBE(buffer, offset + 3) + SSL_RECORD_HEADER_LENGTH; if (packetLength <= SSL_RECORD_HEADER_LENGTH) { // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data) @@ -400,10 +401,10 @@ final class SslUtils { } if (tls) { - // SSLv3 or TLS - Check ProtocolVersion + // SSLv3 or TLS or GMSSLv1.0 or GMSSLv1.1 - Check ProtocolVersion int majorVersion = unsignedByte(buffer.get(pos + 1)); - if (majorVersion == 3) { - // SSLv3 or TLS + if (majorVersion == 3 || buffer.getShort(pos + 1) == GMSSL_PROTOCOL_VERSION) { + // SSLv3 or TLS or GMSSLv1.0 or GMSSLv1.1 packetLength = unsignedShortBE(buffer, pos + 3) + SSL_RECORD_HEADER_LENGTH; if (packetLength <= SSL_RECORD_HEADER_LENGTH) { // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data) diff --git a/handler/src/test/java/io/netty/handler/ssl/SslUtilsTest.java b/handler/src/test/java/io/netty/handler/ssl/SslUtilsTest.java index f84274c97c..02ee72bdf4 100644 --- a/handler/src/test/java/io/netty/handler/ssl/SslUtilsTest.java +++ b/handler/src/test/java/io/netty/handler/ssl/SslUtilsTest.java @@ -75,4 +75,29 @@ public class SslUtilsTest { assertFalse(SslUtils.isTLSv13Cipher("TLS_DHE_RSA_WITH_AES_128_GCM_SHA256")); } + @Test + public void shouldGetPacketLengthOfGmsslProtocolFromByteBuf() { + int bodyLength = 65; + ByteBuf buf = Unpooled.buffer() + .writeByte(SslUtils.SSL_CONTENT_TYPE_HANDSHAKE) + .writeShort(SslUtils.GMSSL_PROTOCOL_VERSION) + .writeShort(bodyLength); + + int packetLength = getEncryptedPacketLength(buf, 0); + assertEquals(bodyLength + SslUtils.SSL_RECORD_HEADER_LENGTH, packetLength); + buf.release(); + } + + @Test + public void shouldGetPacketLengthOfGmsslProtocolFromByteBuffer() { + int bodyLength = 65; + ByteBuf buf = Unpooled.buffer() + .writeByte(SslUtils.SSL_CONTENT_TYPE_HANDSHAKE) + .writeShort(SslUtils.GMSSL_PROTOCOL_VERSION) + .writeShort(bodyLength); + + int packetLength = getEncryptedPacketLength(new ByteBuffer[] { buf.nioBuffer() }, 0); + assertEquals(bodyLength + SslUtils.SSL_RECORD_HEADER_LENGTH, packetLength); + buf.release(); + } }