diff --git a/AGENTS.md b/AGENTS.md index f6fdfbd90..86c2e9c3b 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -42,3 +42,15 @@ Every Python function must include a docstring with usage examples. - **Alias functions**: Functions that are simple aliases (e.g., `list_sort` aliasing `array_sort`) only need a one-line description and a `See Also` reference to the primary function. They do not need their own examples. + +## Aggregate and Window Function Documentation + +When adding or updating an aggregate or window function, ensure the corresponding +site documentation is kept in sync: + +- **Aggregations**: `docs/source/user-guide/common-operations/aggregations.rst` — + add new aggregate functions to the "Aggregate Functions" list and include usage + examples if appropriate. +- **Window functions**: `docs/source/user-guide/common-operations/windows.rst` — + add new window functions to the "Available Functions" list and include usage + examples if appropriate. diff --git a/crates/core/src/expr/grouping_set.rs b/crates/core/src/expr/grouping_set.rs index 549a866ed..11d8f4fcd 100644 --- a/crates/core/src/expr/grouping_set.rs +++ b/crates/core/src/expr/grouping_set.rs @@ -15,9 +15,11 @@ // specific language governing permissions and limitations // under the License. -use datafusion::logical_expr::GroupingSet; +use datafusion::logical_expr::{Expr, GroupingSet}; use pyo3::prelude::*; +use crate::expr::PyExpr; + #[pyclass( from_py_object, frozen, @@ -30,6 +32,39 @@ pub struct PyGroupingSet { grouping_set: GroupingSet, } +#[pymethods] +impl PyGroupingSet { + #[staticmethod] + #[pyo3(signature = (*exprs))] + fn rollup(exprs: Vec) -> PyExpr { + Expr::GroupingSet(GroupingSet::Rollup( + exprs.into_iter().map(|e| e.expr).collect(), + )) + .into() + } + + #[staticmethod] + #[pyo3(signature = (*exprs))] + fn cube(exprs: Vec) -> PyExpr { + Expr::GroupingSet(GroupingSet::Cube( + exprs.into_iter().map(|e| e.expr).collect(), + )) + .into() + } + + #[staticmethod] + #[pyo3(signature = (*expr_lists))] + fn grouping_sets(expr_lists: Vec>) -> PyExpr { + Expr::GroupingSet(GroupingSet::GroupingSets( + expr_lists + .into_iter() + .map(|list| list.into_iter().map(|e| e.expr).collect()) + .collect(), + )) + .into() + } +} + impl From for GroupingSet { fn from(grouping_set: PyGroupingSet) -> Self { grouping_set.grouping_set diff --git a/crates/core/src/functions.rs b/crates/core/src/functions.rs index 74654ce46..f173aaa51 100644 --- a/crates/core/src/functions.rs +++ b/crates/core/src/functions.rs @@ -791,9 +791,10 @@ aggregate_function!(var_pop); aggregate_function!(approx_distinct); aggregate_function!(approx_median); -// Code is commented out since grouping is not yet implemented -// https://github.com/apache/datafusion-python/issues/861 -// aggregate_function!(grouping); +// The grouping function's physical plan is not implemented, but the +// ResolveGroupingFunction analyzer rule rewrites it before the physical +// planner sees it, so it works correctly at runtime. +aggregate_function!(grouping); #[pyfunction] #[pyo3(signature = (sort_expression, percentile, num_centroids=None, filter=None))] @@ -831,6 +832,19 @@ pub fn approx_percentile_cont_with_weight( add_builder_fns_to_aggregate(agg_fn, None, filter, None, None) } +#[pyfunction] +#[pyo3(signature = (sort_expression, percentile, filter=None))] +pub fn percentile_cont( + sort_expression: PySortExpr, + percentile: f64, + filter: Option, +) -> PyDataFusionResult { + let agg_fn = + functions_aggregate::expr_fn::percentile_cont(sort_expression.sort, lit(percentile)); + + add_builder_fns_to_aggregate(agg_fn, None, filter, None, None) +} + // We handle last_value explicitly because the signature expects an order_by // https://github.com/apache/datafusion/issues/12376 #[pyfunction] @@ -1031,6 +1045,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(approx_median))?; m.add_wrapped(wrap_pyfunction!(approx_percentile_cont))?; m.add_wrapped(wrap_pyfunction!(approx_percentile_cont_with_weight))?; + m.add_wrapped(wrap_pyfunction!(percentile_cont))?; m.add_wrapped(wrap_pyfunction!(range))?; m.add_wrapped(wrap_pyfunction!(array_agg))?; m.add_wrapped(wrap_pyfunction!(arrow_typeof))?; @@ -1080,7 +1095,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(from_unixtime))?; m.add_wrapped(wrap_pyfunction!(gcd))?; m.add_wrapped(wrap_pyfunction!(greatest))?; - // m.add_wrapped(wrap_pyfunction!(grouping))?; + m.add_wrapped(wrap_pyfunction!(grouping))?; m.add_wrapped(wrap_pyfunction!(in_list))?; m.add_wrapped(wrap_pyfunction!(initcap))?; m.add_wrapped(wrap_pyfunction!(isnan))?; diff --git a/docs/source/user-guide/common-operations/aggregations.rst b/docs/source/user-guide/common-operations/aggregations.rst index e458e5fcb..de24a2ba5 100644 --- a/docs/source/user-guide/common-operations/aggregations.rst +++ b/docs/source/user-guide/common-operations/aggregations.rst @@ -163,6 +163,168 @@ Suppose we want to find the speed values for only Pokemon that have low Attack v f.avg(col_speed, filter=col_attack < lit(50)).alias("Avg Speed Low Attack")]) +Grouping Sets +------------- + +The default style of aggregation produces one row per group. Sometimes you want a single query to +produce rows at multiple levels of detail — for example, totals per type *and* an overall grand +total, or subtotals for every combination of two columns plus the individual column totals. Writing +separate queries and concatenating them is tedious and runs the data multiple times. Grouping sets +solve this by letting you specify several grouping levels in one pass. + +DataFusion supports three grouping set styles through the +:py:class:`~datafusion.expr.GroupingSet` class: + +- :py:meth:`~datafusion.expr.GroupingSet.rollup` — hierarchical subtotals, like a drill-down report +- :py:meth:`~datafusion.expr.GroupingSet.cube` — every possible subtotal combination, like a pivot table +- :py:meth:`~datafusion.expr.GroupingSet.grouping_sets` — explicitly list exactly which grouping levels you want + +Because result rows come from different grouping levels, a column that is *not* part of a +particular level will be ``null`` in that row. Use :py:func:`~datafusion.functions.grouping` to +distinguish a real ``null`` in the data from one that means "this column was aggregated across." +It returns ``0`` when the column is a grouping key for that row, and ``1`` when it is not. + +Rollup +^^^^^^ + +:py:meth:`~datafusion.expr.GroupingSet.rollup` creates a hierarchy. ``rollup(a, b)`` produces +grouping sets ``(a, b)``, ``(a)``, and ``()`` — like nested subtotals in a report. This is useful +when your columns have a natural hierarchy, such as region → city or type → subtype. + +Suppose we want to summarize Pokemon stats by ``Type 1`` with subtotals and a grand total. With +the default aggregation style we would need two separate queries. With ``rollup`` we get it all at +once: + +.. ipython:: python + + from datafusion.expr import GroupingSet + + df.aggregate( + [GroupingSet.rollup(col_type_1)], + [f.count(col_speed).alias("Count"), + f.avg(col_speed).alias("Avg Speed"), + f.max(col_speed).alias("Max Speed")] + ).sort(col_type_1.sort(ascending=True, nulls_first=True)) + +The first row — where ``Type 1`` is ``null`` — is the grand total across all types. But how do you +tell a grand-total ``null`` apart from a Pokemon that genuinely has no type? The +:py:func:`~datafusion.functions.grouping` function returns ``0`` when the column is a grouping key +for that row and ``1`` when it is aggregated across. + +.. note:: + + Due to an upstream DataFusion limitation + (`apache/datafusion#21411 `_), + ``.alias()`` cannot be applied directly to a ``grouping()`` expression — it will raise an + error at execution time. Instead, use + :py:meth:`~datafusion.dataframe.DataFrame.with_column_renamed` on the result DataFrame to + give the column a readable name. Once the upstream issue is resolved, you will be able to + use ``.alias()`` directly and the workaround below will no longer be necessary. + +The raw column name generated by ``grouping()`` contains internal identifiers, so we use +:py:meth:`~datafusion.dataframe.DataFrame.with_column_renamed` to clean it up: + +.. ipython:: python + + result = df.aggregate( + [GroupingSet.rollup(col_type_1)], + [f.count(col_speed).alias("Count"), + f.avg(col_speed).alias("Avg Speed"), + f.grouping(col_type_1)] + ) + for field in result.schema(): + if field.name.startswith("grouping("): + result = result.with_column_renamed(field.name, "Is Total") + result.sort(col_type_1.sort(ascending=True, nulls_first=True)) + +With two columns the hierarchy becomes more apparent. ``rollup(Type 1, Type 2)`` produces: + +- one row per ``(Type 1, Type 2)`` pair — the most detailed level +- one row per ``Type 1`` — subtotals +- one grand total row + +.. ipython:: python + + df.aggregate( + [GroupingSet.rollup(col_type_1, col_type_2)], + [f.count(col_speed).alias("Count"), + f.avg(col_speed).alias("Avg Speed")] + ).sort( + col_type_1.sort(ascending=True, nulls_first=True), + col_type_2.sort(ascending=True, nulls_first=True) + ) + +Cube +^^^^ + +:py:meth:`~datafusion.expr.GroupingSet.cube` produces every possible subset. ``cube(a, b)`` +produces grouping sets ``(a, b)``, ``(a)``, ``(b)``, and ``()`` — one more than ``rollup`` because +it also includes ``(b)`` alone. This is useful when neither column is "above" the other in a +hierarchy and you want all cross-tabulations. + +For our Pokemon data, ``cube(Type 1, Type 2)`` gives us stats broken down by the type pair, +by ``Type 1`` alone, by ``Type 2`` alone, and a grand total — all in one query: + +.. ipython:: python + + df.aggregate( + [GroupingSet.cube(col_type_1, col_type_2)], + [f.count(col_speed).alias("Count"), + f.avg(col_speed).alias("Avg Speed")] + ).sort( + col_type_1.sort(ascending=True, nulls_first=True), + col_type_2.sort(ascending=True, nulls_first=True) + ) + +Compared to the ``rollup`` example above, notice the extra rows where ``Type 1`` is ``null`` but +``Type 2`` has a value — those are the per-``Type 2`` subtotals that ``rollup`` does not include. + +Explicit Grouping Sets +^^^^^^^^^^^^^^^^^^^^^^ + +:py:meth:`~datafusion.expr.GroupingSet.grouping_sets` lets you list exactly which grouping levels +you need when ``rollup`` or ``cube`` would produce too many or too few. Each argument is a list of +columns forming one grouping set. + +For example, if we want only the per-``Type 1`` totals and per-``Type 2`` totals — but *not* the +full ``(Type 1, Type 2)`` detail rows or the grand total — we can ask for exactly that: + +.. ipython:: python + + df.aggregate( + [GroupingSet.grouping_sets([col_type_1], [col_type_2])], + [f.count(col_speed).alias("Count"), + f.avg(col_speed).alias("Avg Speed")] + ).sort( + col_type_1.sort(ascending=True, nulls_first=True), + col_type_2.sort(ascending=True, nulls_first=True) + ) + +Each row belongs to exactly one grouping level. The :py:func:`~datafusion.functions.grouping` +function tells you which level each row comes from: + +.. ipython:: python + + result = df.aggregate( + [GroupingSet.grouping_sets([col_type_1], [col_type_2])], + [f.count(col_speed).alias("Count"), + f.avg(col_speed).alias("Avg Speed"), + f.grouping(col_type_1), + f.grouping(col_type_2)] + ) + for field in result.schema(): + if field.name.startswith("grouping("): + clean = field.name.split(".")[-1].rstrip(")") + result = result.with_column_renamed(field.name, f"grouping({clean})") + result.sort( + col_type_1.sort(ascending=True, nulls_first=True), + col_type_2.sort(ascending=True, nulls_first=True) + ) + +Where ``grouping(Type 1)`` is ``0`` the row is a per-``Type 1`` total (and ``Type 2`` is ``null``). +Where ``grouping(Type 2)`` is ``0`` the row is a per-``Type 2`` total (and ``Type 1`` is ``null``). + + Aggregate Functions ------------------- @@ -192,6 +354,7 @@ The available aggregate functions are: - :py:func:`datafusion.functions.stddev_pop` - :py:func:`datafusion.functions.var_samp` - :py:func:`datafusion.functions.var_pop` + - :py:func:`datafusion.functions.var_population` 6. Linear Regression Functions - :py:func:`datafusion.functions.regr_count` - :py:func:`datafusion.functions.regr_slope` @@ -208,9 +371,16 @@ The available aggregate functions are: - :py:func:`datafusion.functions.nth_value` 8. String Functions - :py:func:`datafusion.functions.string_agg` -9. Approximation Functions +9. Percentile Functions + - :py:func:`datafusion.functions.percentile_cont` + - :py:func:`datafusion.functions.quantile_cont` - :py:func:`datafusion.functions.approx_distinct` - :py:func:`datafusion.functions.approx_median` - :py:func:`datafusion.functions.approx_percentile_cont` - :py:func:`datafusion.functions.approx_percentile_cont_with_weight` +10. Grouping Set Functions + - :py:func:`datafusion.functions.grouping` + - :py:meth:`datafusion.expr.GroupingSet.rollup` + - :py:meth:`datafusion.expr.GroupingSet.cube` + - :py:meth:`datafusion.expr.GroupingSet.grouping_sets` diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index 10e2a913f..9907eae8b 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -633,8 +633,22 @@ def aggregate( ) -> DataFrame: """Aggregates the rows of the current DataFrame. + By default each unique combination of the ``group_by`` columns + produces one row. To get multiple levels of subtotals in a + single pass, pass a + :py:class:`~datafusion.expr.GroupingSet` expression + (created via + :py:meth:`~datafusion.expr.GroupingSet.rollup`, + :py:meth:`~datafusion.expr.GroupingSet.cube`, or + :py:meth:`~datafusion.expr.GroupingSet.grouping_sets`) + as the ``group_by`` argument. See the + :ref:`aggregation` user guide for detailed examples. + Args: - group_by: Sequence of expressions or column names to group by. + group_by: Sequence of expressions or column names to group + by. A :py:class:`~datafusion.expr.GroupingSet` + expression may be included to produce multiple grouping + levels (rollup, cube, or explicit grouping sets). aggs: Sequence of expressions to aggregate. Returns: diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 14753a4f5..35388468c 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -91,7 +91,6 @@ Extension = expr_internal.Extension FileType = expr_internal.FileType Filter = expr_internal.Filter -GroupingSet = expr_internal.GroupingSet Join = expr_internal.Join ILike = expr_internal.ILike InList = expr_internal.InList @@ -1430,3 +1429,129 @@ def __repr__(self) -> str: SortKey = Expr | SortExpr | str + + +class GroupingSet: + """Factory for creating grouping set expressions. + + Grouping sets control how + :py:meth:`~datafusion.dataframe.DataFrame.aggregate` groups rows. + Instead of a single ``GROUP BY``, they produce multiple grouping + levels in one pass — subtotals, cross-tabulations, or arbitrary + column subsets. + + Use :py:func:`~datafusion.functions.grouping` in the aggregate list + to tell which columns are aggregated across in each result row. + """ + + @staticmethod + def rollup(*exprs: Expr | str) -> Expr: + """Create a ``ROLLUP`` grouping set for use with ``aggregate()``. + + ``ROLLUP`` generates all prefixes of the given column list as + grouping sets. For example, ``rollup(a, b)`` produces grouping + sets ``(a, b)``, ``(a)``, and ``()`` (grand total). + + This is equivalent to ``GROUP BY ROLLUP(a, b)`` in SQL. + + Args: + *exprs: Column expressions or column name strings to + include in the rollup. + + Examples: + >>> from datafusion.expr import GroupingSet + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 1, 2], "b": [10, 20, 30]}) + >>> result = df.aggregate( + ... [GroupingSet.rollup(dfn.col("a"))], + ... [dfn.functions.sum(dfn.col("b")).alias("s"), + ... dfn.functions.grouping(dfn.col("a"))], + ... ).sort(dfn.col("a").sort(nulls_first=False)) + >>> result.collect_column("s").to_pylist() + [30, 30, 60] + + See Also: + :py:meth:`cube`, :py:meth:`grouping_sets`, + :py:func:`~datafusion.functions.grouping` + """ + args = [_to_raw_expr(e) for e in exprs] + return Expr(expr_internal.GroupingSet.rollup(*args)) + + @staticmethod + def cube(*exprs: Expr | str) -> Expr: + """Create a ``CUBE`` grouping set for use with ``aggregate()``. + + ``CUBE`` generates all possible subsets of the given column list + as grouping sets. For example, ``cube(a, b)`` produces grouping + sets ``(a, b)``, ``(a)``, ``(b)``, and ``()`` (grand total). + + This is equivalent to ``GROUP BY CUBE(a, b)`` in SQL. + + Args: + *exprs: Column expressions or column name strings to + include in the cube. + + Examples: + With a single column, ``cube`` behaves identically to + :py:meth:`rollup`: + + >>> from datafusion.expr import GroupingSet + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 1, 2], "b": [10, 20, 30]}) + >>> result = df.aggregate( + ... [GroupingSet.cube(dfn.col("a"))], + ... [dfn.functions.sum(dfn.col("b")).alias("s"), + ... dfn.functions.grouping(dfn.col("a"))], + ... ).sort(dfn.col("a").sort(nulls_first=False)) + >>> result.collect_column("s").to_pylist() + [30, 30, 60] + + See Also: + :py:meth:`rollup`, :py:meth:`grouping_sets`, + :py:func:`~datafusion.functions.grouping` + """ + args = [_to_raw_expr(e) for e in exprs] + return Expr(expr_internal.GroupingSet.cube(*args)) + + @staticmethod + def grouping_sets(*expr_lists: list[Expr | str]) -> Expr: + """Create explicit grouping sets for use with ``aggregate()``. + + Each argument is a list of column expressions or column name + strings representing one grouping set. For example, + ``grouping_sets([a], [b])`` groups by ``a`` alone and by ``b`` + alone in a single query. + + This is equivalent to ``GROUP BY GROUPING SETS ((a), (b))`` in + SQL. + + Args: + *expr_lists: Each positional argument is a list of + expressions or column name strings forming one + grouping set. + + Examples: + >>> from datafusion.expr import GroupingSet + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict( + ... {"a": ["x", "x", "y"], "b": ["m", "n", "m"], + ... "c": [1, 2, 3]}) + >>> result = df.aggregate( + ... [GroupingSet.grouping_sets( + ... [dfn.col("a")], [dfn.col("b")])], + ... [dfn.functions.sum(dfn.col("c")).alias("s"), + ... dfn.functions.grouping(dfn.col("a")), + ... dfn.functions.grouping(dfn.col("b"))], + ... ).sort( + ... dfn.col("a").sort(nulls_first=False), + ... dfn.col("b").sort(nulls_first=False), + ... ) + >>> result.collect_column("s").to_pylist() + [3, 3, 4, 2] + + See Also: + :py:meth:`rollup`, :py:meth:`cube`, + :py:func:`~datafusion.functions.grouping` + """ + raw_lists = [[_to_raw_expr(e) for e in lst] for lst in expr_lists] + return Expr(expr_internal.GroupingSet.grouping_sets(*raw_lists)) diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index aa7f28746..9dfabb62d 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -166,6 +166,7 @@ "generate_series", "get_field", "greatest", + "grouping", "ifnull", "in_list", "initcap", @@ -256,9 +257,11 @@ "order_by", "overlay", "percent_rank", + "percentile_cont", "pi", "pow", "power", + "quantile_cont", "radians", "random", "range", @@ -331,6 +334,7 @@ "uuid", "var", "var_pop", + "var_population", "var_samp", "var_sample", "version", @@ -2654,7 +2658,6 @@ def arrow_cast(expr: Expr, data_type: Expr | str | pa.DataType) -> Expr: >>> result.collect_column("c")[0].as_py() 1.0 - >>> import pyarrow as pa >>> result = df.select( ... dfn.functions.arrow_cast( ... dfn.col("a"), data_type=pa.float64() @@ -2677,7 +2680,6 @@ def arrow_metadata(expr: Expr, key: Expr | str | None = None) -> Expr: If called with two arguments, returns the value for the specified metadata key. Examples: - >>> import pyarrow as pa >>> field = pa.field("val", pa.int64(), metadata={"k": "v"}) >>> schema = pa.schema([field]) >>> batch = pa.RecordBatch.from_arrays([pa.array([1])], schema=schema) @@ -2746,7 +2748,6 @@ def union_extract(union_expr: Expr, field_name: Expr | str) -> Expr: variant, otherwise returns NULL. Examples: - >>> import pyarrow as pa >>> ctx = dfn.SessionContext() >>> types = pa.array([0, 1, 0], type=pa.int8()) >>> offsets = pa.array([0, 0, 1], type=pa.int32()) @@ -2771,7 +2772,6 @@ def union_tag(union_expr: Expr) -> Expr: """Returns the tag (active field name) of a union type. Examples: - >>> import pyarrow as pa >>> ctx = dfn.SessionContext() >>> types = pa.array([0, 1, 0], type=pa.int8()) >>> offsets = pa.array([0, 0, 1], type=pa.int32()) @@ -4306,6 +4306,60 @@ def approx_percentile_cont_with_weight( ) +def percentile_cont( + sort_expression: Expr | SortExpr, + percentile: float, + filter: Expr | None = None, +) -> Expr: + """Computes the exact percentile of input values using continuous interpolation. + + Unlike :py:func:`approx_percentile_cont`, this function computes the exact + percentile value rather than an approximation. + + If using the builder functions described in ref:`_aggregation` this function ignores + the options ``order_by``, ``null_treatment``, and ``distinct``. + + Args: + sort_expression: Values for which to find the percentile + percentile: This must be between 0.0 and 1.0, inclusive + filter: If provided, only compute against rows for which the filter is True + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1.0, 2.0, 3.0, 4.0, 5.0]}) + >>> result = df.aggregate( + ... [], [dfn.functions.percentile_cont( + ... dfn.col("a"), 0.5 + ... ).alias("v")]) + >>> result.collect_column("v")[0].as_py() + 3.0 + + >>> result = df.aggregate( + ... [], [dfn.functions.percentile_cont( + ... dfn.col("a"), 0.5, + ... filter=dfn.col("a") > dfn.lit(1.0), + ... ).alias("v")]) + >>> result.collect_column("v")[0].as_py() + 3.5 + """ + sort_expr_raw = sort_or_default(sort_expression) + filter_raw = filter.expr if filter is not None else None + return Expr(f.percentile_cont(sort_expr_raw, percentile, filter=filter_raw)) + + +def quantile_cont( + sort_expression: Expr | SortExpr, + percentile: float, + filter: Expr | None = None, +) -> Expr: + """Computes the exact percentile of input values using continuous interpolation. + + See Also: + This is an alias for :py:func:`percentile_cont`. + """ + return percentile_cont(sort_expression, percentile, filter) + + def array_agg( expression: Expr, distinct: bool = False, @@ -4364,6 +4418,65 @@ def array_agg( ) +def grouping( + expression: Expr, + distinct: bool = False, + filter: Expr | None = None, +) -> Expr: + """Indicates whether a column is aggregated across in the current row. + + Returns 0 when the column is part of the grouping key for that row + (i.e., the row contains per-group results for that column). Returns 1 + when the column is *not* part of the grouping key (i.e., the row's + aggregate spans all values of that column). + + This function is meaningful with + :py:meth:`GroupingSet.rollup `, + :py:meth:`GroupingSet.cube `, or + :py:meth:`GroupingSet.grouping_sets `, + where different rows are grouped by different subsets of columns. In a + default aggregation without grouping sets every column is always part + of the key, so ``grouping()`` always returns 0. + + .. warning:: + + Due to an upstream DataFusion limitation + (`#21411 `_), + ``.alias()`` cannot be applied directly to a ``grouping()`` + expression. Doing so will raise an error at execution time. To + rename the column, use + :py:meth:`~datafusion.dataframe.DataFrame.with_column_renamed` + on the result DataFrame instead. + + Args: + expression: The column to check grouping status for + distinct: If True, compute on distinct values only + filter: If provided, only compute against rows for which the filter is True + + Examples: + With :py:meth:`~datafusion.expr.GroupingSet.rollup`, the result + includes both per-group rows (``grouping(a) = 0``) and a + grand-total row where ``a`` is aggregated across + (``grouping(a) = 1``): + + >>> from datafusion.expr import GroupingSet + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1, 1, 2], "b": [10, 20, 30]}) + >>> result = df.aggregate( + ... [GroupingSet.rollup(dfn.col("a"))], + ... [dfn.functions.sum(dfn.col("b")).alias("s"), + ... dfn.functions.grouping(dfn.col("a"))], + ... ).sort(dfn.col("a").sort(nulls_first=False)) + >>> result.collect_column("s").to_pylist() + [30, 30, 60] + + See Also: + :py:class:`~datafusion.expr.GroupingSet` + """ + filter_raw = filter.expr if filter is not None else None + return Expr(f.grouping(expression.expr, distinct=distinct, filter=filter_raw)) + + def avg( expression: Expr, filter: Expr | None = None, @@ -4835,6 +4948,15 @@ def var_pop(expression: Expr, filter: Expr | None = None) -> Expr: return Expr(f.var_pop(expression.expr, filter=filter_raw)) +def var_population(expression: Expr, filter: Expr | None = None) -> Expr: + """Computes the population variance of the argument. + + See Also: + This is an alias for :py:func:`var_pop`. + """ + return var_pop(expression, filter) + + def var_samp(expression: Expr, filter: Expr | None = None) -> Expr: """Computes the sample variance of the argument. diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index 4e99fa9e3..11e94af1c 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -22,6 +22,7 @@ import pytest from datafusion import SessionContext, column, literal from datafusion import functions as f +from datafusion.expr import GroupingSet np.seterr(invalid="ignore") @@ -1820,6 +1821,114 @@ def test_conditional_functions(df_with_nulls, expr, expected): assert result.column(0) == expected +@pytest.mark.parametrize( + ("func", "filter_expr", "expected"), + [ + (f.percentile_cont, None, 3.0), + (f.percentile_cont, column("a") > literal(1.0), 3.5), + (f.quantile_cont, None, 3.0), + ], + ids=["no_filter", "with_filter", "quantile_cont_alias"], +) +def test_percentile_cont(func, filter_expr, expected): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1.0, 2.0, 3.0, 4.0, 5.0]}) + result = df.aggregate( + [], [func(column("a"), 0.5, filter=filter_expr).alias("v")] + ).collect()[0] + assert result.column(0)[0].as_py() == expected + + +@pytest.mark.parametrize( + ("grouping_set_expr", "expected_grouping", "expected_sums"), + [ + (GroupingSet.rollup(column("a")), [0, 0, 1], [30, 30, 60]), + (GroupingSet.cube(column("a")), [0, 0, 1], [30, 30, 60]), + (GroupingSet.rollup("a"), [0, 0, 1], [30, 30, 60]), + (GroupingSet.cube("a"), [0, 0, 1], [30, 30, 60]), + ], + ids=["rollup", "cube", "rollup_str", "cube_str"], +) +def test_grouping_set_single_column( + grouping_set_expr, expected_grouping, expected_sums +): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1, 1, 2], "b": [10, 20, 30]}) + result = df.aggregate( + [grouping_set_expr], + [f.sum(column("b")).alias("s"), f.grouping(column("a"))], + ).sort(column("a").sort(ascending=True, nulls_first=False)) + batches = result.collect() + g = pa.concat_arrays([b.column(2) for b in batches]).to_pylist() + s = pa.concat_arrays([b.column("s") for b in batches]).to_pylist() + assert g == expected_grouping + assert s == expected_sums + + +@pytest.mark.parametrize( + ("grouping_set_expr", "expected_rows"), + [ + # rollup(a, b) => (a,b), (a), () => 3 + 2 + 1 = 6 + (GroupingSet.rollup(column("a"), column("b")), 6), + # cube(a, b) => (a,b), (a), (b), () => 3 + 2 + 2 + 1 = 8 + (GroupingSet.cube(column("a"), column("b")), 8), + (GroupingSet.rollup("a", "b"), 6), + (GroupingSet.cube("a", "b"), 8), + ], + ids=["rollup", "cube", "rollup_str", "cube_str"], +) +def test_grouping_set_multi_column(grouping_set_expr, expected_rows): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1, 1, 2], "b": ["x", "y", "x"], "c": [10, 20, 30]}) + result = df.aggregate( + [grouping_set_expr], + [f.sum(column("c")).alias("s")], + ) + total_rows = sum(b.num_rows for b in result.collect()) + assert total_rows == expected_rows + + +@pytest.mark.parametrize( + "grouping_set_expr", + [ + GroupingSet.grouping_sets([column("a")], [column("b")]), + GroupingSet.grouping_sets(["a"], ["b"]), + ], + ids=["expr", "str"], +) +def test_grouping_sets_explicit(grouping_set_expr): + # Each row's grouping() value tells you which columns are aggregated across. + ctx = SessionContext() + df = ctx.from_pydict({"a": ["x", "x", "y"], "b": ["m", "n", "m"], "c": [1, 2, 3]}) + result = df.aggregate( + [grouping_set_expr], + [ + f.sum(column("c")).alias("s"), + f.grouping(column("a")), + f.grouping(column("b")), + ], + ).sort( + column("a").sort(ascending=True, nulls_first=False), + column("b").sort(ascending=True, nulls_first=False), + ) + batches = result.collect() + ga = pa.concat_arrays([b.column(3) for b in batches]).to_pylist() + gb = pa.concat_arrays([b.column(4) for b in batches]).to_pylist() + # Rows grouped by (a): ga=0 (a is a key), gb=1 (b is aggregated across) + # Rows grouped by (b): ga=1 (a is aggregated across), gb=0 (b is a key) + assert ga == [0, 0, 1, 1] + assert gb == [1, 1, 0, 0] + + +def test_var_population(): + ctx = SessionContext() + df = ctx.from_pydict({"a": [-1.0, 0.0, 2.0]}) + result = df.aggregate([], [f.var_population(column("a")).alias("v")]).collect()[0] + # var_population is an alias for var_pop + expected = df.aggregate([], [f.var_pop(column("a")).alias("v")]).collect()[0] + assert abs(result.column(0)[0].as_py() - expected.column(0)[0].as_py()) < 1e-10 + + def test_get_field(df): df = df.with_column( "s",