diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/StreamUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/StreamUtils.java index 28c604361fd0..338779c30dca 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/StreamUtils.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/StreamUtils.java @@ -22,6 +22,8 @@ import java.io.IOException; import java.io.InputStream; import java.lang.ref.SoftReference; +import java.util.Arrays; +import javax.annotation.Nullable; import org.apache.beam.sdk.annotations.Internal; /** Utility functions for stream operations. */ @@ -35,34 +37,67 @@ private StreamUtils() {} private static final int BUF_SIZE = 8192; - private static ThreadLocal> threadLocalBuffer = new ThreadLocal<>(); + private static final ThreadLocal> threadLocalBuffer = new ThreadLocal<>(); /** Efficient converting stream to bytes. */ public static byte[] getBytesWithoutClosing(InputStream stream) throws IOException { + // Unwrap the stream so the below optimizations based upon class type function properly. + // We don't use mark or reset in this function. + while (stream instanceof UnownedInputStream) { + stream = ((UnownedInputStream) stream).getWrappedStream(); + } + if (stream instanceof ExposedByteArrayInputStream) { // Fast path for the exposed version. return ((ExposedByteArrayInputStream) stream).readAll(); - } else if (stream instanceof ByteArrayInputStream) { + } + if (stream instanceof ByteArrayInputStream) { // Fast path for ByteArrayInputStream. byte[] ret = new byte[stream.available()]; stream.read(ret); return ret; } - // Falls back to normal stream copying. + + // Most inputs are fully available so we attempt to first read directly + // into a buffer of the right size, assuming available reflects all the bytes. + int available = stream.available(); + @Nullable ByteArrayOutputStream outputStream = null; + if (available > 0 && available < 1024 * 1024) { + byte[] initialBuffer = new byte[available]; + int initialReadSize = stream.read(initialBuffer); + if (initialReadSize == -1) { + return new byte[0]; + } + int nextChar = stream.read(); + if (nextChar == -1) { + if (initialReadSize == available) { + // Available reflected the full buffer and we copied directly to the + // right size. + return initialBuffer; + } + return Arrays.copyOf(initialBuffer, initialReadSize); + } + outputStream = new ByteArrayOutputStream(); + outputStream.write(initialBuffer, 0, initialReadSize); + outputStream.write(nextChar); + } else { + outputStream = new ByteArrayOutputStream(); + } + + // Normal stream copying using the thread-local buffer. SoftReference refBuffer = threadLocalBuffer.get(); byte[] buffer = refBuffer == null ? null : refBuffer.get(); if (buffer == null) { buffer = new byte[BUF_SIZE]; threadLocalBuffer.set(new SoftReference<>(buffer)); } - ByteArrayOutputStream outStream = new ByteArrayOutputStream(); while (true) { int r = stream.read(buffer); if (r == -1) { break; } - outStream.write(buffer, 0, r); + outputStream.write(buffer, 0, r); } - return outStream.toByteArray(); + return outputStream.toByteArray(); } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/UnownedInputStream.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/UnownedInputStream.java index acf70ed6b00d..345e6a8763b2 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/UnownedInputStream.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/UnownedInputStream.java @@ -35,6 +35,10 @@ public UnownedInputStream(InputStream delegate) { super(delegate); } + InputStream getWrappedStream() { + return in; + } + @Override public void close() throws IOException { throw new UnsupportedOperationException( diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/CoderUtilsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/CoderUtilsTest.java index 132081601814..943ac5ddc8cf 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/CoderUtilsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/CoderUtilsTest.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.util; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertSame; import static org.junit.Assert.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; @@ -26,7 +27,9 @@ import java.io.InputStream; import java.io.OutputStream; +import java.nio.charset.StandardCharsets; import org.apache.beam.sdk.coders.AtomicCoder; +import org.apache.beam.sdk.coders.ByteArrayCoder; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.Coder.Context; import org.apache.beam.sdk.coders.CoderException; @@ -142,4 +145,11 @@ public void testDecodeFromByteStringWithExtraDataThrows() throws Exception { CoderException.class, () -> CoderUtils.decodeFromByteString(StringUtf8Coder.of(), byteString, Context.NESTED)); } + + @Test + public void testDecodeByteArrayWithoutCopy() throws Exception { + byte[] data = "test data".getBytes(StandardCharsets.UTF_8); + byte[] result = CoderUtils.decodeFromByteArray(ByteArrayCoder.of(), data); + assertSame(data, result); + } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/StreamUtilsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/StreamUtilsTest.java index 68f87d737631..c081c0a33e66 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/StreamUtilsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/StreamUtilsTest.java @@ -23,9 +23,11 @@ import java.io.BufferedInputStream; import java.io.ByteArrayInputStream; +import java.io.FilterInputStream; import java.io.IOException; import java.io.InputStream; import java.util.Arrays; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -68,4 +70,97 @@ public void testGetBytesFromInputStream() throws IOException { assertArrayEquals(testData, bytes); assertEquals(0, stream.available()); } + + @Test + public void testGetBytesFromUnownedInputStreamAroundExposed() throws IOException { + InputStream stream = new UnownedInputStream(new ExposedByteArrayInputStream(testData)); + byte[] bytes = StreamUtils.getBytesWithoutClosing(stream); + assertArrayEquals(testData, bytes); + assertSame(testData, bytes); + assertEquals(0, stream.available()); + } + + @Test + public void testGetBytesFromUnownedInputStreamAroundArray() throws IOException { + InputStream stream = new UnownedInputStream(new ByteArrayInputStream(testData)); + byte[] bytes = StreamUtils.getBytesWithoutClosing(stream); + assertArrayEquals(testData, bytes); + assertEquals(0, stream.available()); + } + + @Test + public void testGetBytesFromLimitedInputStream() throws IOException { + InputStream stream = ByteStreams.limit(new ByteArrayInputStream(testData), Integer.MAX_VALUE); + byte[] bytes = StreamUtils.getBytesWithoutClosing(stream); + assertArrayEquals(testData, bytes); + assertEquals(0, stream.available()); + } + + @Test + public void testGetBytesFromEmptyLimitedInputStream() throws IOException { + InputStream stream = ByteStreams.limit(new ByteArrayInputStream(testData), 0); + byte[] bytes = StreamUtils.getBytesWithoutClosing(stream); + assertArrayEquals(new byte[0], bytes); + assertEquals(0, stream.available()); + } + + @Test + public void testGetBytesFromRepeatedInputStream() throws IOException { + byte[] largeBytes = new byte[2 * 1024 * 1024]; + Arrays.fill(largeBytes, (byte) 1); + InputStream stream = ByteStreams.limit(new ByteArrayInputStream(largeBytes), Integer.MAX_VALUE); + byte[] bytes = StreamUtils.getBytesWithoutClosing(stream); + assertArrayEquals(largeBytes, bytes); + assertEquals(0, stream.available()); + } + + public static class LyingInputStream extends FilterInputStream { + private final int availableLie; + + public LyingInputStream(InputStream in, int availableLie) { + super(in); + this.availableLie = availableLie; + } + + @Override + public int available() throws IOException { + return availableLie; + } + } + + @Test + public void testGetBytesFromHugeAvailable() throws IOException { + InputStream wrappedStream = new ByteArrayInputStream(testData); + InputStream stream = new LyingInputStream(wrappedStream, Integer.MAX_VALUE - 1); + byte[] bytes = StreamUtils.getBytesWithoutClosing(stream); + assertArrayEquals(testData, bytes); + assertEquals(0, wrappedStream.available()); + } + + @Test + public void testGetBytesFromZeroAvailable() throws IOException { + InputStream wrappedStream = new ByteArrayInputStream(testData); + InputStream stream = new LyingInputStream(wrappedStream, 0); + byte[] bytes = StreamUtils.getBytesWithoutClosing(stream); + assertArrayEquals(testData, bytes); + assertEquals(0, wrappedStream.available()); + } + + @Test + public void testGetBytesFromOneExtraAvailable() throws IOException { + InputStream wrappedStream = new ByteArrayInputStream(testData); + InputStream stream = new LyingInputStream(wrappedStream, wrappedStream.available() + 1); + byte[] bytes = StreamUtils.getBytesWithoutClosing(stream); + assertArrayEquals(testData, bytes); + assertEquals(0, wrappedStream.available()); + } + + @Test + public void testGetBytesFromOneLessAvailable() throws IOException { + InputStream wrappedStream = new ByteArrayInputStream(testData); + InputStream stream = new LyingInputStream(wrappedStream, wrappedStream.available() - 1); + byte[] bytes = StreamUtils.getBytesWithoutClosing(stream); + assertArrayEquals(testData, bytes); + assertEquals(0, wrappedStream.available()); + } }