Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,16 @@ public interface DataflowStreamingPipelineOptions extends PipelineOptions {

void setIsWindmillServiceDirectPathEnabled(boolean isWindmillServiceDirectPathEnabled);

/**
* The maximum size of cached entries in bytes. Entries (eg: values, bags) larger than this limit
* will not be cached by the windmill state cache
*/
@Description("The maximum size of cached entries in bytes.")
@Default.Long(Long.MAX_VALUE)
Long getMaxWindmillStateCacheEntryBytes();

void setMaxWindmillStateCacheEntryBytes(Long value);

/**
* Factory for creating local Windmill address. Reads from system propery 'windmill.hostport' for
* backwards compatibility.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import org.apache.beam.runners.dataflow.worker.streaming.WeightedSemaphore;
import org.apache.beam.runners.dataflow.worker.streaming.WorkHeartbeatResponseProcessor;
import org.apache.beam.runners.dataflow.worker.streaming.config.ComputationConfig;
import org.apache.beam.runners.dataflow.worker.streaming.config.ComputationConfig.Fetcher;
import org.apache.beam.runners.dataflow.worker.streaming.config.FixedGlobalConfigHandle;
import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingApplianceComputationConfigFetcher;
import org.apache.beam.runners.dataflow.worker.streaming.config.StreamingEngineComputationConfigFetcher;
Expand Down Expand Up @@ -113,6 +114,7 @@
import org.apache.beam.sdk.io.FileSystems;
import org.apache.beam.sdk.io.gcp.bigquery.BigQuerySinkMetrics;
import org.apache.beam.sdk.metrics.MetricsEnvironment;
import org.apache.beam.sdk.options.ExperimentalOptions;
import org.apache.beam.sdk.util.construction.CoderTranslation;
import org.apache.beam.sdk.values.WindowedValues;
import org.apache.beam.vendor.grpc.v1p69p0.io.grpc.auth.MoreCallCredentials;
Expand Down Expand Up @@ -633,6 +635,10 @@ public static StreamingDataflowWorker fromOptions(DataflowWorkerHarnessOptions o
WindmillStateCache.builder()
.setSizeMb(options.getWorkerCacheMb())
.setSupportMapViaMultimap(options.isEnableStreamingEngine())
.setMaxCachedEntryBytes(options.getMaxWindmillStateCacheEntryBytes())
.setEnableHistogram(
!ExperimentalOptions.hasExperiment(
options, "disable_windmill_user_state_cache_histogram"))
.build();

GrpcWindmillStreamFactory.Builder windmillStreamFactoryBuilder =
Expand All @@ -651,6 +657,15 @@ public static StreamingDataflowWorker fromOptions(DataflowWorkerHarnessOptions o
windmillStateCache::forComputation,
ID_GENERATOR));

Fetcher configFetcher = configFetcherComputationStateCacheAndWindmillClient.configFetcher();
configFetcher
.getGlobalConfigHandle()
.registerConfigObserver(
config -> {
windmillStateCache.setMaxCachedEntryBytesOverride(
config.userWorkerJobSettings().getMaxCachedEntryBytes());
});
Comment thread
arunpandianp marked this conversation as resolved.
Comment thread
arunpandianp marked this conversation as resolved.

ComputationStateCache computationStateCache =
configFetcherComputationStateCacheAndWindmillClient.computationStateCache();
WindmillServerStub windmillServer =
Expand Down Expand Up @@ -689,7 +704,7 @@ public static StreamingDataflowWorker fromOptions(DataflowWorkerHarnessOptions o
return new StreamingDataflowWorker(
windmillServer,
clientId,
configFetcherComputationStateCacheAndWindmillClient.configFetcher(),
configFetcher,
computationStateCache,
windmillStateCache,
workExecutor,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.dataflow.worker.util;

/** A simple histogram to track byte sizes. */
public class SimpleByteHistogram {
private final long[] buckets = new long[7];

public void add(long weight) {
buckets[getBucket(weight)]++;
}

private int getBucket(long weight) {
if (weight < 128) return 0;
if (weight < 256) return 1;
if (weight < 512) return 2;
if (weight < 1024) return 3;
if (weight < 10 * 1024) return 4;
if (weight < 1024 * 1024) return 5;
return 6;
}

public String format() {
return String.format(
"[<128B:%d, <256B:%d, <512B:%d, <1KB:%d, <10KB:%d, <1MB:%d, >=1MB:%d]",
buckets[0], buckets[1], buckets[2], buckets[3], buckets[4], buckets[5], buckets[6]);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.beam.runners.dataflow.worker.status.BaseStatusServlet;
import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider;
import org.apache.beam.runners.dataflow.worker.streaming.ShardedKey;
import org.apache.beam.runners.dataflow.worker.util.SimpleByteHistogram;
import org.apache.beam.runners.dataflow.worker.util.common.worker.InternedByteString;
import org.apache.beam.sdk.state.State;
import org.apache.beam.sdk.util.Weighted;
Expand Down Expand Up @@ -75,9 +76,18 @@ public class WindmillStateCache implements StatusDataProvider {
private final ConcurrentMap<WindmillComputationKey, ForKey> keyIndex;
private final long workerCacheBytes; // Copy workerCacheMb and convert to bytes.
private final boolean supportMapViaMultimap;

WindmillStateCache(long sizeMb, boolean supportMapViaMultimap) {
private final long defaultMaxCachedEntryBytes;
private final boolean enableHistogram;
private volatile long maxCachedEntryBytesOverride = -1L;

WindmillStateCache(
long sizeMb,
boolean supportMapViaMultimap,
long maxCachedEntryBytes,
boolean enableHistogram) {
this.workerCacheBytes = sizeMb * MEGABYTES;
this.defaultMaxCachedEntryBytes = maxCachedEntryBytes;
this.enableHistogram = enableHistogram;
int stateCacheConcurrencyLevel =
Math.max(STATE_CACHE_CONCURRENCY_LEVEL, Runtime.getRuntime().availableProcessors());
this.stateCache =
Expand All @@ -99,22 +109,48 @@ public interface Builder {

Builder setSupportMapViaMultimap(boolean supportMapViaMultimap);

Builder setMaxCachedEntryBytes(long maxCachedEntryBytes);

Builder setEnableHistogram(boolean enableHistogram);

WindmillStateCache build();
}

public static Builder builder() {
return new AutoBuilder_WindmillStateCache_Builder().setSupportMapViaMultimap(false);
return new AutoBuilder_WindmillStateCache_Builder()
.setSupportMapViaMultimap(false)
.setMaxCachedEntryBytes(Long.MAX_VALUE)
.setEnableHistogram(true);
}

public void setMaxCachedEntryBytesOverride(long limit) {
this.maxCachedEntryBytesOverride = limit;
}

private long getMaxCachedEntryBytesLimit() {
long override = maxCachedEntryBytesOverride;
return override >= 0 ? override : defaultMaxCachedEntryBytes;
}

private EntryStats calculateEntryStats() {
EntryStats stats = new EntryStats();
BiConsumer<StateId, StateCacheEntry> consumer =
(stateId, stateCacheEntry) -> {
stats.entries++;
stats.idWeight += stateId.getWeight();
stats.entryWeight += stateCacheEntry.getWeight();
long idWeight = stateId.getWeight();
stats.idWeight += idWeight;
long entryWeight = stateCacheEntry.getWeight();
stats.entryWeight += entryWeight;
stats.entryValues += stateCacheEntry.values.size();
stats.maxEntryValues = Math.max(stats.maxEntryValues, stateCacheEntry.values.size());
if (enableHistogram) {
stats.addKeyWeight(idWeight);
stats.addEntryWeight(entryWeight);
stateCacheEntry.values.forEach(
(encodedAddress, weightedValue) -> {
Comment thread
arunpandianp marked this conversation as resolved.
stats.addValueWeight(weightedValue.weight);
});
Comment on lines +149 to +152
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Iterating over all values in every cache entry to calculate the weight distribution increases the complexity of calculateEntryStats from $O(\text{entries})$ to $O(\text{total values})$. For large caches with many values per entry (e.g., large BagState or MapState), this could cause noticeable delays when accessing the worker's status page. Since this is for a debug page, it might be acceptable, but consider if the performance impact has been evaluated for very large states.

}
};
stateCache.asMap().forEach(consumer);
return stats;
Expand Down Expand Up @@ -142,23 +178,44 @@ public ForComputation forComputation(String computation) {
@Override
public void appendSummaryHtml(PrintWriter response) {
response.println("Cache Stats: <br><table>");
response.println(
"<tr><th>Hit Ratio</th><th>Evictions</th><th>Entries</th>"
+ "<th>Entry Values</th><th>Max Entry Values</th>"
+ "<th>Id Weight</th><th>Entry Weight</th><th>Max Weight</th><th>Keys</th>"
+ "</tr><tr>");
CacheStats cacheStats = stateCache.stats();
EntryStats entryStats = calculateEntryStats();
response.println("<td>" + cacheStats.hitRate() + "</td>");
response.println("<td>" + cacheStats.evictionCount() + "</td>");
response.println("<td>" + entryStats.entries + "(" + stateCache.size() + " inc. weak) </td>");
response.println("<td>" + entryStats.entryValues + "</td>");
response.println("<td>" + entryStats.maxEntryValues + "</td>");
response.println("<td>" + entryStats.idWeight / MEGABYTES + "MB</td>");
response.println("<td>" + entryStats.entryWeight / MEGABYTES + "MB</td>");
response.println("<td>" + getMaxWeight() / MEGABYTES + "MB</td>");
response.println("<td>" + keyIndex.size() + "</td>");
response.println("</tr></table><br>");

response.println("<tr><th>Hit Ratio</th><td>" + cacheStats.hitRate() + "</td></tr>");
response.println("<tr><th>Evictions</th><td>" + cacheStats.evictionCount() + "</td></tr>");
response.println(
"<tr><th>Entries</th><td>"
+ entryStats.entries
+ " ("
+ stateCache.size()
+ " inc. weak)</td></tr>");
response.println("<tr><th>Entry Values</th><td>" + entryStats.entryValues + "</td></tr>");
response.println(
"<tr><th>Max Entry Values</th><td>" + entryStats.maxEntryValues + "</td></tr>");
response.println(
"<tr><th>Id Weight</th><td>" + entryStats.idWeight / MEGABYTES + "MB</td></tr>");
response.println(
"<tr><th>Entry Weight</th><td>" + entryStats.entryWeight / MEGABYTES + "MB</td></tr>");
response.println("<tr><th>Max Weight</th><td>" + getMaxWeight() / MEGABYTES + "MB</td></tr>");
response.println("<tr><th>Keys</th><td>" + keyIndex.size() + "</td></tr>");
response.println(
"<tr><th>Entry Size Limit</th><td>" + getMaxCachedEntryBytesLimit() + " bytes</td></tr>");
if (enableHistogram) {
response.println(
"<tr><th>Entry Weight Dist</th><td>"
+ entryStats.entryWeightHistogram.format()
+ "</td></tr>");
response.println(
"<tr><th>Value Weight Dist</th><td>"
+ entryStats.valueWeightHistogram.format()
+ "</td></tr>");
response.println(
"<tr><th>Key Weight Dist</th><td>"
+ entryStats.keyWeightHistogram.format()
+ "</td></tr>");
}

response.println("</table><br>");
}

public BaseStatusServlet statusServlet() {
Expand All @@ -180,6 +237,21 @@ private static class EntryStats {
long entryWeight;
long entryValues;
long maxEntryValues;
SimpleByteHistogram entryWeightHistogram = new SimpleByteHistogram();
SimpleByteHistogram valueWeightHistogram = new SimpleByteHistogram();
SimpleByteHistogram keyWeightHistogram = new SimpleByteHistogram();

void addEntryWeight(long weight) {
entryWeightHistogram.add(weight);
}

void addValueWeight(long weight) {
valueWeightHistogram.add(weight);
}

void addKeyWeight(long weight) {
keyWeightHistogram.add(weight);
}
}

/**
Expand Down Expand Up @@ -413,7 +485,15 @@ public <T extends State> void put(
}

public void persist() {
localCache.forEach(stateCache::put);
long limit = WindmillStateCache.this.getMaxCachedEntryBytesLimit();
localCache.forEach(
(id, entry) -> {
if (entry.getWeight() <= limit) {
stateCache.put(id, entry);
Comment on lines +491 to +492
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The limit is compared against entry.getWeight(), which includes the overhead of the StateCacheEntry and StateId (approximately 136 bytes). This means that if a user sets a small limit (e.g., 100 bytes), no values will ever be cached. The documentation in DataflowStreamingPipelineOptions says "maximum size of cached values", which might lead users to believe it only applies to the payload. Consider clarifying the documentation or adjusting the logic to only account for the value size if that was the intent.

} else {
stateCache.invalidate(id);
}
});
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.dataflow.worker.util;

import static org.junit.Assert.assertEquals;

import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

/** Tests for {@link SimpleByteHistogram}. */
@RunWith(JUnit4.class)
public class SimpleByteHistogramTest {

@Test
public void testHistogram() {
SimpleByteHistogram histogram = new SimpleByteHistogram();
histogram.add(10); // <128B
histogram.add(127); // <128B
histogram.add(128); // <256B
histogram.add(255); // <256B
histogram.add(256); // <512B
histogram.add(511); // <512B
histogram.add(512); // <1KB
histogram.add(1023); // <1KB
histogram.add(1024); // <10KB
histogram.add(10240 - 1); // <10KB
histogram.add(10240); // <1MB
histogram.add(1048576 - 1); // <1MB
histogram.add(1048576); // >=1MB
histogram.add(2000000); // >=1MB

String expected = "[<128B:2, <256B:2, <512B:2, <1KB:2, <10KB:2, <1MB:2, >=1MB:2]";
assertEquals(expected, histogram.format());
}
}
Loading
Loading