Skip to content
30 changes: 30 additions & 0 deletions native/shuffle/src/partitioners/multi_partition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,36 @@ impl MultiPartitionShuffleRepartitioner {
return Ok(());
}

// For zero-column schemas (e.g. COUNT queries), assign all rows to partition 0.
// No hashing or expression evaluation needed — just route through normal buffering.
if input.num_columns() == 0 {
let num_rows = input.num_rows();
self.metrics.baseline.record_output(num_rows);
// All rows go to partition 0: partition_starts = [0, num_rows, num_rows, ...]
// partition_row_indices = [0, 1, 2, ..., num_rows-1]
let mut scratch = std::mem::take(&mut self.scratch);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still looks way more complicated than what I would expect. Why do we need scratch space and to write num_rows partition_row_indices. Why are we "partitioning" rows that don't exist?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just trying CI if single partition approach doesn't break anything

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its fine, shortened the PR, so shuffle steps for count batches

  • partitioning_batch sees num_columns() == 0, buffers the batch, pushes all row indices into partition_indices[0] — skips hashing
  • The IPC stream encodes the schema (no fields) and a single record batch message carrying just the row count in the metadata

scratch
.partition_starts
.resize(self.partition_indices.len() + 1, 0);
scratch.partition_starts.fill(num_rows as u32);
scratch.partition_starts[0] = 0;
scratch.partition_row_indices.resize(num_rows, 0);
for (i, v) in scratch.partition_row_indices[..num_rows]
.iter_mut()
.enumerate()
{
*v = i as u32;
}
self.buffer_partitioned_batch_may_spill(
input,
&scratch.partition_row_indices[..num_rows],
&scratch.partition_starts,
)
.await?;
self.scratch = scratch;
return Ok(());
}

if input.num_rows() > self.batch_size {
return Err(DataFusionError::Internal(
"Input batch size exceeds configured batch size. Call `insert_batch` instead."
Expand Down
27 changes: 16 additions & 11 deletions native/shuffle/src/partitioners/partitioned_batch_iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use arrow::array::RecordBatch;
use arrow::array::{RecordBatch, RecordBatchOptions};
use arrow::compute::interleave_record_batch;
use datafusion::common::DataFusionError;

Expand Down Expand Up @@ -97,15 +97,20 @@ impl Iterator for PartitionedBatchIterator<'_> {

let indices_end = std::cmp::min(self.pos + self.batch_size, self.indices.len());
let indices = &self.indices[self.pos..indices_end];
match interleave_record_batch(&self.record_batches, indices) {
Ok(batch) => {
self.pos = indices_end;
Some(Ok(batch))
}
Err(e) => Some(Err(DataFusionError::ArrowError(
Box::from(e),
Some(DataFusionError::get_back_trace()),
))),
}

// interleave_record_batch requires at least one column or an explicit row count.
// For zero-column batches (e.g. COUNT queries), create the batch directly.
let schema = self.record_batches[0].schema();
let result = if schema.fields().is_empty() {
let options = RecordBatchOptions::new().with_row_count(Some(indices.len()));
RecordBatch::try_new_with_options(schema, vec![], &options)
} else {
interleave_record_batch(&self.record_batches, indices)
};

self.pos = indices_end;
Some(result.map_err(|e| {
DataFusionError::ArrowError(Box::from(e), Some(DataFusionError::get_back_trace()))
}))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -474,4 +474,34 @@ class CometNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper
}
}
}

test("native datafusion scan - repartition count") {
withTempPath { dir =>
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
spark
.range(1000)
.selectExpr("id", "concat('name_', id) as name")
.repartition(100)
.write
.parquet(dir.toString)
}
withSQLConf(
CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_DATAFUSION,
CometConf.COMET_EXEC_SHUFFLE_WITH_ROUND_ROBIN_PARTITIONING_ENABLED.key -> "true") {
Comment on lines +488 to +490
Copy link
Copy Markdown
Member

@andygrove andygrove Mar 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the issue specific to this combination of scan and shuffle?

interleave_record_batch is used in other parts of the shuffle codebase so those may also need updating?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like native_datafusion is used here just to easily force native shuffle.

I am confused by the comment For zero-column batches (e.g. COUNT queries) when the test isn't using a count.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was able to reproduce the crash with both native_datafusion and native_iceberg_compat in combination with native shuffle. the sample query for repro and test case is

spark.read.parquet("hdfs://location").repartition(50).count()

perhaps test can be slightly improved, if it confuses

val testDF = spark.read.parquet(dir.toString).repartition(10)
// Verify CometShuffleExchangeExec is in the plan
assert(
find(testDF.queryExecution.executedPlan) {
case _: CometShuffleExchangeExec => true
case _ => false
}.isDefined,
"Expected CometShuffleExchangeExec in the plan")
// Actual validation, no crash
val count = testDF.count()
assert(count == 1000)
// Ensure test df evaluated by Comet
checkSparkAnswerAndOperator(testDF)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no usage of count() here. Is this intentional ?
Another way could be something like:

val testDF = spark.read.parquet(dir.toString).repartition(10)
val countDF = testDF.selectExpr("count(*) as cnt")
val count = countDF.collect().head.getLong(0)
assert(count == 1000)
checkSparkAnswerAndOperator(countDF)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is intentional, yes. Count returns just Long, I can't really inject in the middle to check native plan, so do it I check that at least everything before count is native which works for this case

}
}
}
}
Loading