From 213d195909e2a49487160219e539807dcff730db Mon Sep 17 00:00:00 2001 From: Norman Maurer Date: Fri, 20 Sep 2013 16:24:46 +0200 Subject: [PATCH] [#1865] Only use internalNioBuffer when one of the read* or write* methods are used. This is neccessary to prevent races as those can happen when a slice or duplicate is shared between different Channels that are not assigned to the same EventLoop. In general get* operations should always be safe to be used from different Threads. This aslo include unit tests that show the issue --- .../java/io/netty/buffer/AbstractByteBuf.java | 2 +- .../io/netty/buffer/PooledDirectByteBuf.java | 91 ++++++- .../io/netty/buffer/PooledHeapByteBuf.java | 20 +- .../buffer/PooledUnsafeDirectByteBuf.java | 46 +++- .../netty/buffer/UnpooledDirectByteBuf.java | 91 ++++++- .../io/netty/buffer/UnpooledHeapByteBuf.java | 21 +- .../buffer/UnpooledUnsafeDirectByteBuf.java | 45 +++- .../io/netty/buffer/AbstractByteBufTest.java | 227 ++++++++++++++++++ .../buffer/AbstractCompositeByteBufTest.java | 60 ----- .../io/netty/buffer/SlicedByteBufTest.java | 36 +++ 10 files changed, 544 insertions(+), 95 deletions(-) diff --git a/buffer/src/main/java/io/netty/buffer/AbstractByteBuf.java b/buffer/src/main/java/io/netty/buffer/AbstractByteBuf.java index abc5e71d33..e44a61a6fc 100644 --- a/buffer/src/main/java/io/netty/buffer/AbstractByteBuf.java +++ b/buffer/src/main/java/io/netty/buffer/AbstractByteBuf.java @@ -36,7 +36,7 @@ public abstract class AbstractByteBuf extends ByteBuf { static final ResourceLeakDetector leakDetector = new ResourceLeakDetector(ByteBuf.class); - private int readerIndex; + int readerIndex; private int writerIndex; private int markedReaderIndex; private int markedWriterIndex; diff --git a/buffer/src/main/java/io/netty/buffer/PooledDirectByteBuf.java b/buffer/src/main/java/io/netty/buffer/PooledDirectByteBuf.java index 2a95b7e024..26ab1ecfc4 100644 --- a/buffer/src/main/java/io/netty/buffer/PooledDirectByteBuf.java +++ b/buffer/src/main/java/io/netty/buffer/PooledDirectByteBuf.java @@ -101,53 +101,122 @@ final class PooledDirectByteBuf extends PooledByteBuf { @Override public ByteBuf getBytes(int index, byte[] dst, int dstIndex, int length) { + getBytes(index, dst, dstIndex, length, false); + return this; + } + + private void getBytes(int index, byte[] dst, int dstIndex, int length, boolean internal) { checkDstIndex(index, length, dstIndex, dst.length); - ByteBuffer tmpBuf = internalNioBuffer(); + ByteBuffer tmpBuf; + if (internal) { + tmpBuf = internalNioBuffer(); + } else { + tmpBuf = memory.duplicate(); + } index = idx(index); tmpBuf.clear().position(index).limit(index + length); tmpBuf.get(dst, dstIndex, length); + } + + @Override + public ByteBuf readBytes(byte[] dst, int dstIndex, int length) { + checkReadableBytes(length); + getBytes(readerIndex, dst, dstIndex, length, true); + readerIndex += length; return this; } @Override public ByteBuf getBytes(int index, ByteBuffer dst) { + getBytes(index, dst, false); + return this; + } + + private void getBytes(int index, ByteBuffer dst, boolean internal) { checkIndex(index); int bytesToCopy = Math.min(capacity() - index, dst.remaining()); - ByteBuffer tmpBuf = internalNioBuffer(); + ByteBuffer tmpBuf; + if (internal) { + tmpBuf = internalNioBuffer(); + } else { + tmpBuf = memory.duplicate(); + } index = idx(index); tmpBuf.clear().position(index).limit(index + bytesToCopy); dst.put(tmpBuf); + } + + @Override + public ByteBuf readBytes(ByteBuffer dst) { + int length = dst.remaining(); + checkReadableBytes(length); + getBytes(readerIndex, dst, true); + readerIndex += length; return this; } @Override public ByteBuf getBytes(int index, OutputStream out, int length) throws IOException { + getBytes(index, out, length, false); + return this; + } + + private void getBytes(int index, OutputStream out, int length, boolean internal) throws IOException { checkIndex(index, length); if (length == 0) { - return this; + return; } byte[] tmp = new byte[length]; - ByteBuffer tmpBuf = internalNioBuffer(); + ByteBuffer tmpBuf; + if (internal) { + tmpBuf = internalNioBuffer(); + } else { + tmpBuf = memory.duplicate(); + } tmpBuf.clear().position(idx(index)); tmpBuf.get(tmp); out.write(tmp); + } + + @Override + public ByteBuf readBytes(OutputStream out, int length) throws IOException { + checkReadableBytes(length); + getBytes(readerIndex, out, length, true); + readerIndex += length; return this; } @Override public int getBytes(int index, GatheringByteChannel out, int length) throws IOException { + return getBytes(index, out, length, false); + } + + private int getBytes(int index, GatheringByteChannel out, int length, boolean internal) throws IOException { checkIndex(index, length); if (length == 0) { return 0; } - ByteBuffer tmpBuf = internalNioBuffer(); + ByteBuffer tmpBuf; + if (internal) { + tmpBuf = internalNioBuffer(); + } else { + tmpBuf = memory.duplicate(); + } index = idx(index); tmpBuf.clear().position(index).limit(index + length); return out.write(tmpBuf); } + @Override + public int readBytes(GatheringByteChannel out, int length) throws IOException { + checkReadableBytes(length); + int readBytes = getBytes(readerIndex, out, length, true); + readerIndex += readBytes; + return readBytes; + } + @Override protected void _setByte(int index, int value) { memory.put(idx(index), (byte) value); @@ -225,20 +294,20 @@ final class PooledDirectByteBuf extends PooledByteBuf { if (readBytes <= 0) { return readBytes; } - ByteBuffer tmpNioBuf = internalNioBuffer(); - tmpNioBuf.clear().position(idx(index)); - tmpNioBuf.put(tmp, 0, readBytes); + ByteBuffer tmpBuf = internalNioBuffer(); + tmpBuf.clear().position(idx(index)); + tmpBuf.put(tmp, 0, readBytes); return readBytes; } @Override public int setBytes(int index, ScatteringByteChannel in, int length) throws IOException { checkIndex(index, length); - ByteBuffer tmpNioBuf = internalNioBuffer(); + ByteBuffer tmpBuf = internalNioBuffer(); index = idx(index); - tmpNioBuf.clear().position(index).limit(index + length); + tmpBuf.clear().position(index).limit(index + length); try { - return in.read(tmpNioBuf); + return in.read(tmpBuf); } catch (ClosedChannelException e) { return -1; } diff --git a/buffer/src/main/java/io/netty/buffer/PooledHeapByteBuf.java b/buffer/src/main/java/io/netty/buffer/PooledHeapByteBuf.java index 5638cc5607..74831e9f17 100644 --- a/buffer/src/main/java/io/netty/buffer/PooledHeapByteBuf.java +++ b/buffer/src/main/java/io/netty/buffer/PooledHeapByteBuf.java @@ -127,9 +127,27 @@ final class PooledHeapByteBuf extends PooledByteBuf { @Override public int getBytes(int index, GatheringByteChannel out, int length) throws IOException { + return getBytes(index, out, length, false); + } + + private int getBytes(int index, GatheringByteChannel out, int length, boolean internal) throws IOException { checkIndex(index, length); index = idx(index); - return out.write((ByteBuffer) internalNioBuffer().clear().position(index).limit(index + length)); + ByteBuffer tmpBuf; + if (internal) { + tmpBuf = internalNioBuffer(); + } else { + tmpBuf = ByteBuffer.wrap(memory); + } + return out.write((ByteBuffer) tmpBuf.clear().position(index).limit(index + length)); + } + + @Override + public int readBytes(GatheringByteChannel out, int length) throws IOException { + checkReadableBytes(length); + int readBytes = getBytes(readerIndex, out, length, true); + readerIndex += readBytes; + return readBytes; } @Override diff --git a/buffer/src/main/java/io/netty/buffer/PooledUnsafeDirectByteBuf.java b/buffer/src/main/java/io/netty/buffer/PooledUnsafeDirectByteBuf.java index 21226c9188..6f547e9255 100644 --- a/buffer/src/main/java/io/netty/buffer/PooledUnsafeDirectByteBuf.java +++ b/buffer/src/main/java/io/netty/buffer/PooledUnsafeDirectByteBuf.java @@ -148,12 +148,30 @@ final class PooledUnsafeDirectByteBuf extends PooledByteBuf { @Override public ByteBuf getBytes(int index, ByteBuffer dst) { + getBytes(index, dst, false); + return this; + } + + private void getBytes(int index, ByteBuffer dst, boolean internal) { checkIndex(index); int bytesToCopy = Math.min(capacity() - index, dst.remaining()); - ByteBuffer tmpBuf = internalNioBuffer(); + ByteBuffer tmpBuf; + if (internal) { + tmpBuf = internalNioBuffer(); + } else { + tmpBuf = memory.duplicate(); + } index = idx(index); tmpBuf.clear().position(index).limit(index + bytesToCopy); dst.put(tmpBuf); + } + + @Override + public ByteBuf readBytes(ByteBuffer dst) { + int length = dst.remaining(); + checkReadableBytes(length); + getBytes(readerIndex, dst, true); + readerIndex += length; return this; } @@ -170,17 +188,35 @@ final class PooledUnsafeDirectByteBuf extends PooledByteBuf { @Override public int getBytes(int index, GatheringByteChannel out, int length) throws IOException { + return getBytes(index, out, length, false); + } + + private int getBytes(int index, GatheringByteChannel out, int length, boolean internal) throws IOException { checkIndex(index, length); if (length == 0) { return 0; } - ByteBuffer tmpBuf = internalNioBuffer(); + ByteBuffer tmpBuf; + if (internal) { + tmpBuf = internalNioBuffer(); + } else { + tmpBuf = memory.duplicate(); + } index = idx(index); tmpBuf.clear().position(index).limit(index + length); return out.write(tmpBuf); } + @Override + public int readBytes(GatheringByteChannel out, int length) + throws IOException { + checkReadableBytes(length); + int readBytes = getBytes(readerIndex, out, length, true); + readerIndex += readBytes; + return readBytes; + } + @Override protected void _setByte(int index, int value) { PlatformDependent.putByte(addr(index), (byte) value); @@ -268,11 +304,11 @@ final class PooledUnsafeDirectByteBuf extends PooledByteBuf { @Override public int setBytes(int index, ScatteringByteChannel in, int length) throws IOException { checkIndex(index, length); - ByteBuffer tmpNioBuf = internalNioBuffer(); + ByteBuffer tmpBuf = internalNioBuffer(); index = idx(index); - tmpNioBuf.clear().position(index).limit(index + length); + tmpBuf.clear().position(index).limit(index + length); try { - return in.read(tmpNioBuf); + return in.read(tmpBuf); } catch (ClosedChannelException e) { return -1; } diff --git a/buffer/src/main/java/io/netty/buffer/UnpooledDirectByteBuf.java b/buffer/src/main/java/io/netty/buffer/UnpooledDirectByteBuf.java index 721e28fddf..1ecacbeb6f 100644 --- a/buffer/src/main/java/io/netty/buffer/UnpooledDirectByteBuf.java +++ b/buffer/src/main/java/io/netty/buffer/UnpooledDirectByteBuf.java @@ -274,6 +274,11 @@ public class UnpooledDirectByteBuf extends AbstractReferenceCountedByteBuf { @Override public ByteBuf getBytes(int index, byte[] dst, int dstIndex, int length) { + getBytes(index, dst, dstIndex, length, false); + return this; + } + + private void getBytes(int index, byte[] dst, int dstIndex, int length, boolean internal) { checkDstIndex(index, length, dstIndex, dst.length); if (dstIndex < 0 || dstIndex > dst.length - length) { @@ -281,23 +286,53 @@ public class UnpooledDirectByteBuf extends AbstractReferenceCountedByteBuf { "dstIndex: %d, length: %d (expected: range(0, %d))", dstIndex, length, dst.length)); } - ByteBuffer tmpBuf = internalNioBuffer(); + ByteBuffer tmpBuf; + if (internal) { + tmpBuf = internalNioBuffer(); + } else { + tmpBuf = buffer.duplicate(); + } tmpBuf.clear().position(index).limit(index + length); tmpBuf.get(dst, dstIndex, length); + } + + @Override + public ByteBuf readBytes(byte[] dst, int dstIndex, int length) { + checkReadableBytes(length); + getBytes(readerIndex, dst, dstIndex, length, true); + readerIndex += length; return this; } @Override public ByteBuf getBytes(int index, ByteBuffer dst) { + getBytes(index, dst, false); + return this; + } + + private void getBytes(int index, ByteBuffer dst, boolean internal) { checkIndex(index); if (dst == null) { throw new NullPointerException("dst"); } int bytesToCopy = Math.min(capacity() - index, dst.remaining()); - ByteBuffer tmpBuf = internalNioBuffer(); + ByteBuffer tmpBuf; + if (internal) { + tmpBuf = internalNioBuffer(); + } else { + tmpBuf = buffer.duplicate(); + } tmpBuf.clear().position(index).limit(index + bytesToCopy); dst.put(tmpBuf); + } + + @Override + public ByteBuf readBytes(ByteBuffer dst) { + int length = dst.remaining(); + checkReadableBytes(length); + getBytes(readerIndex, dst, true); + readerIndex += length; return this; } @@ -404,35 +439,69 @@ public class UnpooledDirectByteBuf extends AbstractReferenceCountedByteBuf { @Override public ByteBuf getBytes(int index, OutputStream out, int length) throws IOException { + getBytes(index, out, length, false); + return this; + } + + private void getBytes(int index, OutputStream out, int length, boolean internal) throws IOException { ensureAccessible(); if (length == 0) { - return this; + return; } if (buffer.hasArray()) { out.write(buffer.array(), index + buffer.arrayOffset(), length); } else { byte[] tmp = new byte[length]; - ByteBuffer tmpBuf = internalNioBuffer(); + ByteBuffer tmpBuf; + if (internal) { + tmpBuf = internalNioBuffer(); + } else { + tmpBuf = buffer.duplicate(); + } tmpBuf.clear().position(index); tmpBuf.get(tmp); out.write(tmp); } + } + + @Override + public ByteBuf readBytes(OutputStream out, int length) throws IOException { + checkReadableBytes(length); + getBytes(readerIndex, out, length, true); + readerIndex += length; return this; } @Override public int getBytes(int index, GatheringByteChannel out, int length) throws IOException { + return getBytes(index, out, length, false); + } + + private int getBytes(int index, GatheringByteChannel out, int length, boolean internal) throws IOException { ensureAccessible(); if (length == 0) { return 0; } - ByteBuffer tmpBuf = internalNioBuffer(); + ByteBuffer tmpBuf; + if (internal) { + tmpBuf = internalNioBuffer(); + } else { + tmpBuf = buffer.duplicate(); + } tmpBuf.clear().position(index).limit(index + length); return out.write(tmpBuf); } + @Override + public int readBytes(GatheringByteChannel out, int length) throws IOException { + checkReadableBytes(length); + int readBytes = getBytes(readerIndex, out, length, true); + readerIndex += readBytes; + return readBytes; + } + @Override public int setBytes(int index, InputStream in, int length) throws IOException { ensureAccessible(); @@ -444,9 +513,9 @@ public class UnpooledDirectByteBuf extends AbstractReferenceCountedByteBuf { if (readBytes <= 0) { return readBytes; } - ByteBuffer tmpNioBuf = internalNioBuffer(); - tmpNioBuf.clear().position(index); - tmpNioBuf.put(tmp, 0, readBytes); + ByteBuffer tmpBuf = internalNioBuffer(); + tmpBuf.clear().position(index); + tmpBuf.put(tmp, 0, readBytes); return readBytes; } } @@ -454,8 +523,8 @@ public class UnpooledDirectByteBuf extends AbstractReferenceCountedByteBuf { @Override public int setBytes(int index, ScatteringByteChannel in, int length) throws IOException { ensureAccessible(); - ByteBuffer tmpNioBuf = internalNioBuffer(); - tmpNioBuf.clear().position(index).limit(index + length); + ByteBuffer tmpBuf = internalNioBuffer(); + tmpBuf.clear().position(index).limit(index + length); try { return in.read(tmpNioBuf); } catch (ClosedChannelException e) { @@ -478,7 +547,7 @@ public class UnpooledDirectByteBuf extends AbstractReferenceCountedByteBuf { ensureAccessible(); ByteBuffer src; try { - src = (ByteBuffer) internalNioBuffer().clear().position(index).limit(index + length); + src = (ByteBuffer) buffer.duplicate().clear().position(index).limit(index + length); } catch (IllegalArgumentException e) { throw new IndexOutOfBoundsException("Too many bytes to read - Need " + (index + length)); } diff --git a/buffer/src/main/java/io/netty/buffer/UnpooledHeapByteBuf.java b/buffer/src/main/java/io/netty/buffer/UnpooledHeapByteBuf.java index 2f3b5a6d11..4c80c30317 100644 --- a/buffer/src/main/java/io/netty/buffer/UnpooledHeapByteBuf.java +++ b/buffer/src/main/java/io/netty/buffer/UnpooledHeapByteBuf.java @@ -194,7 +194,26 @@ public class UnpooledHeapByteBuf extends AbstractReferenceCountedByteBuf { @Override public int getBytes(int index, GatheringByteChannel out, int length) throws IOException { ensureAccessible(); - return out.write((ByteBuffer) internalNioBuffer().clear().position(index).limit(index + length)); + return getBytes(index, out, length, false); + } + + private int getBytes(int index, GatheringByteChannel out, int length, boolean internal) throws IOException { + ensureAccessible(); + ByteBuffer tmpBuf; + if (internal) { + tmpBuf = internalNioBuffer(); + } else { + tmpBuf = ByteBuffer.wrap(array); + } + return out.write((ByteBuffer) tmpBuf.clear().position(index).limit(index + length)); + } + + @Override + public int readBytes(GatheringByteChannel out, int length) throws IOException { + checkReadableBytes(length); + int readBytes = getBytes(readerIndex, out, length, true); + readerIndex += readBytes; + return readBytes; } @Override diff --git a/buffer/src/main/java/io/netty/buffer/UnpooledUnsafeDirectByteBuf.java b/buffer/src/main/java/io/netty/buffer/UnpooledUnsafeDirectByteBuf.java index 6e00b51ba2..46fc3729ab 100644 --- a/buffer/src/main/java/io/netty/buffer/UnpooledUnsafeDirectByteBuf.java +++ b/buffer/src/main/java/io/netty/buffer/UnpooledUnsafeDirectByteBuf.java @@ -274,15 +274,33 @@ public class UnpooledUnsafeDirectByteBuf extends AbstractReferenceCountedByteBuf @Override public ByteBuf getBytes(int index, ByteBuffer dst) { + getBytes(index, dst, false); + return this; + } + + private void getBytes(int index, ByteBuffer dst, boolean internal) { checkIndex(index); if (dst == null) { throw new NullPointerException("dst"); } int bytesToCopy = Math.min(capacity() - index, dst.remaining()); - ByteBuffer tmpBuf = internalNioBuffer(); + ByteBuffer tmpBuf; + if (internal) { + tmpBuf = internalNioBuffer(); + } else { + tmpBuf = buffer.duplicate(); + } tmpBuf.clear().position(index).limit(index + bytesToCopy); dst.put(tmpBuf); + } + + @Override + public ByteBuf readBytes(ByteBuffer dst) { + int length = dst.remaining(); + checkReadableBytes(length); + getBytes(readerIndex, dst, true); + readerIndex += length; return this; } @@ -371,16 +389,33 @@ public class UnpooledUnsafeDirectByteBuf extends AbstractReferenceCountedByteBuf @Override public int getBytes(int index, GatheringByteChannel out, int length) throws IOException { + return getBytes(index, out, length, false); + } + + private int getBytes(int index, GatheringByteChannel out, int length, boolean internal) throws IOException { ensureAccessible(); if (length == 0) { return 0; } - ByteBuffer tmpBuf = internalNioBuffer(); + ByteBuffer tmpBuf; + if (internal) { + tmpBuf = internalNioBuffer(); + } else { + tmpBuf = buffer.duplicate(); + } tmpBuf.clear().position(index).limit(index + length); return out.write(tmpBuf); } + @Override + public int readBytes(GatheringByteChannel out, int length) throws IOException { + checkReadableBytes(length); + int readBytes = getBytes(readerIndex, out, length, true); + readerIndex += readBytes; + return readBytes; + } + @Override public int setBytes(int index, InputStream in, int length) throws IOException { checkIndex(index, length); @@ -395,10 +430,10 @@ public class UnpooledUnsafeDirectByteBuf extends AbstractReferenceCountedByteBuf @Override public int setBytes(int index, ScatteringByteChannel in, int length) throws IOException { ensureAccessible(); - ByteBuffer tmpNioBuf = internalNioBuffer(); - tmpNioBuf.clear().position(index).limit(index + length); + ByteBuffer tmpBuf = internalNioBuffer(); + tmpBuf.clear().position(index).limit(index + length); try { - return in.read(tmpNioBuf); + return in.read(tmpBuf); } catch (ClosedChannelException e) { return -1; } diff --git a/buffer/src/test/java/io/netty/buffer/AbstractByteBufTest.java b/buffer/src/test/java/io/netty/buffer/AbstractByteBufTest.java index 62ab85ae5b..4d33483ce2 100644 --- a/buffer/src/test/java/io/netty/buffer/AbstractByteBufTest.java +++ b/buffer/src/test/java/io/netty/buffer/AbstractByteBufTest.java @@ -17,6 +17,7 @@ package io.netty.buffer; import io.netty.util.CharsetUtil; import org.junit.After; +import org.junit.Assert; import org.junit.Assume; import org.junit.Before; import org.junit.Ignore; @@ -24,14 +25,22 @@ import org.junit.Test; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; +import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.channels.GatheringByteChannel; +import java.nio.channels.WritableByteChannel; import java.util.ArrayDeque; import java.util.Arrays; import java.util.HashSet; import java.util.Queue; import java.util.Random; import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import static io.netty.buffer.Unpooled.*; import static io.netty.util.internal.EmptyArrays.*; @@ -1762,4 +1771,222 @@ public abstract class AbstractByteBufTest { } assertFalse(buf.hasRemaining()); } + + @Test + public void testDuplicateReadGatheringByteChannelMultipleThreads() throws Exception { + testReadGatheringByteChannelMultipleThreads(false); + } + + @Test + public void testSliceReadGatheringByteChannelMultipleThreads() throws Exception { + testReadGatheringByteChannelMultipleThreads(true); + } + + private void testReadGatheringByteChannelMultipleThreads(final boolean slice) throws Exception { + final byte[] bytes = new byte[8]; + random.nextBytes(bytes); + + final ByteBuf buffer = freeLater(newBuffer(8)); + buffer.writeBytes(bytes); + final CountDownLatch latch = new CountDownLatch(60000); + final CyclicBarrier barrier = new CyclicBarrier(11); + for (int i = 0; i < 10; i++) { + new Thread(new Runnable() { + @Override + public void run() { + while (latch.getCount() > 0) { + ByteBuf buf; + if (slice) { + buf = buffer.slice(); + } else { + buf = buffer.duplicate(); + } + TestGatheringByteChannel channel = new TestGatheringByteChannel(); + + while (buf.isReadable()) { + try { + buf.readBytes(channel, buf.readableBytes()); + } catch (IOException e) { + // Never happens + return; + } + } + Assert.assertArrayEquals(bytes, channel.writtenBytes()); + latch.countDown(); + } + try { + barrier.await(); + } catch (Exception e) { + // ignore + } + } + }).start(); + } + latch.await(10, TimeUnit.SECONDS); + barrier.await(5, TimeUnit.SECONDS); + } + + @Test + public void testDuplicateReadOutputStreamMultipleThreads() throws Exception { + testReadOutputStreamMultipleThreads(false); + } + + @Test + public void testSliceReadOutputStreamMultipleThreads() throws Exception { + testReadOutputStreamMultipleThreads(true); + } + + private void testReadOutputStreamMultipleThreads(final boolean slice) throws Exception { + final byte[] bytes = new byte[8]; + random.nextBytes(bytes); + + final ByteBuf buffer = freeLater(newBuffer(8)); + buffer.writeBytes(bytes); + final CountDownLatch latch = new CountDownLatch(60000); + final CyclicBarrier barrier = new CyclicBarrier(11); + for (int i = 0; i < 10; i++) { + new Thread(new Runnable() { + @Override + public void run() { + while (latch.getCount() > 0) { + ByteBuf buf; + if (slice) { + buf = buffer.slice(); + } else { + buf = buffer.duplicate(); + } + ByteArrayOutputStream out = new ByteArrayOutputStream(); + + while (buf.isReadable()) { + try { + buf.readBytes(out, buf.readableBytes()); + } catch (IOException e) { + // Never happens + return; + } + } + Assert.assertArrayEquals(bytes, out.toByteArray()); + latch.countDown(); + } + try { + barrier.await(); + } catch (Exception e) { + // ignore + } + } + }).start(); + } + latch.await(10, TimeUnit.SECONDS); + barrier.await(5, TimeUnit.SECONDS); + } + + @Test + public void testDuplicateBytesInArrayMultipleThreads() throws Exception { + testBytesInArrayMultipleThreads(false); + } + + @Test + public void testSliceBytesInArrayMultipleThreads() throws Exception { + testBytesInArrayMultipleThreads(true); + } + + private void testBytesInArrayMultipleThreads(final boolean slice) throws Exception { + final byte[] bytes = new byte[8]; + random.nextBytes(bytes); + + final ByteBuf buffer = freeLater(newBuffer(8)); + buffer.writeBytes(bytes); + final AtomicReference cause = new AtomicReference(); + final CountDownLatch latch = new CountDownLatch(60000); + final CyclicBarrier barrier = new CyclicBarrier(11); + for (int i = 0; i < 10; i++) { + new Thread(new Runnable() { + @Override + public void run() { + while (cause.get() == null && latch.getCount() > 0) { + ByteBuf buf; + if (slice) { + buf = buffer.slice(); + } else { + buf = buffer.duplicate(); + } + + byte[] array = new byte[8]; + buf.readBytes(array); + + Assert.assertArrayEquals(bytes, array); + + Arrays.fill(array, (byte) 0); + buf.getBytes(0, array); + Assert.assertArrayEquals(bytes, array); + + latch.countDown(); + } + try { + barrier.await(); + } catch (Exception e) { + // ignore + } + } + }).start(); + } + latch.await(10, TimeUnit.SECONDS); + barrier.await(5, TimeUnit.SECONDS); + assertNull(cause.get()); + } + + static final class TestGatheringByteChannel implements GatheringByteChannel { + private final ByteArrayOutputStream out = new ByteArrayOutputStream(); + private final WritableByteChannel channel = Channels.newChannel(out); + private final int limit; + TestGatheringByteChannel(int limit) { + this.limit = limit; + } + + TestGatheringByteChannel() { + this(Integer.MAX_VALUE); + } + + @Override + public long write(ByteBuffer[] srcs, int offset, int length) throws IOException { + long written = 0; + for (; offset < length; offset++) { + written += write(srcs[offset]); + if (written >= limit) { + break; + } + } + return written; + } + + @Override + public long write(ByteBuffer[] srcs) throws IOException { + return write(srcs, 0, srcs.length); + } + + @Override + public int write(ByteBuffer src) throws IOException { + int oldLimit = src.limit(); + if (limit < src.remaining()) { + src.limit(src.position() + limit); + } + int w = channel.write(src); + src.limit(oldLimit); + return w; + } + + @Override + public boolean isOpen() { + return channel.isOpen(); + } + + @Override + public void close() throws IOException { + channel.close(); + } + + public byte[] writtenBytes() { + return out.toByteArray(); + } + } } diff --git a/buffer/src/test/java/io/netty/buffer/AbstractCompositeByteBufTest.java b/buffer/src/test/java/io/netty/buffer/AbstractCompositeByteBufTest.java index f5cb599ec5..a03f201113 100644 --- a/buffer/src/test/java/io/netty/buffer/AbstractCompositeByteBufTest.java +++ b/buffer/src/test/java/io/netty/buffer/AbstractCompositeByteBufTest.java @@ -17,13 +17,8 @@ package io.netty.buffer; import org.junit.Test; -import java.io.ByteArrayOutputStream; -import java.io.IOException; import java.nio.ByteBuffer; import java.nio.ByteOrder; -import java.nio.channels.Channels; -import java.nio.channels.GatheringByteChannel; -import java.nio.channels.WritableByteChannel; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -781,59 +776,4 @@ public abstract class AbstractCompositeByteBufTest extends AbstractByteBufTest { public void testInternalNioBuffer() { // ignore } - - private static final class TestGatheringByteChannel implements GatheringByteChannel { - private final ByteArrayOutputStream out = new ByteArrayOutputStream(); - private final WritableByteChannel channel = Channels.newChannel(out); - private final int limit; - TestGatheringByteChannel(int limit) { - this.limit = limit; - } - - TestGatheringByteChannel() { - this(Integer.MAX_VALUE); - } - - @Override - public long write(ByteBuffer[] srcs, int offset, int length) throws IOException { - long written = 0; - for (; offset < length; offset++) { - written += write(srcs[offset]); - if (written >= limit) { - break; - } - } - return written; - } - - @Override - public long write(ByteBuffer[] srcs) throws IOException { - return write(srcs, 0, srcs.length); - } - - @Override - public int write(ByteBuffer src) throws IOException { - int oldLimit = src.limit(); - if (limit < src.remaining()) { - src.limit(src.position() + limit); - } - int w = channel.write(src); - src.limit(oldLimit); - return w; - } - - @Override - public boolean isOpen() { - return channel.isOpen(); - } - - @Override - public void close() throws IOException { - channel.close(); - } - - public byte[] writtenBytes() { - return out.toByteArray(); - } - } } diff --git a/buffer/src/test/java/io/netty/buffer/SlicedByteBufTest.java b/buffer/src/test/java/io/netty/buffer/SlicedByteBufTest.java index 90be26e570..5a3c891ab5 100644 --- a/buffer/src/test/java/io/netty/buffer/SlicedByteBufTest.java +++ b/buffer/src/test/java/io/netty/buffer/SlicedByteBufTest.java @@ -52,4 +52,40 @@ public class SlicedByteBufTest extends AbstractByteBufTest { public void testInternalNioBuffer() { super.testInternalNioBuffer(); } + + @Test(expected = IndexOutOfBoundsException.class) + @Override + public void testDuplicateReadGatheringByteChannelMultipleThreads() throws Exception { + super.testDuplicateReadGatheringByteChannelMultipleThreads(); + } + + @Test(expected = IndexOutOfBoundsException.class) + @Override + public void testSliceReadGatheringByteChannelMultipleThreads() throws Exception { + super.testSliceReadGatheringByteChannelMultipleThreads(); + } + + @Test(expected = IndexOutOfBoundsException.class) + @Override + public void testDuplicateReadOutputStreamMultipleThreads() throws Exception { + super.testDuplicateReadOutputStreamMultipleThreads(); + } + + @Test(expected = IndexOutOfBoundsException.class) + @Override + public void testSliceReadOutputStreamMultipleThreads() throws Exception { + super.testSliceReadOutputStreamMultipleThreads(); + } + + @Test(expected = IndexOutOfBoundsException.class) + @Override + public void testDuplicateBytesInArrayMultipleThreads() throws Exception { + super.testDuplicateBytesInArrayMultipleThreads(); + } + + @Test(expected = IndexOutOfBoundsException.class) + @Override + public void testSliceBytesInArrayMultipleThreads() throws Exception { + super.testSliceBytesInArrayMultipleThreads(); + } }