2525
2626import com .datastax .oss .driver .api .core .ConsistencyLevel ;
2727import com .datastax .oss .driver .api .core .CqlIdentifier ;
28+ import com .datastax .oss .driver .api .core .RequestRoutingType ;
2829import com .datastax .oss .driver .api .core .config .DefaultDriverOption ;
2930import com .datastax .oss .driver .api .core .config .DriverExecutionProfile ;
3031import com .datastax .oss .driver .api .core .context .DriverContext ;
6364import edu .umd .cs .findbugs .annotations .NonNull ;
6465import edu .umd .cs .findbugs .annotations .Nullable ;
6566import java .nio .ByteBuffer ;
67+ import java .util .ArrayList ;
68+ import java .util .Collections ;
69+ import java .util .HashMap ;
6670import java .util .LinkedHashSet ;
6771import java .util .List ;
6872import java .util .Map ;
7175import java .util .Queue ;
7276import java .util .Set ;
7377import java .util .UUID ;
78+ import java .util .concurrent .ThreadLocalRandom ;
7479import java .util .concurrent .atomic .AtomicInteger ;
7580import java .util .function .IntUnaryOperator ;
81+ import java .util .stream .Collectors ;
7682import net .jcip .annotations .ThreadSafe ;
7783import org .slf4j .Logger ;
7884import org .slf4j .LoggerFactory ;
113119@ ThreadSafe
114120public class BasicLoadBalancingPolicy implements LoadBalancingPolicy {
115121
122+ public enum RequestRoutingMethod {
123+ REGULAR ,
124+ PRESERVE_REPLICA_ORDER
125+ }
126+
116127 private static final Logger LOG = LoggerFactory .getLogger (BasicLoadBalancingPolicy .class );
117128
118129 protected static final IntUnaryOperator INCREMENT = i -> (i == Integer .MAX_VALUE ) ? 0 : i + 1 ;
@@ -127,6 +138,7 @@ public class BasicLoadBalancingPolicy implements LoadBalancingPolicy {
127138 private final int maxNodesPerRemoteDc ;
128139 private final boolean allowDcFailoverForLocalCl ;
129140 private final ConsistencyLevel defaultConsistencyLevel ;
141+ private final RequestRoutingMethod lwtRequestRoutingMethod ;
130142
131143 // private because they should be set in init() and never be modified after
132144 private volatile DistanceReporter distanceReporter ;
@@ -154,6 +166,34 @@ public BasicLoadBalancingPolicy(@NonNull DriverContext context, @NonNull String
154166 new LinkedHashSet <>(
155167 profile .getStringList (
156168 DefaultDriverOption .LOAD_BALANCING_DC_FAILOVER_PREFERRED_REMOTE_DCS ));
169+ this .lwtRequestRoutingMethod = parseLwtRequestRoutingMethod ();
170+ }
171+
172+ @ NonNull
173+ private RequestRoutingMethod parseLwtRequestRoutingMethod () {
174+ String methodString =
175+ profile .getString (DefaultDriverOption .LOAD_BALANCING_DEFAULT_LWT_REQUEST_ROUTING_METHOD );
176+ try {
177+ return RequestRoutingMethod .valueOf (methodString .toUpperCase ());
178+ } catch (IllegalArgumentException e ) {
179+ LOG .warn (
180+ "[{}] Unknown request routing method '{}', defaulting to PRESERVE_REPLICA_ORDER" ,
181+ logPrefix ,
182+ methodString );
183+ return RequestRoutingMethod .PRESERVE_REPLICA_ORDER ;
184+ }
185+ }
186+
187+ @ NonNull
188+ public RequestRoutingMethod getRequestRoutingMethod (@ Nullable Request request ) {
189+ if (request == null ) {
190+ return RequestRoutingMethod .REGULAR ;
191+ }
192+ if (request .getRequestRoutingType () == RequestRoutingType .LWT ) {
193+ return lwtRequestRoutingMethod ;
194+ } else {
195+ return RequestRoutingMethod .REGULAR ;
196+ }
157197 }
158198
159199 /**
@@ -260,6 +300,17 @@ protected NodeDistanceEvaluator createNodeDistanceEvaluator(
260300 @ NonNull
261301 @ Override
262302 public Queue <Node > newQueryPlan (@ Nullable Request request , @ Nullable Session session ) {
303+ switch (getRequestRoutingMethod (request )) {
304+ case PRESERVE_REPLICA_ORDER :
305+ return newQueryPlanPreserveReplicas (request , session );
306+ case REGULAR :
307+ default :
308+ return newQueryPlanRegular (request , session );
309+ }
310+ }
311+
312+ @ NonNull
313+ protected Queue <Node > newQueryPlanRegular (@ Nullable Request request , @ Nullable Session session ) {
263314 // Take a snapshot since the set is concurrent:
264315 Object [] currentNodes = liveNodes .dc (localDc ).toArray ();
265316
@@ -294,6 +345,101 @@ public Queue<Node> newQueryPlan(@Nullable Request request, @Nullable Session ses
294345 return maybeAddDcFailover (request , plan );
295346 }
296347
348+ /**
349+ * Builds a query plan that preserves replica order: local replicas, remote replicas, local
350+ * non-replicas (rotated), remote non-replicas (rotated).
351+ */
352+ @ NonNull
353+ protected Queue <Node > newQueryPlanPreserveReplicas (
354+ @ Nullable Request request , @ Nullable Session session ) {
355+ List <Node > replicas = getReplicas (request , session );
356+ String localDc = getLocalDatacenter ();
357+ List <Node > queryPlan = new ArrayList <>();
358+
359+ if (localDc == null ) {
360+ // No local DC: all replicas first, then rotated non-replicas
361+ List <Node > allNodes = new ArrayList <>();
362+ for (Object obj : getLiveNodes ().dc (null ).toArray ()) {
363+ allNodes .add ((Node ) obj );
364+ }
365+ queryPlan .addAll (replicas );
366+ addRotatedNonReplicas (queryPlan , allNodes , replicas , request );
367+ } else {
368+ // With local DC: prioritize local, then remote
369+ Map <String , List <Node >> nodesByDc = getAllNodesByDc ();
370+ addReplicasByDc (queryPlan , replicas , localDc );
371+ addNonReplicasByDc (queryPlan , nodesByDc , replicas , localDc , request );
372+ }
373+
374+ return new SimpleQueryPlan (queryPlan .toArray ());
375+ }
376+
377+ /** Collect all live nodes grouped by DC. */
378+ private Map <String , List <Node >> getAllNodesByDc () {
379+ Map <String , List <Node >> nodesByDc = new HashMap <>();
380+ for (String dc : getLiveNodes ().dcs ()) {
381+ List <Node > dcNodes = new ArrayList <>();
382+ for (Object obj : getLiveNodes ().dc (dc ).toArray ()) {
383+ dcNodes .add ((Node ) obj );
384+ }
385+ nodesByDc .put (dc , dcNodes );
386+ }
387+ return nodesByDc ;
388+ }
389+
390+ /** Add replicas with local DC first, then remote DCs. */
391+ private void addReplicasByDc (List <Node > queryPlan , List <Node > replicas , String localDc ) {
392+ replicas .stream ()
393+ .filter (r -> Objects .equals (r .getDatacenter (), localDc ))
394+ .forEach (queryPlan ::add );
395+ replicas .stream ()
396+ .filter (r -> !Objects .equals (r .getDatacenter (), localDc ))
397+ .forEach (queryPlan ::add );
398+ }
399+
400+ /** Add non-replicas with local DC first, then remote DCs (all rotated). */
401+ private void addNonReplicasByDc (
402+ List <Node > queryPlan ,
403+ Map <String , List <Node >> nodesByDc ,
404+ List <Node > replicas ,
405+ String localDc ,
406+ Request request ) {
407+ // Local DC non-replicas first
408+ addRotatedNonReplicas (
409+ queryPlan , nodesByDc .getOrDefault (localDc , new ArrayList <>()), replicas , request );
410+ // Remote DC non-replicas
411+ for (Map .Entry <String , List <Node >> entry : nodesByDc .entrySet ()) {
412+ if (!Objects .equals (entry .getKey (), localDc )) {
413+ addRotatedNonReplicas (queryPlan , entry .getValue (), replicas , request );
414+ }
415+ }
416+ }
417+
418+ /** Add non-replica nodes from given list with rotation. */
419+ private void addRotatedNonReplicas (
420+ List <Node > queryPlan , List <Node > nodes , List <Node > replicas , Request request ) {
421+ List <Node > nonReplicas =
422+ nodes .stream ().filter (n -> !replicas .contains (n )).collect (Collectors .toList ());
423+ if (!nonReplicas .isEmpty ()) {
424+ rotateNonReplicas (nonReplicas , request );
425+ queryPlan .addAll (nonReplicas );
426+ }
427+ }
428+
429+ /** Rotates nodes based on routing key (consistent) or randomly. */
430+ private void rotateNonReplicas (List <Node > nodes , @ Nullable Request request ) {
431+ if (nodes .size () <= 1 ) return ;
432+
433+ int rotationAmount =
434+ (request != null && request .getRoutingKey () != null )
435+ ? Math .abs (request .getRoutingKey ().hashCode ()) % nodes .size ()
436+ : randomNextInt (nodes .size ());
437+
438+ if (rotationAmount > 0 ) {
439+ Collections .rotate (nodes , -rotationAmount );
440+ }
441+ }
442+
297443 @ NonNull
298444 protected List <Node > getReplicas (@ Nullable Request request , @ Nullable Session session ) {
299445 if (request == null || session == null ) {
@@ -441,6 +587,11 @@ protected Object[] computeNodes() {
441587 return new CompositeQueryPlan (queryPlans );
442588 }
443589
590+ /** Exposed as a protected method so that it can be accessed by tests */
591+ protected int randomNextInt (int bound ) {
592+ return ThreadLocalRandom .current ().nextInt (bound );
593+ }
594+
444595 /** Exposed as a protected method so that it can be accessed by tests */
445596 protected void shuffleHead (Object [] currentNodes , int headLength ) {
446597 ArrayUtils .shuffleHead (currentNodes , headLength );
0 commit comments