Skip to content

Commit 2c340aa

Browse files
dkropachevclaude
andcommitted
Move LWT preserve-replica-order routing logic to BasicLoadBalancingPolicy
Move newQueryPlanPreserveReplicas, the RequestRoutingMethod enum, LWT routing dispatch, and randomNextInt from DefaultLoadBalancingPolicy to BasicLoadBalancingPolicy so both policies support preserve-replica-order routing for LWT requests. DefaultLoadBalancingPolicy now only overrides newQueryPlanRegular with its rack-aware shuffling and slow replica avoidance logic. Refactor LWT routing tests into an abstract LwtRoutingTestBase that runs against both BasicLoadBalancingPolicy and DefaultLoadBalancingPolicy, fixing pre-existing test failures caused by incorrect maxNodesPerRemoteDc configuration and missing routing token setup. Co-Authored-By: Claude Opus 4.6 <[email protected]>
1 parent 284121f commit 2c340aa

6 files changed

Lines changed: 588 additions & 538 deletions

File tree

core/src/main/java/com/datastax/oss/driver/internal/core/loadbalancing/BasicLoadBalancingPolicy.java

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import com.datastax.oss.driver.api.core.ConsistencyLevel;
2727
import com.datastax.oss.driver.api.core.CqlIdentifier;
28+
import com.datastax.oss.driver.api.core.RequestRoutingType;
2829
import com.datastax.oss.driver.api.core.config.DefaultDriverOption;
2930
import com.datastax.oss.driver.api.core.config.DriverExecutionProfile;
3031
import com.datastax.oss.driver.api.core.context.DriverContext;
@@ -63,6 +64,9 @@
6364
import edu.umd.cs.findbugs.annotations.NonNull;
6465
import edu.umd.cs.findbugs.annotations.Nullable;
6566
import java.nio.ByteBuffer;
67+
import java.util.ArrayList;
68+
import java.util.Collections;
69+
import java.util.HashMap;
6670
import java.util.LinkedHashSet;
6771
import java.util.List;
6872
import java.util.Map;
@@ -71,8 +75,10 @@
7175
import java.util.Queue;
7276
import java.util.Set;
7377
import java.util.UUID;
78+
import java.util.concurrent.ThreadLocalRandom;
7479
import java.util.concurrent.atomic.AtomicInteger;
7580
import java.util.function.IntUnaryOperator;
81+
import java.util.stream.Collectors;
7682
import net.jcip.annotations.ThreadSafe;
7783
import org.slf4j.Logger;
7884
import org.slf4j.LoggerFactory;
@@ -113,6 +119,11 @@
113119
@ThreadSafe
114120
public class BasicLoadBalancingPolicy implements LoadBalancingPolicy {
115121

122+
public enum RequestRoutingMethod {
123+
REGULAR,
124+
PRESERVE_REPLICA_ORDER
125+
}
126+
116127
private static final Logger LOG = LoggerFactory.getLogger(BasicLoadBalancingPolicy.class);
117128

118129
protected static final IntUnaryOperator INCREMENT = i -> (i == Integer.MAX_VALUE) ? 0 : i + 1;
@@ -127,6 +138,7 @@ public class BasicLoadBalancingPolicy implements LoadBalancingPolicy {
127138
private final int maxNodesPerRemoteDc;
128139
private final boolean allowDcFailoverForLocalCl;
129140
private final ConsistencyLevel defaultConsistencyLevel;
141+
private final RequestRoutingMethod lwtRequestRoutingMethod;
130142

131143
// private because they should be set in init() and never be modified after
132144
private volatile DistanceReporter distanceReporter;
@@ -154,6 +166,34 @@ public BasicLoadBalancingPolicy(@NonNull DriverContext context, @NonNull String
154166
new LinkedHashSet<>(
155167
profile.getStringList(
156168
DefaultDriverOption.LOAD_BALANCING_DC_FAILOVER_PREFERRED_REMOTE_DCS));
169+
this.lwtRequestRoutingMethod = parseLwtRequestRoutingMethod();
170+
}
171+
172+
@NonNull
173+
private RequestRoutingMethod parseLwtRequestRoutingMethod() {
174+
String methodString =
175+
profile.getString(DefaultDriverOption.LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD);
176+
try {
177+
return RequestRoutingMethod.valueOf(methodString.toUpperCase());
178+
} catch (IllegalArgumentException e) {
179+
LOG.warn(
180+
"[{}] Unknown request routing method '{}', defaulting to PRESERVE_REPLICA_ORDER",
181+
logPrefix,
182+
methodString);
183+
return RequestRoutingMethod.PRESERVE_REPLICA_ORDER;
184+
}
185+
}
186+
187+
@NonNull
188+
public RequestRoutingMethod getRequestRoutingMethod(@Nullable Request request) {
189+
if (request == null) {
190+
return RequestRoutingMethod.REGULAR;
191+
}
192+
if (request.getRequestRoutingType() == RequestRoutingType.LWT) {
193+
return lwtRequestRoutingMethod;
194+
} else {
195+
return RequestRoutingMethod.REGULAR;
196+
}
157197
}
158198

159199
/**
@@ -260,6 +300,17 @@ protected NodeDistanceEvaluator createNodeDistanceEvaluator(
260300
@NonNull
261301
@Override
262302
public Queue<Node> newQueryPlan(@Nullable Request request, @Nullable Session session) {
303+
switch (getRequestRoutingMethod(request)) {
304+
case PRESERVE_REPLICA_ORDER:
305+
return newQueryPlanPreserveReplicas(request, session);
306+
case REGULAR:
307+
default:
308+
return newQueryPlanRegular(request, session);
309+
}
310+
}
311+
312+
@NonNull
313+
protected Queue<Node> newQueryPlanRegular(@Nullable Request request, @Nullable Session session) {
263314
// Take a snapshot since the set is concurrent:
264315
Object[] currentNodes = liveNodes.dc(localDc).toArray();
265316

@@ -294,6 +345,101 @@ public Queue<Node> newQueryPlan(@Nullable Request request, @Nullable Session ses
294345
return maybeAddDcFailover(request, plan);
295346
}
296347

348+
/**
349+
* Builds a query plan that preserves replica order: local replicas, remote replicas, local
350+
* non-replicas (rotated), remote non-replicas (rotated).
351+
*/
352+
@NonNull
353+
protected Queue<Node> newQueryPlanPreserveReplicas(
354+
@Nullable Request request, @Nullable Session session) {
355+
List<Node> replicas = getReplicas(request, session);
356+
String localDc = getLocalDatacenter();
357+
List<Node> queryPlan = new ArrayList<>();
358+
359+
if (localDc == null) {
360+
// No local DC: all replicas first, then rotated non-replicas
361+
List<Node> allNodes = new ArrayList<>();
362+
for (Object obj : getLiveNodes().dc(null).toArray()) {
363+
allNodes.add((Node) obj);
364+
}
365+
queryPlan.addAll(replicas);
366+
addRotatedNonReplicas(queryPlan, allNodes, replicas, request);
367+
} else {
368+
// With local DC: prioritize local, then remote
369+
Map<String, List<Node>> nodesByDc = getAllNodesByDc();
370+
addReplicasByDc(queryPlan, replicas, localDc);
371+
addNonReplicasByDc(queryPlan, nodesByDc, replicas, localDc, request);
372+
}
373+
374+
return new SimpleQueryPlan(queryPlan.toArray());
375+
}
376+
377+
/** Collect all live nodes grouped by DC. */
378+
private Map<String, List<Node>> getAllNodesByDc() {
379+
Map<String, List<Node>> nodesByDc = new HashMap<>();
380+
for (String dc : getLiveNodes().dcs()) {
381+
List<Node> dcNodes = new ArrayList<>();
382+
for (Object obj : getLiveNodes().dc(dc).toArray()) {
383+
dcNodes.add((Node) obj);
384+
}
385+
nodesByDc.put(dc, dcNodes);
386+
}
387+
return nodesByDc;
388+
}
389+
390+
/** Add replicas with local DC first, then remote DCs. */
391+
private void addReplicasByDc(List<Node> queryPlan, List<Node> replicas, String localDc) {
392+
replicas.stream()
393+
.filter(r -> Objects.equals(r.getDatacenter(), localDc))
394+
.forEach(queryPlan::add);
395+
replicas.stream()
396+
.filter(r -> !Objects.equals(r.getDatacenter(), localDc))
397+
.forEach(queryPlan::add);
398+
}
399+
400+
/** Add non-replicas with local DC first, then remote DCs (all rotated). */
401+
private void addNonReplicasByDc(
402+
List<Node> queryPlan,
403+
Map<String, List<Node>> nodesByDc,
404+
List<Node> replicas,
405+
String localDc,
406+
Request request) {
407+
// Local DC non-replicas first
408+
addRotatedNonReplicas(
409+
queryPlan, nodesByDc.getOrDefault(localDc, new ArrayList<>()), replicas, request);
410+
// Remote DC non-replicas
411+
for (Map.Entry<String, List<Node>> entry : nodesByDc.entrySet()) {
412+
if (!Objects.equals(entry.getKey(), localDc)) {
413+
addRotatedNonReplicas(queryPlan, entry.getValue(), replicas, request);
414+
}
415+
}
416+
}
417+
418+
/** Add non-replica nodes from given list with rotation. */
419+
private void addRotatedNonReplicas(
420+
List<Node> queryPlan, List<Node> nodes, List<Node> replicas, Request request) {
421+
List<Node> nonReplicas =
422+
nodes.stream().filter(n -> !replicas.contains(n)).collect(Collectors.toList());
423+
if (!nonReplicas.isEmpty()) {
424+
rotateNonReplicas(nonReplicas, request);
425+
queryPlan.addAll(nonReplicas);
426+
}
427+
}
428+
429+
/** Rotates nodes based on routing key (consistent) or randomly. */
430+
private void rotateNonReplicas(List<Node> nodes, @Nullable Request request) {
431+
if (nodes.size() <= 1) return;
432+
433+
int rotationAmount =
434+
(request != null && request.getRoutingKey() != null)
435+
? Math.abs(request.getRoutingKey().hashCode()) % nodes.size()
436+
: randomNextInt(nodes.size());
437+
438+
if (rotationAmount > 0) {
439+
Collections.rotate(nodes, -rotationAmount);
440+
}
441+
}
442+
297443
@NonNull
298444
protected List<Node> getReplicas(@Nullable Request request, @Nullable Session session) {
299445
if (request == null || session == null) {
@@ -441,6 +587,11 @@ protected Object[] computeNodes() {
441587
return new CompositeQueryPlan(queryPlans);
442588
}
443589

590+
/** Exposed as a protected method so that it can be accessed by tests */
591+
protected int randomNextInt(int bound) {
592+
return ThreadLocalRandom.current().nextInt(bound);
593+
}
594+
444595
/** Exposed as a protected method so that it can be accessed by tests */
445596
protected void shuffleHead(Object[] currentNodes, int headLength) {
446597
ArrayUtils.shuffleHead(currentNodes, headLength);

0 commit comments

Comments
 (0)