diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow.json index e9d869cc508c..8144784f5f02 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow.json @@ -1,5 +1,5 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run!", - "modification": 1, + "modification": 2, } diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json index e623d3373a93..e1b083e439cc 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run!", - "modification": 1, + "modification": 3, } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/PubsubReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/PubsubReader.java index b60cb84415ff..0c80909a0f3b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/PubsubReader.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/PubsubReader.java @@ -117,20 +117,12 @@ public NativeReader create( @Override public NativeReaderIterator> iterator() throws IOException { - return new PubsubReaderIterator(context.getWorkItem()); + return new PubsubReaderIterator(); } class PubsubReaderIterator extends WindmillReaderIteratorBase { - protected PubsubReaderIterator(Windmill.WorkItem work) { - super(work, skipUndecodableElements); - } - - @Override - public boolean advance() throws IOException { - if (context.workIsFailed()) { - return false; - } - return super.advance(); + protected PubsubReaderIterator() { + super(context, skipUndecodableElements); } @Override diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index f75d452b211b..1ff2c1bc4a1c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -56,6 +56,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInput; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputState; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; +import org.apache.beam.runners.dataflow.worker.util.common.worker.WorkExecutor; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataId; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest; @@ -157,6 +158,9 @@ public class StreamingModeExecutionContext extends DataflowExecutionContext activeReader; + private @Nullable WorkExecutor workExecutor; + private boolean finishKeyCalled = false; + public StreamingModeExecutionContext( CounterFactory counterFactory, String computationId, @@ -240,9 +244,12 @@ public void start( Work work, WindmillStateReader stateReader, SideInputStateFetcher sideInputStateFetcher, - Windmill.WorkItemCommitRequest.Builder outputBuilder) { + Windmill.WorkItemCommitRequest.Builder outputBuilder, + WorkExecutor workExecutor) { this.key = key; this.work = work; + this.workExecutor = workExecutor; + this.finishKeyCalled = false; this.computationKey = WindmillComputationKey.create(computationId, work.getShardedKey()); this.sideInputStateFetcher = sideInputStateFetcher; StreamingGlobalConfig config = globalConfigHandle.getConfig(); @@ -270,6 +277,17 @@ public void start( } } + public void finishKey() { + checkState(!finishKeyCalled, "finishKey was already called"); + checkNotNull(workExecutor, "workExecutor must be set before calling finishKey()"); + try { + workExecutor.finishKey(); + } catch (Exception e) { + throw new RuntimeException(e); + } + this.finishKeyCalled = true; + } + /** * Ensure that the processing time is greater than any fired processing time timers. Otherwise, a * trigger could ignore the timer and orphan the window. @@ -451,6 +469,7 @@ public void invalidateCache() { } public Map> flushState() { + checkState(finishKeyCalled, "finishKey must be called before flushState"); Map> callbacks = new HashMap<>(); for (StepContext stepContext : getAllStepContexts()) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/UngroupedWindmillReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/UngroupedWindmillReader.java index 8ef0bf80323a..f6924493f190 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/UngroupedWindmillReader.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/UngroupedWindmillReader.java @@ -109,20 +109,12 @@ public NativeReader create( @Override public NativeReaderIterator> iterator() throws IOException { - return new UngroupedWindmillReaderIterator(context.getWorkItem()); + return new UngroupedWindmillReaderIterator(); } class UngroupedWindmillReaderIterator extends WindmillReaderIteratorBase { - UngroupedWindmillReaderIterator(Windmill.WorkItem work) { - super(work, skipUndecodableElements); - } - - @Override - public boolean advance() throws IOException { - if (context.workIsFailed()) { - return false; - } - return super.advance(); + UngroupedWindmillReaderIterator() { + super(context, skipUndecodableElements); } @Override diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java index 7e6508a4788c..075a1a8a4250 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBase.java @@ -34,6 +34,7 @@ */ public abstract class WindmillReaderIteratorBase extends NativeReader.NativeReaderIterator> { + private final StreamingModeExecutionContext context; private final Windmill.WorkItem work; private int bundleIndex = 0; private int messageIndex = -1; @@ -42,9 +43,10 @@ public abstract class WindmillReaderIteratorBase private static final Logger LOG = LoggerFactory.getLogger(WindmillReaderIteratorBase.class); protected WindmillReaderIteratorBase( - Windmill.WorkItem work, ValueProvider skipUndecodableElements) { + StreamingModeExecutionContext context, ValueProvider skipUndecodableElements) { + this.context = context; this.skipUndecodableElements = skipUndecodableElements; - this.work = work; + this.work = context.getWorkItem(); } @Override @@ -54,9 +56,14 @@ public boolean start() throws IOException { @Override public boolean advance() throws IOException { + if (context.workIsFailed()) { + throw new WorkItemCancelledException(context.getWorkItem().getShardingKey()); + } + while (true) { if (bundleIndex >= work.getMessageBundlesCount()) { current = null; + context.finishKey(); return false; } Windmill.InputMessageBundle bundle = work.getMessageBundles(bundleIndex); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java index 173b254f6395..488684769bd9 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindowingWindmillReader.java @@ -156,6 +156,7 @@ public NativeReaderIterator>> iterator() throw return new NativeReaderIterator>>() { @Override public boolean start() throws IOException { + context.finishKey(); return false; } @@ -182,6 +183,7 @@ public boolean start() throws IOException { @Override public boolean advance() throws IOException { current = null; + context.finishKey(); return false; } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourceOperationExecutor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourceOperationExecutor.java index 31528e96e07f..a1321d57ebb6 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourceOperationExecutor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourceOperationExecutor.java @@ -89,6 +89,9 @@ public void execute() throws Exception { LOG.debug("Source operation execution complete"); } + @Override + public void finishKey() throws Exception {} + @Override public SourceOperationResponse getResponse() { return response; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSources.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSources.java index c00edffeaf95..29d5fb3561a1 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSources.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSources.java @@ -824,6 +824,7 @@ public boolean start() throws IOException { } try { if (!reader.start()) { + context.finishKey(); return false; } } catch (Exception e) { @@ -841,10 +842,13 @@ public boolean advance() throws IOException { // that there are regular checkpoints and that state does not become too large. BackOff backoff = backoffFactory.backoff(); while (true) { + if (context.workIsFailed()) { + throw new WorkItemCancelledException(context.getWorkItem().getShardingKey()); + } if (elemsRead >= maxElems || Instant.now().isAfter(endTime) - || context.isSinkFullHintSet() - || context.workIsFailed()) { + || context.isSinkFullHintSet()) { + context.finishKey(); return false; } try { @@ -857,6 +861,7 @@ public boolean advance() throws IOException { } long nextBackoff = backoff.nextBackOffMillis(); if (nextBackoff == BackOff.STOP) { + context.finishKey(); return false; } Uninterruptibles.sleepUninterruptibly(nextBackoff, TimeUnit.MILLISECONDS); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java index 8dc681fc640c..b4f3a22a7f52 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutor.java @@ -74,7 +74,7 @@ public final void executeWork( SideInputStateFetcher sideInputStateFetcher, Windmill.WorkItemCommitRequest.Builder outputBuilder) throws Exception { - context().start(key, work, stateReader, sideInputStateFetcher, outputBuilder); + context().start(key, work, stateReader, sideInputStateFetcher, outputBuilder, workExecutor()); workExecutor().execute(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/FlattenOperation.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/FlattenOperation.java index 4847e9f2ea9c..af1b2b9c48bd 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/FlattenOperation.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/FlattenOperation.java @@ -43,6 +43,9 @@ public void process(Object elem) throws Exception { } } + @Override + public void finishKey() throws Exception {} + @Override public boolean supportsRestart() { return true; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutor.java index 58b95f286d55..3c33e1904069 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutor.java @@ -112,6 +112,13 @@ public void execute() throws Exception { // TODO: support for success / failure ports? } + @Override + public void finishKey() throws Exception { + for (Operation op : operations) { + op.finishKey(); + } + } + @Override public NativeReader.Progress getWorkerProgress() throws Exception { return getReadOperation().getProgress(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/Operation.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/Operation.java index b7b4e255cfa5..b630da33cfad 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/Operation.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/Operation.java @@ -137,6 +137,9 @@ public void finish() throws Exception { } } + /** Called when all elements for a specific key have been processed. */ + public abstract void finishKey() throws Exception; + /** Aborts this Operation's execution. */ public void abort() throws Exception { synchronized (initializationStateLock) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ParDoOperation.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ParDoOperation.java index 27b6e9d1fb35..5aec82073366 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ParDoOperation.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ParDoOperation.java @@ -46,12 +46,17 @@ public void process(Object elem) throws Exception { } @Override - public void finish() throws Exception { + // Batch mode does not use this method and instead relies on BatchModeUngroupingParDoFn + // to process timers per key. + public void finishKey() throws Exception { try (Closeable scope = context.enterProcessTimers()) { checkStarted(); fn.processTimers(); } + } + @Override + public void finish() throws Exception { try (Closeable scope = context.enterFinish()) { fn.finishBundle(); super.finish(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ReadOperation.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ReadOperation.java index d6b020483d4c..fabc8d6af25b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ReadOperation.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ReadOperation.java @@ -271,6 +271,9 @@ public void finish() throws Exception { } } + @Override + public void finishKey() throws Exception {} + @Override public void abort() throws Exception { if (readerIterator != null) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/WorkExecutor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/WorkExecutor.java index b7170c80ced9..1083fdbb9c42 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/WorkExecutor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/WorkExecutor.java @@ -34,6 +34,9 @@ public interface WorkExecutor extends AutoCloseable { /** Executes the task. */ public abstract void execute() throws Exception; + /** Called when all elements for a specific key have been processed. */ + void finishKey() throws Exception; + /** * Returns the worker's current progress. * diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/WriteOperation.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/WriteOperation.java index 673140d58d89..d28e7f3e5d3d 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/WriteOperation.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/WriteOperation.java @@ -105,6 +105,9 @@ public void finish() throws Exception { } } + @Override + public void finishKey() throws Exception {} + @Override public void abort() throws Exception { if (writer == null) { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorTest.java index c519efd4172c..396c8db87e6b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorTest.java @@ -104,6 +104,9 @@ public void abort() throws Exception { aborted = true; super.abort(); } + + @Override + public void finishKey() throws Exception {} } // A mock ReadOperation fed to a MapTaskExecutor in test. @@ -312,6 +315,9 @@ public void start() throws Exception { Metrics.counter("TestMetric", "MetricCounter").inc(1L); } } + + @Override + public void finishKey() throws Exception {} }, new Operation(new OutputReceiver[] {}, context2) { @Override @@ -321,6 +327,9 @@ public void start() throws Exception { Metrics.counter("TestMetric", "MetricCounter").inc(2L); } } + + @Override + public void finishKey() throws Exception {} }, new Operation(new OutputReceiver[] {}, context3) { @Override @@ -330,6 +339,9 @@ public void start() throws Exception { Metrics.counter("TestMetric", "MetricCounter").inc(3L); } } + + @Override + public void finishKey() throws Exception {} }); try (IntrinsicMapTaskExecutor executor = diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java index d8a1d1b90d47..ff82a1ab5c4c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java @@ -3657,8 +3657,8 @@ public void testActiveWorkFailure() throws Exception { server.waitForAndGetCommitsWithTimeout(1, Duration.standardSeconds(5)); assertEquals(1, commits.size()); - assertEquals(0, BlockingFn.teardownCounter.get()); - assertEquals(1, BlockingFn.setupCounter.get()); + assertEquals(1, BlockingFn.teardownCounter.get()); + assertEquals(2, BlockingFn.setupCounter.get()); worker.stop(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java index 4bfa6efc8880..850c0988ac8a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java @@ -62,6 +62,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.config.FakeGlobalConfigHandle; import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingGlobalConfig; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; +import org.apache.beam.runners.dataflow.worker.util.common.worker.WorkExecutor; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; @@ -99,6 +100,7 @@ public class StreamingModeExecutionContextTest { @Rule public transient Timeout globalTimeout = Timeout.seconds(600); @Mock private SideInputStateFetcher sideInputStateFetcher; @Mock private WindmillStateReader stateReader; + @Mock private WorkExecutor workExecutor; private static final String COMPUTATION_ID = "computationId"; @@ -152,7 +154,7 @@ COMPUTATION_ID, new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.cla } @Test - public void testTimerInternalsSetTimer() { + public void testTimerInternalsSetTimer() throws Exception { Windmill.WorkItemCommitRequest.Builder outputBuilder = Windmill.WorkItemCommitRequest.newBuilder(); NameContext nameContext = NameContextsForTests.nameContextForTest(); @@ -168,7 +170,8 @@ public void testTimerInternalsSetTimer() { Watermarks.builder().setInputDataWatermark(new Instant(1000)).build()), stateReader, sideInputStateFetcher, - outputBuilder); + outputBuilder, + workExecutor); TimerInternals timerInternals = stepContext.timerInternals(); @@ -179,6 +182,7 @@ public void testTimerInternalsSetTimer() { new Instant(5000), TimeDomain.EVENT_TIME, CausedByDrain.NORMAL)); + executionContext.finishKey(); executionContext.flushState(); Windmill.Timer timer = outputBuilder.buildPartial().getOutputTimers(0); @@ -218,7 +222,8 @@ public void testTimerInternalsProcessingTimeSkew() { Watermarks.builder().setInputDataWatermark(new Instant(1000)).build()), stateReader, sideInputStateFetcher, - outputBuilder); + outputBuilder, + workExecutor); TimerInternals timerInternals = stepContext.timerInternals(); assertTrue(timerTimestamp.isBefore(timerInternals.currentProcessingTime())); } @@ -427,7 +432,8 @@ public void testStateTagEncodingBasedOnConfig() { Watermarks.builder().setInputDataWatermark(new Instant(1000)).build()), stateReader, sideInputStateFetcher, - outputBuilder); + outputBuilder, + workExecutor); assertEquals(expectedEncoding, executionContext.getWindmillTagEncoding().getClass()); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBaseTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBaseTest.java index 61e2f4250d06..539c38eeb1da 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBaseTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillReaderIteratorBaseTest.java @@ -19,6 +19,11 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import java.io.IOException; import java.util.ArrayList; @@ -40,8 +45,8 @@ public class WindmillReaderIteratorBaseTest { private static class TestWindmillReaderIterator extends WindmillReaderIteratorBase { protected TestWindmillReaderIterator( - Windmill.WorkItem work, ValueProvider skipUndecodableElements) { - super(work, skipUndecodableElements); + StreamingModeExecutionContext context, ValueProvider skipUndecodableElements) { + super(context, skipUndecodableElements); } @Override @@ -81,6 +86,51 @@ public void testSkipErrors() throws IOException { testForMessageBundleCounts(true, 0, 0, 1, 3, 0, 1, 0, 0, 0, 0); } + @Test + public void testWorkItemCancelledException() throws IOException { + StreamingModeExecutionContext mockContext = mock(StreamingModeExecutionContext.class); + when(mockContext.workIsFailed()).thenReturn(true); + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(0L).build(); + when(mockContext.getWorkItem()).thenReturn(workItem); + + try (TestWindmillReaderIterator iter = + new TestWindmillReaderIterator(mockContext, ValueProvider.StaticValueProvider.of(false))) { + iter.start(); + fail("Expected WorkItemCancelledException"); + } catch (WorkItemCancelledException e) { + // Expected + } + } + + @Test + public void testFinishKeyCalled() throws Exception { + StreamingModeExecutionContext mockContext = mock(StreamingModeExecutionContext.class); + when(mockContext.workIsFailed()).thenReturn(false); + Windmill.WorkItem workItem = + Windmill.WorkItem.newBuilder() + .setKey(ByteString.EMPTY) + .setWorkToken(0L) + .addMessageBundles( + Windmill.InputMessageBundle.newBuilder() + .setSourceComputationId("foo") + .addMessages( + Windmill.Message.newBuilder() + .setTimestamp(0) + .setData(ByteString.EMPTY) + .build()) + .build()) + .build(); + when(mockContext.getWorkItem()).thenReturn(workItem); + + try (TestWindmillReaderIterator iter = + new TestWindmillReaderIterator(mockContext, ValueProvider.StaticValueProvider.of(false))) { + assertTrue(iter.start()); + assertFalse(iter.advance()); // This should trigger finishKey + verify(mockContext).finishKey(); + } + } + private void testForMessageBundleCounts(int... messageBundleCounts) throws IOException { testForMessageBundleCounts(false, messageBundleCounts); } @@ -111,9 +161,13 @@ private void testForMessageBundleCounts(boolean skipErrors, int... messageBundle .setWorkToken(0L) .addAllMessageBundles(bundles) .build(); + + StreamingModeExecutionContext mockContext = mock(StreamingModeExecutionContext.class); + when(mockContext.getWorkItem()).thenReturn(workItem); + try (TestWindmillReaderIterator iter = new TestWindmillReaderIterator( - workItem, ValueProvider.StaticValueProvider.of(skipErrors))) { + mockContext, ValueProvider.StaticValueProvider.of(skipErrors))) { List actual = ReaderTestUtils.windowedValuesToValues( ReaderUtils.readRemainingFromIterator(iter, false)); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java index ce8ad32f71aa..d5cf2948d928 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java @@ -97,6 +97,7 @@ import org.apache.beam.runners.dataflow.worker.testing.TestCountingSource; import org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader; import org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader.NativeReaderIterator; +import org.apache.beam.runners.dataflow.worker.util.common.worker.WorkExecutor; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; @@ -643,7 +644,8 @@ public void testReadUnboundedReader() throws Exception { Watermarks.builder().setInputDataWatermark(new Instant(0)).build()), mock(WindmillStateReader.class), mock(SideInputStateFetcher.class), - Windmill.WorkItemCommitRequest.newBuilder()); + Windmill.WorkItemCommitRequest.newBuilder(), + mock(WorkExecutor.class)); @SuppressWarnings({"unchecked", "rawtypes"}) NativeReader>>> reader = @@ -1023,7 +1025,8 @@ public void testFailedWorkItemsAbort() throws Exception { dummyWork, mock(WindmillStateReader.class), mock(SideInputStateFetcher.class), - Windmill.WorkItemCommitRequest.newBuilder()); + Windmill.WorkItemCommitRequest.newBuilder(), + mock(WorkExecutor.class)); @SuppressWarnings({"unchecked", "rawtypes"}) NativeReader>>> reader = @@ -1038,14 +1041,19 @@ public void testFailedWorkItemsAbort() throws Exception { NativeReaderIterator>>> readerIterator = reader.iterator(); int numReads = 0; - while ((numReads == 0) ? readerIterator.start() : readerIterator.advance()) { - WindowedValue>> value = readerIterator.getCurrent(); - assertEquals(KV.of(0, numReads), value.getValue().getValue()); - numReads++; - // Fail the work item after reading two elements. - if (numReads == 2) { - dummyWork.setFailed(); + try { + while ((numReads == 0) ? readerIterator.start() : readerIterator.advance()) { + WindowedValue>> value = readerIterator.getCurrent(); + assertEquals(KV.of(0, numReads), value.getValue().getValue()); + numReads++; + // Fail the work item after reading two elements. + if (numReads == 2) { + dummyWork.setFailed(); + } } + fail("Expected WorkItemCancelledException"); + } catch (WorkItemCancelledException e) { + // Expected } assertThat(numReads, equalTo(2)); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ExecutorTestUtils.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ExecutorTestUtils.java index 2c35f4bf99db..d5e3b9c87139 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ExecutorTestUtils.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ExecutorTestUtils.java @@ -59,6 +59,9 @@ private static OutputReceiver[] createOutputReceivers(int numOutputs, CounterSet } return receivers; } + + @Override + public void finishKey() throws Exception {} } /** A {@code Reader} that yields a specified set of values. */ diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutorTest.java index 188466a50572..5d8f8eebb6f6 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutorTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/MapTaskExecutorTest.java @@ -100,6 +100,9 @@ public void abort() throws Exception { aborted = true; super.abort(); } + + @Override + public void finishKey() throws Exception {} } // A mock ReadOperation fed to a MapTaskExecutor in test. @@ -309,6 +312,9 @@ public void start() throws Exception { Metrics.counter("TestMetric", "MetricCounter").inc(1L); } } + + @Override + public void finishKey() throws Exception {} }, new Operation(new OutputReceiver[] {}, context2) { @Override @@ -318,6 +324,9 @@ public void start() throws Exception { Metrics.counter("TestMetric", "MetricCounter").inc(2L); } } + + @Override + public void finishKey() throws Exception {} }, new Operation(new OutputReceiver[] {}, context3) { @Override @@ -327,6 +336,9 @@ public void start() throws Exception { Metrics.counter("TestMetric", "MetricCounter").inc(3L); } } + + @Override + public void finishKey() throws Exception {} }); assertEquals(TimeUnit.MINUTES.toMillis(10), stateTracker.getNextBundleLullDurationReportMs()); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ParDoOperationTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ParDoOperationTest.java index ba327f92cc44..5d058b1968cb 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ParDoOperationTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ParDoOperationTest.java @@ -104,6 +104,7 @@ public void testRunParDoOperation() throws Exception { parDoOperation.process(""); parDoOperation.process("bob"); + parDoOperation.finishKey(); parDoOperation.finish(); parDoOperation.abort(); @@ -147,6 +148,7 @@ public void testParDoOperationContext() throws Exception { operation.start(); operation.process("hello"); + operation.finishKey(); operation.finish(); InOrder inOrder =