Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions diffly/_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,11 @@ def _compare_sequence_columns(
n_elements = dtype_right.shape[0]
has_same_length = col_left.list.len().eq(pl.lit(n_elements))
else: # pl.List vs pl.List
if not isinstance(max_list_length, int):
# Fallback for nested list comparisons where no max_list_length is
# available: perform a direct equality comparison without element-wise
# unrolling.
return _eq_missing(col_left.eq_missing(col_right), col_left, col_right)
if max_list_length is None:
raise ValueError(
"max_list_length must be provided for List-vs-List comparisons "
"in _compare_sequence_columns()."
)
n_elements = max_list_length
Comment thread
MariusMerkleQC marked this conversation as resolved.
Outdated
has_same_length = col_left.list.len().eq_missing(col_right.list.len())

Expand All @@ -232,7 +232,7 @@ def _get_element(col: pl.Expr, dtype: DataType | DataTypeClass, i: int) -> pl.Ex
abs_tol=abs_tol,
rel_tol=rel_tol,
abs_tol_temporal=abs_tol_temporal,
max_list_length=None,
max_list_length=max_list_length,
)
for i in range(n_elements)
]
Expand Down
46 changes: 36 additions & 10 deletions diffly/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,22 +711,30 @@ def _validate_subset_of_common_columns(self, subset: Iterable[str]) -> list[str]

@cached_property
def _max_list_lengths_by_column(self) -> dict[str, int]:
list_columns = [
col
for col in self._other_common_columns
if isinstance(self.left_schema[col], pl.List)
and isinstance(self.right_schema[col], pl.List)
]
if not list_columns:
"""Max list length across all nesting levels, for columns where both sides
contain a List anywhere in their type tree."""
left_exprs: list[pl.Expr] = []
right_exprs: list[pl.Expr] = []
columns: list[str] = []

for col in self._other_common_columns:
col_left = _list_length_exprs(pl.col(col), self.left_schema[col])
col_right = _list_length_exprs(pl.col(col), self.right_schema[col])
if not (col_left and col_right):
continue
columns.append(col)
left_exprs.append(pl.max_horizontal(col_left).alias(col))
right_exprs.append(pl.max_horizontal(col_right).alias(col))

if not columns:
return {}

exprs = [pl.col(col).list.len().max().alias(col) for col in list_columns]
[left_max, right_max] = pl.collect_all(
[self.left.select(exprs), self.right.select(exprs)]
[self.left.select(left_exprs), self.right.select(right_exprs)]
)
return {
col: max(int(left_max[col].item() or 0), int(right_max[col].item() or 0))
for col in list_columns
for col in columns
}

def _condition_equal_rows(self, columns: list[str]) -> pl.Expr:
Expand Down Expand Up @@ -833,3 +841,21 @@ def right_only(self) -> Schema:
"""Columns that are only present in the right data frame, mapped to their data
types."""
return self.right() - self.left()


def _list_length_exprs(
expr: pl.Expr, dtype: pl.DataType | pl.datatypes.DataTypeClass
) -> list[pl.Expr]:
"""Collect max-list-length scalar expressions for every List level in the type
tree."""
if isinstance(dtype, pl.List):
return [expr.list.len().max(), *_list_length_exprs(expr.explode(), dtype.inner)]
Comment thread
MariusMerkleQC marked this conversation as resolved.
if isinstance(dtype, pl.Array):
return _list_length_exprs(expr.explode(), dtype.inner)
if isinstance(dtype, pl.Struct):
return [
e
for field in dtype.fields
for e in _list_length_exprs(expr.struct[field.name], field.dtype)
]
return []
107 changes: 98 additions & 9 deletions tests/test_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ def test_condition_equal_columns_list_array_with_tolerance(
schema={"pk": pl.Int64, "a_right": rhs_type},
)

max_list_length: int | None = None
Comment thread
MariusMerkleQC marked this conversation as resolved.
Outdated
if isinstance(lhs_type, pl.List) and isinstance(rhs_type, pl.List):
max_list_length = 2

# Act
actual = (
lhs.join(rhs, on="pk", maintain_order="left")
Expand All @@ -112,7 +116,7 @@ def test_condition_equal_columns_list_array_with_tolerance(
dtype_right=rhs.schema["a_right"],
abs_tol=0.5,
rel_tol=0,
max_list_length=2,
max_list_length=max_list_length,
)
)
.to_series()
Expand Down Expand Up @@ -156,6 +160,10 @@ def test_condition_equal_columns_nested_list_array_with_tolerance(
schema={"pk": pl.Int64, "a_right": rhs_type},
)

max_list_length: int | None = None
if isinstance(lhs_type, pl.List) and isinstance(rhs_type, pl.List):
max_list_length = 3

# Act
actual = (
lhs.join(rhs, on="pk", maintain_order="left")
Expand All @@ -166,16 +174,13 @@ def test_condition_equal_columns_nested_list_array_with_tolerance(
dtype_right=rhs.schema["a_right"],
abs_tol=0.5,
rel_tol=0,
max_list_length=2,
max_list_length=max_list_length,
)
)
.to_series()
)

if isinstance(lhs_type, pl.List) and isinstance(rhs_type, pl.List):
assert actual.to_list() == [True, False, False]
else:
assert actual.to_list() == [True, True, False]
assert actual.to_list() == [True, True, False]
Comment thread
MariusMerkleQC marked this conversation as resolved.


def test_condition_equal_columns_nested_dtype_mismatch() -> None:
Expand All @@ -201,7 +206,7 @@ def test_condition_equal_columns_nested_dtype_mismatch() -> None:
"a",
dtype_left=lhs.schema["a_left"],
dtype_right=rhs.schema["a_right"],
max_list_length=None,
max_list_length=2,
)
)
.to_series()
Expand Down Expand Up @@ -341,7 +346,7 @@ def test_condition_equal_columns_array_vs_list_length_mismatch() -> None:
"a",
dtype_left=lhs.schema["a_left"],
dtype_right=rhs.schema["a_right"],
max_list_length=None,
max_list_length=2,
abs_tol=0.5,
rel_tol=0,
)
Expand Down Expand Up @@ -406,21 +411,105 @@ def test_condition_equal_columns_empty_list_array(
schema={"pk": pl.Int64, "a_right": rhs_type},
)

max_list_length: int | None = None
if isinstance(lhs_type, pl.List) or isinstance(rhs_type, pl.List):
max_list_length = 0

actual = (
lhs.join(rhs, on="pk", maintain_order="left")
.select(
condition_equal_columns(
"a",
dtype_left=lhs.schema["a_left"],
dtype_right=rhs.schema["a_right"],
max_list_length=None,
max_list_length=max_list_length,
)
)
.to_series()
)
assert actual.to_list() == [True, True]


def test_condition_equal_columns_lists_only_inner() -> None:
# Arrange
lhs = pl.DataFrame(
{
"pk": [1, 2],
"a_left": [
{
"x": 1,
"y": [1.0, 2.0, 3.0],
},
{
"x": 2,
"y": [4.0, 5.0, 6.0],
},
],
},
)
rhs = pl.DataFrame(
{
"pk": [1, 2],
"a_right": [
{
"x": 1,
"y": [1.0, 2.1, 3.0],
},
{
"x": 2,
"y": [4.0, 5.3, 6.0],
},
],
},
)
Comment thread
MariusMerkleQC marked this conversation as resolved.

# Act
actual = (
lhs.join(rhs, on="pk", maintain_order="left")
.select(
condition_equal_columns(
"a",
dtype_left=lhs.schema["a_left"],
dtype_right=rhs.schema["a_right"],
max_list_length=3,
abs_tol=0.2,
)
)
.to_series()
)

# Assert
assert actual.to_list() == [True, False]


def test_condition_equal_columns_two_lists_no_max_length() -> None:
lhs = pl.DataFrame(
{
"pk": [1, 2],
"a_left": [[1.0, 2.0], [3.0, 4.0]],
},
)
rhs = pl.DataFrame(
{
"pk": [1, 2],
"a_right": [[1.0, 2.0], [3.0, 4.0]],
},
)

with pytest.raises(
ValueError,
match="max_list_length must be provided for List-vs-List comparisons",
):
lhs.join(rhs, on="pk", maintain_order="left").select(
condition_equal_columns(
"a",
dtype_left=lhs.schema["a_left"],
dtype_right=rhs.schema["a_right"],
max_list_length=None,
)
).to_series()


@pytest.mark.parametrize(
("dtype_left", "dtype_right", "can_compare_dtypes"),
[
Expand Down
Loading