Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions crates/core/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -631,8 +631,29 @@ expr_fn_vec!(named_struct);
expr_fn!(from_unixtime, unixtime);
expr_fn!(arrow_typeof, arg_1);
expr_fn!(arrow_cast, arg_1 datatype);
expr_fn_vec!(arrow_metadata);
expr_fn!(union_tag, arg1);
expr_fn!(random);

#[pyfunction]
fn get_field(expr: PyExpr, name: PyExpr) -> PyExpr {
functions::core::get_field()
.call(vec![expr.into(), name.into()])
.into()
}

#[pyfunction]
fn union_extract(union_expr: PyExpr, field_name: PyExpr) -> PyExpr {
functions::core::union_extract()
.call(vec![union_expr.into(), field_name.into()])
.into()
}

#[pyfunction]
fn version() -> PyExpr {
functions::core::version().call(vec![]).into()
}

// Array Functions
array_fn!(array_append, array element);
array_fn!(array_to_string, array delimiter);
Expand Down Expand Up @@ -940,6 +961,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(array_agg))?;
m.add_wrapped(wrap_pyfunction!(arrow_typeof))?;
m.add_wrapped(wrap_pyfunction!(arrow_cast))?;
m.add_wrapped(wrap_pyfunction!(arrow_metadata))?;
m.add_wrapped(wrap_pyfunction!(ascii))?;
m.add_wrapped(wrap_pyfunction!(asin))?;
m.add_wrapped(wrap_pyfunction!(asinh))?;
Expand Down Expand Up @@ -1063,6 +1085,10 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(trim))?;
m.add_wrapped(wrap_pyfunction!(trunc))?;
m.add_wrapped(wrap_pyfunction!(upper))?;
m.add_wrapped(wrap_pyfunction!(get_field))?;
m.add_wrapped(wrap_pyfunction!(union_extract))?;
m.add_wrapped(wrap_pyfunction!(union_tag))?;
m.add_wrapped(wrap_pyfunction!(version))?;
m.add_wrapped(wrap_pyfunction!(self::uuid))?; // Use self to avoid name collision
m.add_wrapped(wrap_pyfunction!(var_pop))?;
m.add_wrapped(wrap_pyfunction!(var_sample))?;
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,10 @@ extend-allowed-calls = ["datafusion.lit", "lit"]

[tool.codespell]
skip = [
"./python/tests/test_functions.py",
"./target",
"python/tests/test_functions.py",
"target",
"uv.lock",
"./examples/tpch/answers_sf1/*",
"examples/tpch/answers_sf1/*",
]
count = true
ignore-words-list = ["IST", "ans"]
Expand Down
86 changes: 86 additions & 0 deletions python/datafusion/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
"array_to_string",
"array_union",
"arrow_cast",
"arrow_metadata",
"arrow_typeof",
"ascii",
"asin",
Expand Down Expand Up @@ -149,6 +150,7 @@
"floor",
"from_unixtime",
"gcd",
"get_field",
"in_list",
"initcap",
"isnan",
Expand Down Expand Up @@ -242,6 +244,7 @@
"reverse",
"right",
"round",
"row",
"row_number",
"rpad",
"rtrim",
Expand Down Expand Up @@ -282,12 +285,15 @@
"translate",
"trim",
"trunc",
"union_extract",
"union_tag",
"upper",
"uuid",
"var",
"var_pop",
"var_samp",
"var_sample",
"version",
"when",
# Window Functions
"window",
Expand Down Expand Up @@ -2492,6 +2498,86 @@ def arrow_cast(expr: Expr, data_type: Expr) -> Expr:
return Expr(f.arrow_cast(expr.expr, data_type.expr))


def arrow_metadata(*args: Expr) -> Expr:
"""Returns the metadata of the input expression.

If called with one argument, returns a Map of all metadata key-value pairs.
If called with two arguments, returns the value for the specified metadata key.

Args:
args: An expression, optionally followed by a metadata key string.

Returns:
A Map of metadata or a specific metadata value.
"""
args = [arg.expr for arg in args]
return Expr(f.arrow_metadata(*args))


def get_field(expr: Expr, name: Expr) -> Expr:
"""Extracts a field from a struct or map by name.

Args:
expr: A struct or map expression.
name: The field name to extract.

Returns:
The value of the named field.
"""
return Expr(f.get_field(expr.expr, name.expr))


def union_extract(union_expr: Expr, field_name: Expr) -> Expr:
"""Extracts a value from a union type by field name.

Returns the value of the named field if it is the currently selected
variant, otherwise returns NULL.

Args:
union_expr: A union-typed expression.
field_name: The name of the field to extract.

Returns:
The extracted value or NULL.
"""
return Expr(f.union_extract(union_expr.expr, field_name.expr))


def union_tag(union_expr: Expr) -> Expr:
"""Returns the tag (active field name) of a union type.

Args:
union_expr: A union-typed expression.

Returns:
The name of the currently selected field in the union.
"""
return Expr(f.union_tag(union_expr.expr))


def version() -> Expr:
"""Returns the DataFusion version string.

Returns:
A string describing the DataFusion version.
"""
return Expr(f.version())


def row(*args: Expr) -> Expr:
"""Returns a struct with the given arguments.

This is an alias for :py:func:`struct`.

Args:
args: The expressions to include in the struct.

Returns:
A struct expression.
"""
return struct(*args)


def random() -> Expr:
"""Returns a random value in the range ``0.0 <= x < 1.0``.

Expand Down
70 changes: 70 additions & 0 deletions python/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,3 +1435,73 @@ def test_coalesce(df):
assert result.column(0) == pa.array(
["Hello", "fallback", "!"], type=pa.string_view()
)


def test_get_field(df):
df = df.with_column(
"s",
f.named_struct(
[
("x", column("a")),
("y", column("b")),
]
),
)
result = df.select(
f.get_field(column("s"), string_literal("x")).alias("x_val"),
f.get_field(column("s"), string_literal("y")).alias("y_val"),
).collect()[0]

assert result.column(0) == pa.array(["Hello", "World", "!"], type=pa.string_view())
assert result.column(1) == pa.array([4, 5, 6])


def test_arrow_metadata(df):
result = df.select(
f.arrow_metadata(column("a")).alias("meta"),
).collect()[0]
# The metadata column should be returned as a map type (possibly empty)
assert result.column(0).type == pa.map_(pa.utf8(), pa.utf8())


def test_version():
ctx = SessionContext()
df = ctx.from_pydict({"a": [1]})
result = df.select(f.version().alias("v")).collect()[0]
version_str = result.column(0)[0].as_py()
assert "Apache DataFusion" in version_str


def test_row(df):
result = df.select(
f.row(column("a"), column("b")).alias("r"),
f.struct(column("a"), column("b")).alias("s"),
).collect()[0]
# row is an alias for struct, so they should produce the same output
assert result.column(0) == result.column(1)


def test_union_tag():
ctx = SessionContext()
types = pa.array([0, 1, 0], type=pa.int8())
offsets = pa.array([0, 0, 1], type=pa.int32())
children = [pa.array([1, 2]), pa.array(["hello"])]
arr = pa.UnionArray.from_dense(types, offsets, children, ["int", "str"], [0, 1])
df = ctx.create_dataframe([[pa.RecordBatch.from_arrays([arr], names=["u"])]])

result = df.select(f.union_tag(column("u")).alias("tag")).collect()[0]
assert result.column(0).to_pylist() == ["int", "str", "int"]


def test_union_extract():
ctx = SessionContext()
types = pa.array([0, 1, 0], type=pa.int8())
offsets = pa.array([0, 0, 1], type=pa.int32())
children = [pa.array([1, 2]), pa.array(["hello"])]
arr = pa.UnionArray.from_dense(types, offsets, children, ["int", "str"], [0, 1])
df = ctx.create_dataframe([[pa.RecordBatch.from_arrays([arr], names=["u"])]])

result = df.select(
f.union_extract(column("u"), string_literal("int")).alias("val")
).collect()[0]
assert result.column(0).to_pylist() == [1, None, 2]
Loading