Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions diffly/_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import datetime as dt
from collections.abc import Mapping
from typing import cast

import polars as pl
from polars.datatypes import DataType, DataTypeClass
Expand Down Expand Up @@ -206,12 +207,7 @@ def _compare_sequence_columns(
n_elements = dtype_right.shape[0]
has_same_length = col_left.list.len().eq(pl.lit(n_elements))
else: # pl.List vs pl.List
if not isinstance(max_list_length, int):
# Fallback for nested list comparisons where no max_list_length is
# available: perform a direct equality comparison without element-wise
# unrolling.
return _eq_missing(col_left.eq_missing(col_right), col_left, col_right)
n_elements = max_list_length
n_elements = cast(int, max_list_length)
has_same_length = col_left.list.len().eq_missing(col_right.list.len())

if n_elements == 0:
Expand All @@ -232,7 +228,7 @@ def _get_element(col: pl.Expr, dtype: DataType | DataTypeClass, i: int) -> pl.Ex
abs_tol=abs_tol,
rel_tol=rel_tol,
abs_tol_temporal=abs_tol_temporal,
max_list_length=None,
max_list_length=max_list_length,
)
for i in range(n_elements)
]
Expand Down
46 changes: 36 additions & 10 deletions diffly/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,22 +711,30 @@ def _validate_subset_of_common_columns(self, subset: Iterable[str]) -> list[str]

@cached_property
def _max_list_lengths_by_column(self) -> dict[str, int]:
list_columns = [
col
for col in self._other_common_columns
if isinstance(self.left_schema[col], pl.List)
and isinstance(self.right_schema[col], pl.List)
]
if not list_columns:
"""Max list length across all nesting levels, for columns where both sides
contain a List anywhere in their type tree."""
left_exprs: list[pl.Expr] = []
right_exprs: list[pl.Expr] = []
columns: list[str] = []

for col in self._other_common_columns:
col_left = _list_length_exprs(pl.col(col), self.left_schema[col])
col_right = _list_length_exprs(pl.col(col), self.right_schema[col])
if not (col_left and col_right):
continue
columns.append(col)
left_exprs.append(pl.max_horizontal(col_left).alias(col))
right_exprs.append(pl.max_horizontal(col_right).alias(col))

if not columns:
return {}

exprs = [pl.col(col).list.len().max().alias(col) for col in list_columns]
[left_max, right_max] = pl.collect_all(
[self.left.select(exprs), self.right.select(exprs)]
[self.left.select(left_exprs), self.right.select(right_exprs)]
)
return {
col: max(int(left_max[col].item() or 0), int(right_max[col].item() or 0))
for col in list_columns
for col in columns
}

def _condition_equal_rows(self, columns: list[str]) -> pl.Expr:
Expand Down Expand Up @@ -833,3 +841,21 @@ def right_only(self) -> Schema:
"""Columns that are only present in the right data frame, mapped to their data
types."""
return self.right() - self.left()


def _list_length_exprs(
expr: pl.Expr, dtype: pl.DataType | pl.datatypes.DataTypeClass
) -> list[pl.Expr]:
"""Collect max-list-length scalar expressions for every List level in the type
tree."""
if isinstance(dtype, pl.List):
return [expr.list.len().max(), *_list_length_exprs(expr.explode(), dtype.inner)]
if isinstance(dtype, pl.Array):
return _list_length_exprs(expr.explode(), dtype.inner)
if isinstance(dtype, pl.Struct):
return [
e
for field in dtype.fields
for e in _list_length_exprs(expr.struct[field.name], field.dtype)
]
return []
Loading
Loading