Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import java.io.IOException;
import java.io.InputStream;
import java.lang.ref.SoftReference;
import java.util.Arrays;
import javax.annotation.Nullable;
import org.apache.beam.sdk.annotations.Internal;

/** Utility functions for stream operations. */
Expand All @@ -35,34 +37,67 @@ private StreamUtils() {}

private static final int BUF_SIZE = 8192;

private static ThreadLocal<SoftReference<byte[]>> threadLocalBuffer = new ThreadLocal<>();
private static final ThreadLocal<SoftReference<byte[]>> threadLocalBuffer = new ThreadLocal<>();

/** Efficient converting stream to bytes. */
public static byte[] getBytesWithoutClosing(InputStream stream) throws IOException {
// Unwrap the stream so the below optimizations based upon class type function properly.
// We don't use mark or reset in this function.
while (stream instanceof UnownedInputStream) {
stream = ((UnownedInputStream) stream).getWrappedStream();
}

if (stream instanceof ExposedByteArrayInputStream) {
// Fast path for the exposed version.
return ((ExposedByteArrayInputStream) stream).readAll();
} else if (stream instanceof ByteArrayInputStream) {
}
if (stream instanceof ByteArrayInputStream) {
// Fast path for ByteArrayInputStream.
byte[] ret = new byte[stream.available()];
stream.read(ret);
return ret;
}
// Falls back to normal stream copying.

// Most inputs are fully available so we attempt to first read directly
// into a buffer of the right size, assuming available reflects all the bytes.
int available = stream.available();
@Nullable ByteArrayOutputStream outputStream = null;
if (available > 0 && available < 1024 * 1024) {
byte[] initialBuffer = new byte[available];
int initialReadSize = stream.read(initialBuffer);
if (initialReadSize == -1) {
return new byte[0];
}
int nextChar = stream.read();
if (nextChar == -1) {
if (initialReadSize == available) {
// Available reflected the full buffer and we copied directly to the
// right size.
return initialBuffer;
}
return Arrays.copyOf(initialBuffer, initialReadSize);
}
outputStream = new ByteArrayOutputStream();
outputStream.write(initialBuffer, 0, initialReadSize);
outputStream.write(nextChar);
} else {
outputStream = new ByteArrayOutputStream();
}

// Normal stream copying using the thread-local buffer.
SoftReference<byte[]> refBuffer = threadLocalBuffer.get();
byte[] buffer = refBuffer == null ? null : refBuffer.get();
if (buffer == null) {
buffer = new byte[BUF_SIZE];
threadLocalBuffer.set(new SoftReference<>(buffer));
}
ByteArrayOutputStream outStream = new ByteArrayOutputStream();
while (true) {
int r = stream.read(buffer);
if (r == -1) {
break;
}
outStream.write(buffer, 0, r);
outputStream.write(buffer, 0, r);
}
return outStream.toByteArray();
return outputStream.toByteArray();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ public UnownedInputStream(InputStream delegate) {
super(delegate);
}

InputStream getWrappedStream() {
return in;
}

@Override
public void close() throws IOException {
throw new UnsupportedOperationException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.beam.sdk.util;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
Expand All @@ -26,7 +27,9 @@

import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import org.apache.beam.sdk.coders.AtomicCoder;
import org.apache.beam.sdk.coders.ByteArrayCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.Coder.Context;
import org.apache.beam.sdk.coders.CoderException;
Expand Down Expand Up @@ -142,4 +145,11 @@ public void testDecodeFromByteStringWithExtraDataThrows() throws Exception {
CoderException.class,
() -> CoderUtils.decodeFromByteString(StringUtf8Coder.of(), byteString, Context.NESTED));
}

@Test
public void testDecodeByteArrayWithoutCopy() throws Exception {
byte[] data = "test data".getBytes(StandardCharsets.UTF_8);
byte[] result = CoderUtils.decodeFromByteArray(ByteArrayCoder.of(), data);
assertSame(data, result);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@

import java.io.BufferedInputStream;
import java.io.ByteArrayInputStream;
import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand Down Expand Up @@ -68,4 +70,97 @@ public void testGetBytesFromInputStream() throws IOException {
assertArrayEquals(testData, bytes);
assertEquals(0, stream.available());
}

@Test
public void testGetBytesFromUnownedInputStreamAroundExposed() throws IOException {
InputStream stream = new UnownedInputStream(new ExposedByteArrayInputStream(testData));
byte[] bytes = StreamUtils.getBytesWithoutClosing(stream);
assertArrayEquals(testData, bytes);
assertSame(testData, bytes);
assertEquals(0, stream.available());
}

@Test
public void testGetBytesFromUnownedInputStreamAroundArray() throws IOException {
InputStream stream = new UnownedInputStream(new ByteArrayInputStream(testData));
byte[] bytes = StreamUtils.getBytesWithoutClosing(stream);
assertArrayEquals(testData, bytes);
assertEquals(0, stream.available());
}

@Test
public void testGetBytesFromLimitedInputStream() throws IOException {
InputStream stream = ByteStreams.limit(new ByteArrayInputStream(testData), Integer.MAX_VALUE);
byte[] bytes = StreamUtils.getBytesWithoutClosing(stream);
assertArrayEquals(testData, bytes);
assertEquals(0, stream.available());
}

@Test
public void testGetBytesFromEmptyLimitedInputStream() throws IOException {
InputStream stream = ByteStreams.limit(new ByteArrayInputStream(testData), 0);
byte[] bytes = StreamUtils.getBytesWithoutClosing(stream);
assertArrayEquals(new byte[0], bytes);
assertEquals(0, stream.available());
}

@Test
public void testGetBytesFromRepeatedInputStream() throws IOException {
byte[] largeBytes = new byte[2 * 1024 * 1024];
Arrays.fill(largeBytes, (byte) 1);
InputStream stream = ByteStreams.limit(new ByteArrayInputStream(largeBytes), Integer.MAX_VALUE);
byte[] bytes = StreamUtils.getBytesWithoutClosing(stream);
assertArrayEquals(largeBytes, bytes);
assertEquals(0, stream.available());
}

public static class LyingInputStream extends FilterInputStream {
private final int availableLie;

public LyingInputStream(InputStream in, int availableLie) {
super(in);
this.availableLie = availableLie;
}

@Override
public int available() throws IOException {
return availableLie;
}
}

@Test
public void testGetBytesFromHugeAvailable() throws IOException {
InputStream wrappedStream = new ByteArrayInputStream(testData);
InputStream stream = new LyingInputStream(wrappedStream, Integer.MAX_VALUE - 1);
byte[] bytes = StreamUtils.getBytesWithoutClosing(stream);
assertArrayEquals(testData, bytes);
assertEquals(0, wrappedStream.available());
}

@Test
public void testGetBytesFromZeroAvailable() throws IOException {
InputStream wrappedStream = new ByteArrayInputStream(testData);
InputStream stream = new LyingInputStream(wrappedStream, 0);
byte[] bytes = StreamUtils.getBytesWithoutClosing(stream);
assertArrayEquals(testData, bytes);
assertEquals(0, wrappedStream.available());
}

@Test
public void testGetBytesFromOneExtraAvailable() throws IOException {
InputStream wrappedStream = new ByteArrayInputStream(testData);
InputStream stream = new LyingInputStream(wrappedStream, wrappedStream.available() + 1);
byte[] bytes = StreamUtils.getBytesWithoutClosing(stream);
assertArrayEquals(testData, bytes);
assertEquals(0, wrappedStream.available());
}

@Test
public void testGetBytesFromOneLessAvailable() throws IOException {
InputStream wrappedStream = new ByteArrayInputStream(testData);
InputStream stream = new LyingInputStream(wrappedStream, wrappedStream.available() - 1);
byte[] bytes = StreamUtils.getBytesWithoutClosing(stream);
assertArrayEquals(testData, bytes);
assertEquals(0, wrappedStream.available());
}
}
Loading