From 2ad88776722a7ed874a5b1ed20c36070a93c51b5 Mon Sep 17 00:00:00 2001 From: Marius Merkle Date: Fri, 27 Mar 2026 09:13:16 +0100 Subject: [PATCH 1/7] feat: Tolerances for inner lists and arrays --- diffly/_conditions.py | 8 ++---- diffly/comparison.py | 55 +++++++++++++++++++++++++++++------- tests/test_conditions.py | 61 ++++++++++++++++++++++++++++++++++------ 3 files changed, 99 insertions(+), 25 deletions(-) diff --git a/diffly/_conditions.py b/diffly/_conditions.py index 25f76d9..3675648 100644 --- a/diffly/_conditions.py +++ b/diffly/_conditions.py @@ -206,11 +206,7 @@ def _compare_sequence_columns( n_elements = dtype_right.shape[0] has_same_length = col_left.list.len().eq(pl.lit(n_elements)) else: # pl.List vs pl.List - if not isinstance(max_list_length, int): - # Fallback for nested list comparisons where no max_list_length is - # available: perform a direct equality comparison without element-wise - # unrolling. - return _eq_missing(col_left.eq_missing(col_right), col_left, col_right) + assert max_list_length is not None n_elements = max_list_length has_same_length = col_left.list.len().eq_missing(col_right.list.len()) @@ -232,7 +228,7 @@ def _get_element(col: pl.Expr, dtype: DataType | DataTypeClass, i: int) -> pl.Ex abs_tol=abs_tol, rel_tol=rel_tol, abs_tol_temporal=abs_tol_temporal, - max_list_length=None, + max_list_length=max_list_length, ) for i in range(n_elements) ] diff --git a/diffly/comparison.py b/diffly/comparison.py index ee46e55..e1024aa 100644 --- a/diffly/comparison.py +++ b/diffly/comparison.py @@ -711,22 +711,30 @@ def _validate_subset_of_common_columns(self, subset: Iterable[str]) -> list[str] @cached_property def _max_list_lengths_by_column(self) -> dict[str, int]: - list_columns = [ - col - for col in self._other_common_columns - if isinstance(self.left_schema[col], pl.List) - and isinstance(self.right_schema[col], pl.List) - ] - if not list_columns: + """Max list length across all nesting levels, for columns where either side + contains a List anywhere in its type tree.""" + left_exprs: list[pl.Expr] = [] + right_exprs: list[pl.Expr] = [] + columns: list[str] = [] + + for col in self._other_common_columns: + col_left = _list_length_exprs(pl.col(col), self.left_schema[col]) + col_right = _list_length_exprs(pl.col(col), self.right_schema[col]) + if not col_left and not col_right: + continue + columns.append(col) + left_exprs.append(_max_or_zero(col_left).alias(col)) + right_exprs.append(_max_or_zero(col_right).alias(col)) + + if not columns: return {} - exprs = [pl.col(col).list.len().max().alias(col) for col in list_columns] [left_max, right_max] = pl.collect_all( - [self.left.select(exprs), self.right.select(exprs)] + [self.left.select(left_exprs), self.right.select(right_exprs)] ) return { col: max(int(left_max[col].item() or 0), int(right_max[col].item() or 0)) - for col in list_columns + for col in columns } def _condition_equal_rows(self, columns: list[str]) -> pl.Expr: @@ -833,3 +841,30 @@ def right_only(self) -> Schema: """Columns that are only present in the right data frame, mapped to their data types.""" return self.right() - self.left() + + +def _list_length_exprs( + expr: pl.Expr, dtype: pl.DataType | pl.datatypes.DataTypeClass +) -> list[pl.Expr]: + """Collect max-list-length scalar expressions for every List level in the type + tree.""" + if isinstance(dtype, pl.List): + return [expr.list.len().max(), *_list_length_exprs(expr.explode(), dtype.inner)] + if isinstance(dtype, pl.Array): + return _list_length_exprs(expr.explode(), dtype.inner) + if isinstance(dtype, pl.Struct): + return [ + e + for field in dtype.fields + for e in _list_length_exprs(expr.struct[field.name], field.dtype) + ] + return [] + + +def _max_or_zero(exprs: list[pl.Expr]) -> pl.Expr: + """Return the horizontal max of scalar expressions, or literal 0 if empty.""" + if not exprs: + return pl.lit(0) + if len(exprs) == 1: + return exprs[0] + return pl.max_horizontal(exprs) diff --git a/tests/test_conditions.py b/tests/test_conditions.py index 8aaeedb..1b6c526 100644 --- a/tests/test_conditions.py +++ b/tests/test_conditions.py @@ -102,6 +102,10 @@ def test_condition_equal_columns_list_array_with_tolerance( schema={"pk": pl.Int64, "a_right": rhs_type}, ) + max_list_length: int | None = None + if isinstance(lhs_type, pl.List) or isinstance(rhs_type, pl.List): + max_list_length = 2 + # Act actual = ( lhs.join(rhs, on="pk", maintain_order="left") @@ -112,7 +116,7 @@ def test_condition_equal_columns_list_array_with_tolerance( dtype_right=rhs.schema["a_right"], abs_tol=0.5, rel_tol=0, - max_list_length=2, + max_list_length=max_list_length, ) ) .to_series() @@ -156,6 +160,10 @@ def test_condition_equal_columns_nested_list_array_with_tolerance( schema={"pk": pl.Int64, "a_right": rhs_type}, ) + max_list_length: int | None = None + if isinstance(lhs_type, pl.List) or isinstance(rhs_type, pl.List): + max_list_length = 3 + # Act actual = ( lhs.join(rhs, on="pk", maintain_order="left") @@ -166,16 +174,13 @@ def test_condition_equal_columns_nested_list_array_with_tolerance( dtype_right=rhs.schema["a_right"], abs_tol=0.5, rel_tol=0, - max_list_length=2, + max_list_length=max_list_length, ) ) .to_series() ) - if isinstance(lhs_type, pl.List) and isinstance(rhs_type, pl.List): - assert actual.to_list() == [True, False, False] - else: - assert actual.to_list() == [True, True, False] + assert actual.to_list() == [True, True, False] def test_condition_equal_columns_nested_dtype_mismatch() -> None: @@ -201,7 +206,7 @@ def test_condition_equal_columns_nested_dtype_mismatch() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], - max_list_length=None, + max_list_length=2, ) ) .to_series() @@ -341,7 +346,7 @@ def test_condition_equal_columns_array_vs_list_length_mismatch() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], - max_list_length=None, + max_list_length=2, abs_tol=0.5, rel_tol=0, ) @@ -406,6 +411,10 @@ def test_condition_equal_columns_empty_list_array( schema={"pk": pl.Int64, "a_right": rhs_type}, ) + max_list_length: int | None = None + if isinstance(lhs_type, pl.List) or isinstance(rhs_type, pl.List): + max_list_length = 0 + actual = ( lhs.join(rhs, on="pk", maintain_order="left") .select( @@ -413,7 +422,7 @@ def test_condition_equal_columns_empty_list_array( "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], - max_list_length=None, + max_list_length=max_list_length, ) ) .to_series() @@ -421,6 +430,40 @@ def test_condition_equal_columns_empty_list_array( assert actual.to_list() == [True, True] +def test_condition_equal_columns_lists_only_inner() -> None: + # Arrange + lhs = pl.DataFrame( + { + "pk": [1, 2], + "a_left": [{"x": 1, "y": [1.0, 2.0, 3.0]}, {"x": 2, "y": [4.0, 5.0, 6.0]}], + }, + ) + rhs = pl.DataFrame( + { + "pk": [1, 2], + "a_right": [{"x": 1, "y": [1.0, 2.1, 3.0]}, {"x": 2, "y": [4.0, 5.3, 6.0]}], + }, + ) + + # Act + actual = ( + lhs.join(rhs, on="pk", maintain_order="left") + .select( + condition_equal_columns( + "a", + dtype_left=lhs.schema["a_left"], + dtype_right=rhs.schema["a_right"], + max_list_length=3, + abs_tol=0.2, + ) + ) + .to_series() + ) + + # Assert + assert actual.to_list() == [True, False] + + @pytest.mark.parametrize( ("dtype_left", "dtype_right", "can_compare_dtypes"), [ From ab746e925c8fa107e5a7499c13f3ef404c42f100 Mon Sep 17 00:00:00 2001 From: Marius Merkle Date: Fri, 27 Mar 2026 09:17:08 +0100 Subject: [PATCH 2/7] fix --- diffly/comparison.py | 6 +++--- tests/test_conditions.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/diffly/comparison.py b/diffly/comparison.py index e1024aa..1285f4b 100644 --- a/diffly/comparison.py +++ b/diffly/comparison.py @@ -711,8 +711,8 @@ def _validate_subset_of_common_columns(self, subset: Iterable[str]) -> list[str] @cached_property def _max_list_lengths_by_column(self) -> dict[str, int]: - """Max list length across all nesting levels, for columns where either side - contains a List anywhere in its type tree.""" + """Max list length across all nesting levels, for columns where both sides + contain a List anywhere in their type tree.""" left_exprs: list[pl.Expr] = [] right_exprs: list[pl.Expr] = [] columns: list[str] = [] @@ -720,7 +720,7 @@ def _max_list_lengths_by_column(self) -> dict[str, int]: for col in self._other_common_columns: col_left = _list_length_exprs(pl.col(col), self.left_schema[col]) col_right = _list_length_exprs(pl.col(col), self.right_schema[col]) - if not col_left and not col_right: + if not (col_left and col_right): continue columns.append(col) left_exprs.append(_max_or_zero(col_left).alias(col)) diff --git a/tests/test_conditions.py b/tests/test_conditions.py index 1b6c526..f700261 100644 --- a/tests/test_conditions.py +++ b/tests/test_conditions.py @@ -103,7 +103,7 @@ def test_condition_equal_columns_list_array_with_tolerance( ) max_list_length: int | None = None - if isinstance(lhs_type, pl.List) or isinstance(rhs_type, pl.List): + if isinstance(lhs_type, pl.List) and isinstance(rhs_type, pl.List): max_list_length = 2 # Act @@ -161,7 +161,7 @@ def test_condition_equal_columns_nested_list_array_with_tolerance( ) max_list_length: int | None = None - if isinstance(lhs_type, pl.List) or isinstance(rhs_type, pl.List): + if isinstance(lhs_type, pl.List) and isinstance(rhs_type, pl.List): max_list_length = 3 # Act From abb87092f218d9e537dfd1854d17a7ddb6b3aba9 Mon Sep 17 00:00:00 2001 From: Marius Merkle Date: Fri, 27 Mar 2026 09:35:08 +0100 Subject: [PATCH 3/7] remove _max_or_zero --- diffly/comparison.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/diffly/comparison.py b/diffly/comparison.py index 1285f4b..d4697fb 100644 --- a/diffly/comparison.py +++ b/diffly/comparison.py @@ -723,8 +723,8 @@ def _max_list_lengths_by_column(self) -> dict[str, int]: if not (col_left and col_right): continue columns.append(col) - left_exprs.append(_max_or_zero(col_left).alias(col)) - right_exprs.append(_max_or_zero(col_right).alias(col)) + left_exprs.append(pl.max_horizontal(col_left).alias(col)) + right_exprs.append(pl.max_horizontal(col_right).alias(col)) if not columns: return {} @@ -859,12 +859,3 @@ def _list_length_exprs( for e in _list_length_exprs(expr.struct[field.name], field.dtype) ] return [] - - -def _max_or_zero(exprs: list[pl.Expr]) -> pl.Expr: - """Return the horizontal max of scalar expressions, or literal 0 if empty.""" - if not exprs: - return pl.lit(0) - if len(exprs) == 1: - return exprs[0] - return pl.max_horizontal(exprs) From 8afddd2969f8a779076b7cd434b10dcc17f69e08 Mon Sep 17 00:00:00 2001 From: Marius Merkle Date: Fri, 27 Mar 2026 10:04:11 +0100 Subject: [PATCH 4/7] feedback copilot --- diffly/_conditions.py | 6 +++++- tests/test_conditions.py | 22 ++++++++++++++++++++-- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/diffly/_conditions.py b/diffly/_conditions.py index 3675648..12675fb 100644 --- a/diffly/_conditions.py +++ b/diffly/_conditions.py @@ -206,7 +206,11 @@ def _compare_sequence_columns( n_elements = dtype_right.shape[0] has_same_length = col_left.list.len().eq(pl.lit(n_elements)) else: # pl.List vs pl.List - assert max_list_length is not None + if max_list_length is None: + raise ValueError( + "max_list_length must be provided for List-vs-List comparisons " + "in _compare_sequence_columns()." + ) n_elements = max_list_length has_same_length = col_left.list.len().eq_missing(col_right.list.len()) diff --git a/tests/test_conditions.py b/tests/test_conditions.py index f700261..c33ad42 100644 --- a/tests/test_conditions.py +++ b/tests/test_conditions.py @@ -435,13 +435,31 @@ def test_condition_equal_columns_lists_only_inner() -> None: lhs = pl.DataFrame( { "pk": [1, 2], - "a_left": [{"x": 1, "y": [1.0, 2.0, 3.0]}, {"x": 2, "y": [4.0, 5.0, 6.0]}], + "a_left": [ + { + "x": 1, + "y": [1.0, 2.0, 3.0], + }, + { + "x": 2, + "y": [4.0, 5.0, 6.0], + }, + ], }, ) rhs = pl.DataFrame( { "pk": [1, 2], - "a_right": [{"x": 1, "y": [1.0, 2.1, 3.0]}, {"x": 2, "y": [4.0, 5.3, 6.0]}], + "a_right": [ + { + "x": 1, + "y": [1.0, 2.1, 3.0], + }, + { + "x": 2, + "y": [4.0, 5.3, 6.0], + }, + ], }, ) From 3d5467abb5f0dde6d7784fdb57bb97abcba961f5 Mon Sep 17 00:00:00 2001 From: Marius Merkle Date: Fri, 27 Mar 2026 10:08:33 +0100 Subject: [PATCH 5/7] fix test coverage --- tests/test_conditions.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/test_conditions.py b/tests/test_conditions.py index c33ad42..0610742 100644 --- a/tests/test_conditions.py +++ b/tests/test_conditions.py @@ -482,6 +482,34 @@ def test_condition_equal_columns_lists_only_inner() -> None: assert actual.to_list() == [True, False] +def test_condition_equal_columns_two_lists_no_max_length() -> None: + lhs = pl.DataFrame( + { + "pk": [1, 2], + "a_left": [[1.0, 2.0], [3.0, 4.0]], + }, + ) + rhs = pl.DataFrame( + { + "pk": [1, 2], + "a_right": [[1.0, 2.0], [3.0, 4.0]], + }, + ) + + with pytest.raises( + ValueError, + match="max_list_length must be provided for List-vs-List comparisons", + ): + lhs.join(rhs, on="pk", maintain_order="left").select( + condition_equal_columns( + "a", + dtype_left=lhs.schema["a_left"], + dtype_right=rhs.schema["a_right"], + max_list_length=None, + ) + ).to_series() + + @pytest.mark.parametrize( ("dtype_left", "dtype_right", "can_compare_dtypes"), [ From 36fc349b94dad1069b6ab82f46b218e159eb530c Mon Sep 17 00:00:00 2001 From: Marius Merkle <122545105+MariusMerkleQC@users.noreply.github.com> Date: Fri, 27 Mar 2026 23:06:43 +0100 Subject: [PATCH 6/7] test: Combine tests with `_max_list_lenghts_by_column` (#23) --- tests/test_conditions.py | 266 ++++++++++++++++++++++----------------- 1 file changed, 148 insertions(+), 118 deletions(-) diff --git a/tests/test_conditions.py b/tests/test_conditions.py index 0610742..d05e363 100644 --- a/tests/test_conditions.py +++ b/tests/test_conditions.py @@ -7,6 +7,7 @@ import pytest from diffly._conditions import _can_compare_dtypes, condition_equal_columns +from diffly.comparison import compare_frames def test_condition_equal_columns_struct() -> None: @@ -14,17 +15,20 @@ def test_condition_equal_columns_struct() -> None: lhs = pl.DataFrame( { "pk": [1, 2], - "a_left": [{"x": 1.0, "y": 2.0}, {"x": 2.0, "y": 2.1}], + "a": [{"x": 1.0, "y": 2.0}, {"x": 2.0, "y": 2.1}], } ) rhs = pl.DataFrame( { "pk": [1, 2], - "a_right": [{"y": 2.0, "x": 1.1}, {"y": 2.7, "x": 2.1}], + "a": [{"y": 2.0, "x": 1.1}, {"y": 2.7, "x": 2.1}], } ) + c = compare_frames(lhs, rhs, primary_key="pk", abs_tol=0.5, rel_tol=0) # Act + lhs = lhs.rename({"a": "a_left"}) + rhs = rhs.rename({"a": "a_right"}) actual = ( lhs.join(rhs, on="pk", maintain_order="left") .select( @@ -32,15 +36,16 @@ def test_condition_equal_columns_struct() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], - max_list_length=None, - abs_tol=0.5, - rel_tol=0, + max_list_length=c._max_list_lengths_by_column.get("a"), + abs_tol=c.abs_tol_by_column["a"], + rel_tol=c.rel_tol_by_column["a"], ) ) .to_series() ) # Assert + assert c._max_list_lengths_by_column == {} assert actual.to_list() == [True, False] @@ -49,17 +54,20 @@ def test_condition_equal_columns_different_struct_fields() -> None: lhs = pl.DataFrame( { "pk": [1, 2], - "a_left": [{"x": 1.0, "z": 2.0}, {"x": 2.0, "z": 2.1}], + "a": [{"x": 1.0, "z": 2.0}, {"x": 2.0, "z": 2.1}], } ) rhs = pl.DataFrame( { "pk": [1, 2], - "a_right": [{"y": 2.0, "x": 1.1}, {"y": 2.7, "x": 2.1}], + "a": [{"y": 2.0, "x": 1.1}, {"y": 2.7, "x": 2.1}], } ) + c = compare_frames(lhs, rhs, primary_key="pk") # Act + lhs = lhs.rename({"a": "a_left"}) + rhs = rhs.rename({"a": "a_right"}) actual = ( lhs.join(rhs, on="pk", maintain_order="left") .select( @@ -67,13 +75,16 @@ def test_condition_equal_columns_different_struct_fields() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], - max_list_length=None, + max_list_length=c._max_list_lengths_by_column.get("a"), + abs_tol=c.abs_tol_by_column["a"], + rel_tol=c.rel_tol_by_column["a"], ) ) .to_series() ) # Assert + assert c._max_list_lengths_by_column == {} assert actual.to_list() == [False, False] @@ -88,25 +99,18 @@ def test_condition_equal_columns_list_array_with_tolerance( ) -> None: # Arrange lhs = pl.DataFrame( - { - "pk": [1, 2, 3], - "a_left": [[1.0, 1.1], [2.0, 2.1], [3.0, 3.0]], - }, - schema={"pk": pl.Int64, "a_left": lhs_type}, + {"pk": [1, 2, 3], "a": [[1.0, 1.1], [2.0, 2.1], [3.0, 3.0]]}, + schema={"pk": pl.Int64, "a": lhs_type}, ) rhs = pl.DataFrame( - { - "pk": [1, 2, 3], - "a_right": [[1.0, 1.1], [2.0, 2.2], [3.0, 3.7]], - }, - schema={"pk": pl.Int64, "a_right": rhs_type}, + {"pk": [1, 2, 3], "a": [[1.0, 1.1], [2.0, 2.2], [3.0, 3.7]]}, + schema={"pk": pl.Int64, "a": rhs_type}, ) - - max_list_length: int | None = None - if isinstance(lhs_type, pl.List) and isinstance(rhs_type, pl.List): - max_list_length = 2 + c = compare_frames(lhs, rhs, primary_key="pk", abs_tol=0.5, rel_tol=0) # Act + lhs = lhs.rename({"a": "a_left"}) + rhs = rhs.rename({"a": "a_right"}) actual = ( lhs.join(rhs, on="pk", maintain_order="left") .select( @@ -114,14 +118,19 @@ def test_condition_equal_columns_list_array_with_tolerance( "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], - abs_tol=0.5, - rel_tol=0, - max_list_length=max_list_length, + max_list_length=c._max_list_lengths_by_column.get("a"), + abs_tol=c.abs_tol_by_column["a"], + rel_tol=c.rel_tol_by_column["a"], ) ) .to_series() ) + # Assert + if isinstance(lhs_type, pl.List) and isinstance(rhs_type, pl.List): + assert c._max_list_lengths_by_column == {"a": 2} + else: + assert c._max_list_lengths_by_column == {} assert actual.to_list() == [True, True, False] @@ -140,31 +149,30 @@ def test_condition_equal_columns_nested_list_array_with_tolerance( lhs = pl.DataFrame( { "pk": [1, 2, 3], - "a_left": [ + "a": [ [[1.0, 1.1, 1.3], [2.0, 2.1, 2.2]], [[3.0, 3.0, 3.1], [4.0, 4.0, 4.1]], [[5.0, 5.0, 5.1], [6.0, 6.0, 6.1]], ], }, - schema={"pk": pl.Int64, "a_left": lhs_type}, + schema={"pk": pl.Int64, "a": lhs_type}, ) rhs = pl.DataFrame( { "pk": [1, 2, 3], - "a_right": [ + "a": [ [[1.0, 1.1, 1.3], [2.0, 2.1, 2.2]], [[3.0, 3.0, 3.1], [4.0, 4.4, 4.1]], [[5.0, 5.0, 5.1], [6.0, 6.8, 6.1]], ], }, - schema={"pk": pl.Int64, "a_right": rhs_type}, + schema={"pk": pl.Int64, "a": rhs_type}, ) - - max_list_length: int | None = None - if isinstance(lhs_type, pl.List) and isinstance(rhs_type, pl.List): - max_list_length = 3 + c = compare_frames(lhs, rhs, primary_key="pk", abs_tol=0.5, rel_tol=0) # Act + lhs = lhs.rename({"a": "a_left"}) + rhs = rhs.rename({"a": "a_right"}) actual = ( lhs.join(rhs, on="pk", maintain_order="left") .select( @@ -172,33 +180,31 @@ def test_condition_equal_columns_nested_list_array_with_tolerance( "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], - abs_tol=0.5, - rel_tol=0, - max_list_length=max_list_length, + max_list_length=c._max_list_lengths_by_column.get("a"), + abs_tol=c.abs_tol_by_column["a"], + rel_tol=c.rel_tol_by_column["a"], ) ) .to_series() ) + # Assert + if isinstance(lhs_type, pl.List) and isinstance(rhs_type, pl.List): + assert c._max_list_lengths_by_column == {"a": 3} + else: + assert c._max_list_lengths_by_column == {} assert actual.to_list() == [True, True, False] def test_condition_equal_columns_nested_dtype_mismatch() -> None: # Arrange - lhs = pl.DataFrame( - { - "pk": [1, 2], - "a_left": [{"x": 1}, {"x": 2}], - }, - ) - rhs = pl.DataFrame( - { - "pk": [1, 2], - "a_right": [[1.0, 1.1], [2.0, 2.2]], - }, - ) + lhs = pl.DataFrame({"pk": [1, 2], "a": [{"x": 1}, {"x": 2}]}) + rhs = pl.DataFrame({"pk": [1, 2], "a": [[1.0, 1.1], [2.0, 2.2]]}) + c = compare_frames(lhs, rhs, primary_key="pk") # Act + lhs = lhs.rename({"a": "a_left"}) + rhs = rhs.rename({"a": "a_right"}) actual = ( lhs.join(rhs, on="pk", maintain_order="left") .select( @@ -206,32 +212,28 @@ def test_condition_equal_columns_nested_dtype_mismatch() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], - max_list_length=2, + max_list_length=c._max_list_lengths_by_column.get("a"), + abs_tol=c.abs_tol_by_column["a"], + rel_tol=c.rel_tol_by_column["a"], ) ) .to_series() ) # Assert + assert c._max_list_lengths_by_column == {} assert actual.to_list() == [False, False] def test_condition_equal_columns_exactly_one_nested() -> None: # Arrange - lhs = pl.DataFrame( - { - "pk": [1, 2], - "a_left": [{"x": 1}, {"x": 2}], - }, - ) - rhs = pl.DataFrame( - { - "pk": [1, 2], - "a_right": [1, 2], - }, - ) + lhs = pl.DataFrame({"pk": [1, 2], "a": [{"x": 1}, {"x": 2}]}) + rhs = pl.DataFrame({"pk": [1, 2], "a": [1, 2]}) + c = compare_frames(lhs, rhs, primary_key="pk") # Act + lhs = lhs.rename({"a": "a_left"}) + rhs = rhs.rename({"a": "a_right"}) actual = ( lhs.join(rhs, on="pk", maintain_order="left") .select( @@ -239,13 +241,16 @@ def test_condition_equal_columns_exactly_one_nested() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], - max_list_length=None, + max_list_length=c._max_list_lengths_by_column.get("a"), + abs_tol=c.abs_tol_by_column["a"], + rel_tol=c.rel_tol_by_column["a"], ) ) .to_series() ) # Assert + assert c._max_list_lengths_by_column == {} assert actual.to_list() == [False, False] @@ -254,7 +259,7 @@ def test_condition_equal_columns_temporal_tolerance() -> None: lhs = pl.DataFrame( { "pk": [1, 2, 3, 4], - "a_left": [ + "a": [ dt.datetime(2025, 1, 1, 9, 0, 0), dt.datetime(2025, 1, 1, 10, 0, 0), None, @@ -265,7 +270,7 @@ def test_condition_equal_columns_temporal_tolerance() -> None: rhs = pl.DataFrame( { "pk": [1, 2, 3, 4], - "a_right": [ + "a": [ dt.datetime(2025, 1, 1, 9, 0, 1), dt.datetime(2025, 1, 1, 10, 0, 5), dt.datetime(2025, 1, 1, 10, 0, 0), @@ -273,8 +278,13 @@ def test_condition_equal_columns_temporal_tolerance() -> None: ], }, ) + c = compare_frames( + lhs, rhs, primary_key="pk", abs_tol_temporal=dt.timedelta(seconds=2) + ) # Act + lhs = lhs.rename({"a": "a_left"}) + rhs = rhs.rename({"a": "a_right"}) actual = ( lhs.join(rhs, on="pk", maintain_order="left") .select( @@ -282,31 +292,36 @@ def test_condition_equal_columns_temporal_tolerance() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], - max_list_length=None, - abs_tol_temporal=dt.timedelta(seconds=2), + max_list_length=c._max_list_lengths_by_column.get("a"), + abs_tol=c.abs_tol_by_column["a"], + rel_tol=c.rel_tol_by_column["a"], + abs_tol_temporal=c.abs_tol_temporal_by_column["a"], ) ) .to_series() ) # Assert + assert c._max_list_lengths_by_column == {} assert actual.to_list() == [True, False, False, True] def test_condition_equal_columns_two_lists() -> None: + # Arrange lhs = pl.DataFrame( - { - "pk": [1, 2, 3, 4, 5], - "a_left": [[1.0, 2.0], [3.0], [5.0, None], None, None], - }, + {"pk": [1, 2, 3, 4, 5], "a": [[1.0, 2.0], [3.0], [5.0, None], None, None]}, ) rhs = pl.DataFrame( { "pk": [1, 2, 3, 4, 5], - "a_right": [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0], None], + "a": [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0], None], }, ) + c = compare_frames(lhs, rhs, primary_key="pk", abs_tol=0.5, rel_tol=0) + # Act + lhs = lhs.rename({"a": "a_left"}) + rhs = rhs.rename({"a": "a_right"}) actual = ( lhs.join(rhs, on="pk", maintain_order="left") .select( @@ -314,31 +329,31 @@ def test_condition_equal_columns_two_lists() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], - abs_tol=0.5, - rel_tol=0, - max_list_length=2, + max_list_length=c._max_list_lengths_by_column.get("a"), + abs_tol=c.abs_tol_by_column["a"], + rel_tol=c.rel_tol_by_column["a"], ) ) .to_series() ) + + # Assert + assert c._max_list_lengths_by_column == {"a": 2} assert actual.to_list() == [True, False, False, False, True] def test_condition_equal_columns_array_vs_list_length_mismatch() -> None: + # Arrange lhs = pl.DataFrame( - { - "pk": [1, 2], - "a_left": [[1.0, 2.0], [3.0, 4.0]], - }, - schema={"pk": pl.Int64, "a_left": pl.Array(pl.Float64, shape=2)}, - ) - rhs = pl.DataFrame( - { - "pk": [1, 2], - "a_right": [[1.0, 2.0], [3.0]], - }, + {"pk": [1, 2], "a": [[1.0, 2.0], [3.0, 4.0]]}, + schema={"pk": pl.Int64, "a": pl.Array(pl.Float64, shape=2)}, ) + rhs = pl.DataFrame({"pk": [1, 2], "a": [[1.0, 2.0], [3.0]]}) + c = compare_frames(lhs, rhs, primary_key="pk", abs_tol=0.5, rel_tol=0) + # Act + lhs = lhs.rename({"a": "a_left"}) + rhs = rhs.rename({"a": "a_right"}) actual = ( lhs.join(rhs, on="pk", maintain_order="left") .select( @@ -346,32 +361,34 @@ def test_condition_equal_columns_array_vs_list_length_mismatch() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], - max_list_length=2, - abs_tol=0.5, - rel_tol=0, + max_list_length=c._max_list_lengths_by_column.get("a"), + abs_tol=c.abs_tol_by_column["a"], + rel_tol=c.rel_tol_by_column["a"], ) ) .to_series() ) + + # Assert + assert c._max_list_lengths_by_column == {} assert actual.to_list() == [True, False] def test_condition_equal_columns_two_arrays_different_shapes() -> None: + # Arrange lhs = pl.DataFrame( - { - "pk": [1], - "a_left": [[1.0, 2.0]], - }, - schema={"pk": pl.Int64, "a_left": pl.Array(pl.Float64, shape=2)}, + {"pk": [1], "a": [[1.0, 2.0]]}, + schema={"pk": pl.Int64, "a": pl.Array(pl.Float64, shape=2)}, ) rhs = pl.DataFrame( - { - "pk": [1], - "a_right": [[1.0, 2.0, 3.0]], - }, - schema={"pk": pl.Int64, "a_right": pl.Array(pl.Float64, shape=3)}, + {"pk": [1], "a": [[1.0, 2.0, 3.0]]}, + schema={"pk": pl.Int64, "a": pl.Array(pl.Float64, shape=3)}, ) + c = compare_frames(lhs, rhs, primary_key="pk") + # Act + lhs = lhs.rename({"a": "a_left"}) + rhs = rhs.rename({"a": "a_right"}) actual = ( lhs.join(rhs, on="pk", maintain_order="left") .select( @@ -379,11 +396,16 @@ def test_condition_equal_columns_two_arrays_different_shapes() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], - max_list_length=None, + max_list_length=c._max_list_lengths_by_column.get("a"), + abs_tol=c.abs_tol_by_column["a"], + rel_tol=c.rel_tol_by_column["a"], ) ) .to_series() ) + + # Assert + assert c._max_list_lengths_by_column == {} assert actual.to_list() == [False] @@ -396,25 +418,20 @@ def test_condition_equal_columns_two_arrays_different_shapes() -> None: def test_condition_equal_columns_empty_list_array( lhs_type: pl.DataType, rhs_type: pl.DataType ) -> None: + # Arrange lhs = pl.DataFrame( - { - "pk": [1, 2], - "a_left": [[], None], - }, - schema={"pk": pl.Int64, "a_left": lhs_type}, + {"pk": [1, 2], "a": [[], None]}, + schema={"pk": pl.Int64, "a": lhs_type}, ) rhs = pl.DataFrame( - { - "pk": [1, 2], - "a_right": [[], None], - }, - schema={"pk": pl.Int64, "a_right": rhs_type}, + {"pk": [1, 2], "a": [[], None]}, + schema={"pk": pl.Int64, "a": rhs_type}, ) + c = compare_frames(lhs, rhs, primary_key="pk") - max_list_length: int | None = None - if isinstance(lhs_type, pl.List) or isinstance(rhs_type, pl.List): - max_list_length = 0 - + # Act + lhs = lhs.rename({"a": "a_left"}) + rhs = rhs.rename({"a": "a_right"}) actual = ( lhs.join(rhs, on="pk", maintain_order="left") .select( @@ -422,11 +439,19 @@ def test_condition_equal_columns_empty_list_array( "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], - max_list_length=max_list_length, + max_list_length=c._max_list_lengths_by_column.get("a"), + abs_tol=c.abs_tol_by_column["a"], + rel_tol=c.rel_tol_by_column["a"], ) ) .to_series() ) + + # Assert + if isinstance(lhs_type, pl.List) and isinstance(rhs_type, pl.List): + assert c._max_list_lengths_by_column == {"a": 0} + else: + assert c._max_list_lengths_by_column == {} assert actual.to_list() == [True, True] @@ -435,7 +460,7 @@ def test_condition_equal_columns_lists_only_inner() -> None: lhs = pl.DataFrame( { "pk": [1, 2], - "a_left": [ + "a": [ { "x": 1, "y": [1.0, 2.0, 3.0], @@ -450,7 +475,7 @@ def test_condition_equal_columns_lists_only_inner() -> None: rhs = pl.DataFrame( { "pk": [1, 2], - "a_right": [ + "a": [ { "x": 1, "y": [1.0, 2.1, 3.0], @@ -462,8 +487,11 @@ def test_condition_equal_columns_lists_only_inner() -> None: ], }, ) + c = compare_frames(lhs, rhs, primary_key="pk", abs_tol=0.2, rel_tol=0) # Act + lhs = lhs.rename({"a": "a_left"}) + rhs = rhs.rename({"a": "a_right"}) actual = ( lhs.join(rhs, on="pk", maintain_order="left") .select( @@ -471,14 +499,16 @@ def test_condition_equal_columns_lists_only_inner() -> None: "a", dtype_left=lhs.schema["a_left"], dtype_right=rhs.schema["a_right"], - max_list_length=3, - abs_tol=0.2, + max_list_length=c._max_list_lengths_by_column.get("a"), + abs_tol=c.abs_tol_by_column["a"], + rel_tol=c.rel_tol_by_column["a"], ) ) .to_series() ) # Assert + assert c._max_list_lengths_by_column == {"a": 3} assert actual.to_list() == [True, False] From d2097c9322aceed5e202df2253500ce578cb68be Mon Sep 17 00:00:00 2001 From: Marius Merkle Date: Fri, 27 Mar 2026 23:08:59 +0100 Subject: [PATCH 7/7] feedback OB --- diffly/_conditions.py | 8 ++------ tests/test_conditions.py | 28 ---------------------------- 2 files changed, 2 insertions(+), 34 deletions(-) diff --git a/diffly/_conditions.py b/diffly/_conditions.py index 12675fb..d6970a6 100644 --- a/diffly/_conditions.py +++ b/diffly/_conditions.py @@ -3,6 +3,7 @@ import datetime as dt from collections.abc import Mapping +from typing import cast import polars as pl from polars.datatypes import DataType, DataTypeClass @@ -206,12 +207,7 @@ def _compare_sequence_columns( n_elements = dtype_right.shape[0] has_same_length = col_left.list.len().eq(pl.lit(n_elements)) else: # pl.List vs pl.List - if max_list_length is None: - raise ValueError( - "max_list_length must be provided for List-vs-List comparisons " - "in _compare_sequence_columns()." - ) - n_elements = max_list_length + n_elements = cast(int, max_list_length) has_same_length = col_left.list.len().eq_missing(col_right.list.len()) if n_elements == 0: diff --git a/tests/test_conditions.py b/tests/test_conditions.py index d05e363..308d2a3 100644 --- a/tests/test_conditions.py +++ b/tests/test_conditions.py @@ -512,34 +512,6 @@ def test_condition_equal_columns_lists_only_inner() -> None: assert actual.to_list() == [True, False] -def test_condition_equal_columns_two_lists_no_max_length() -> None: - lhs = pl.DataFrame( - { - "pk": [1, 2], - "a_left": [[1.0, 2.0], [3.0, 4.0]], - }, - ) - rhs = pl.DataFrame( - { - "pk": [1, 2], - "a_right": [[1.0, 2.0], [3.0, 4.0]], - }, - ) - - with pytest.raises( - ValueError, - match="max_list_length must be provided for List-vs-List comparisons", - ): - lhs.join(rhs, on="pk", maintain_order="left").select( - condition_equal_columns( - "a", - dtype_left=lhs.schema["a_left"], - dtype_right=rhs.schema["a_right"], - max_list_length=None, - ) - ).to_series() - - @pytest.mark.parametrize( ("dtype_left", "dtype_right", "can_compare_dtypes"), [