diff --git a/diffly/_conditions.py b/diffly/_conditions.py index d6970a6..fa77b6b 100644 --- a/diffly/_conditions.py +++ b/diffly/_conditions.py @@ -140,22 +140,22 @@ def _compare_columns( elif isinstance(dtype_left, pl.List | pl.Array) and isinstance( dtype_right, pl.List | pl.Array ): - return _compare_sequence_columns( - col_left=col_left, - col_right=col_right, - dtype_left=dtype_left, - dtype_right=dtype_right, - max_list_length=max_list_length, - abs_tol=abs_tol, - rel_tol=rel_tol, - abs_tol_temporal=abs_tol_temporal, - ) - - if ( - isinstance(dtype_left, pl.Enum) - and isinstance(dtype_right, pl.Enum) - and dtype_left != dtype_right - ) or _enum_and_categorical(dtype_left, dtype_right): + if _needs_element_wise_comparison(dtype_left.inner, dtype_right.inner): + return _compare_sequence_columns( + col_left=col_left, + col_right=col_right, + dtype_left=dtype_left, + dtype_right=dtype_right, + max_list_length=max_list_length, + abs_tol=abs_tol, + rel_tol=rel_tol, + abs_tol_temporal=abs_tol_temporal, + ) + return col_left.eq_missing(col_right) + + if _different_enums(dtype_left, dtype_right) or _enum_and_categorical( + dtype_left, dtype_right + ): # Enums with different categories as well as enums and categoricals # can't be compared directly. # Fall back to comparison of strings. @@ -237,6 +237,55 @@ def _get_element(col: pl.Expr, dtype: DataType | DataTypeClass, i: int) -> pl.Ex return _eq_missing(has_same_length & elements_match, col_left, col_right) +def _is_float_numeric_pair( + dtype_left: DataType | DataTypeClass, + dtype_right: DataType | DataTypeClass, +) -> bool: + return (dtype_left.is_float() or dtype_right.is_float()) and ( + dtype_left.is_numeric() and dtype_right.is_numeric() + ) + + +def _is_temporal_pair( + dtype_left: DataType | DataTypeClass, + dtype_right: DataType | DataTypeClass, +) -> bool: + return dtype_left.is_temporal() and dtype_right.is_temporal() + + +def _needs_element_wise_comparison( + dtype_left: DataType | DataTypeClass, + dtype_right: DataType | DataTypeClass, +) -> bool: + """Check if two dtypes require element-wise comparison (tolerances or special + handling). + + Returns False when eq_missing() on the whole column would produce identical results, + allowing us to skip the expensive element-wise iteration for list/array columns. + """ + if _is_float_numeric_pair(dtype_left, dtype_right): + return True + if _is_temporal_pair(dtype_left, dtype_right): + return True + if _different_enums(dtype_left, dtype_right) or _enum_and_categorical( + dtype_left, dtype_right + ): + return True + if isinstance(dtype_left, pl.Struct) and isinstance(dtype_right, pl.Struct): + fields_left = {f.name: f.dtype for f in dtype_left.fields} + fields_right = {f.name: f.dtype for f in dtype_right.fields} + return any( + _needs_element_wise_comparison(fields_left[name], fields_right[name]) + for name in fields_left + if name in fields_right + ) + if isinstance(dtype_left, pl.List | pl.Array) and isinstance( + dtype_right, pl.List | pl.Array + ): + return _needs_element_wise_comparison(dtype_left.inner, dtype_right.inner) + return False + + def _compare_primitive_columns( col_left: pl.Expr, col_right: pl.Expr, @@ -246,13 +295,11 @@ def _compare_primitive_columns( rel_tol: float, abs_tol_temporal: dt.timedelta, ) -> pl.Expr: - if (dtype_left.is_float() or dtype_right.is_float()) and ( - dtype_left.is_numeric() and dtype_right.is_numeric() - ): + if _is_float_numeric_pair(dtype_left, dtype_right): return col_left.is_close(col_right, abs_tol=abs_tol, rel_tol=rel_tol).pipe( _eq_missing_with_nan, lhs=col_left, rhs=col_right ) - elif dtype_left.is_temporal() and dtype_right.is_temporal(): + elif _is_temporal_pair(dtype_left, dtype_right): diff_less_than_tolerance = (col_left - col_right).abs() <= abs_tol_temporal return diff_less_than_tolerance.pipe(_eq_missing, lhs=col_left, rhs=col_right) @@ -270,6 +317,12 @@ def _eq_missing_with_nan(expr: pl.Expr, lhs: pl.Expr, rhs: pl.Expr) -> pl.Expr: return _eq_missing(expr, lhs, rhs) | both_nan +def _different_enums( + left: DataType | DataTypeClass, right: DataType | DataTypeClass +) -> bool: + return isinstance(left, pl.Enum) and isinstance(right, pl.Enum) and left != right + + def _enum_and_categorical( left: DataType | DataTypeClass, right: DataType | DataTypeClass ) -> bool: diff --git a/tests/test_conditions.py b/tests/test_conditions.py index 308d2a3..5895be8 100644 --- a/tests/test_conditions.py +++ b/tests/test_conditions.py @@ -6,7 +6,11 @@ import polars as pl import pytest -from diffly._conditions import _can_compare_dtypes, condition_equal_columns +from diffly._conditions import ( + _can_compare_dtypes, + _needs_element_wise_comparison, + condition_equal_columns, +) from diffly.comparison import compare_frames @@ -512,6 +516,45 @@ def test_condition_equal_columns_lists_only_inner() -> None: assert actual.to_list() == [True, False] +def test_condition_equal_columns_list_of_different_enums() -> None: + # Arrange + first_enum = pl.Enum(["one", "two"]) + second_enum = pl.Enum(["one", "two", "three"]) + + lhs = pl.DataFrame( + {"pk": [1, 2], "a": [["one", "two"], ["one", "one"]]}, + schema_overrides={"a": pl.List(first_enum)}, + ) + rhs = pl.DataFrame( + {"pk": [1, 2], "a": [["one", "two"], ["one", "three"]]}, + schema_overrides={"a": pl.List(second_enum)}, + ) + c = compare_frames(lhs, rhs, primary_key="pk") + + # Act + lhs = lhs.rename({"a": "a_left"}) + rhs = rhs.rename({"a": "a_right"}) + actual = ( + lhs.join(rhs, on="pk", maintain_order="left") + .select( + condition_equal_columns( + "a", + dtype_left=lhs.schema["a_left"], + dtype_right=rhs.schema["a_right"], + max_list_length=c._max_list_lengths_by_column.get("a"), + abs_tol=c.abs_tol_by_column["a"], + rel_tol=c.rel_tol_by_column["a"], + ) + ) + .to_series() + ) + + # Assert + assert c._max_list_lengths_by_column == {"a": 2} + assert _needs_element_wise_comparison(first_enum, second_enum) + assert actual.to_list() == [True, False] + + @pytest.mark.parametrize( ("dtype_left", "dtype_right", "can_compare_dtypes"), [ @@ -534,3 +577,73 @@ def test_can_compare_dtypes( dtype_left=dtype_left, dtype_right=dtype_right ) assert can_compare_dtypes_actual == can_compare_dtypes + + +@pytest.mark.parametrize( + ("dtype_left", "dtype_right", "expected"), + [ + # Primitives that don't need element-wise comparison + (pl.Int64, pl.Int64, False), + (pl.String, pl.String, False), + (pl.Boolean, pl.Boolean, False), + # Float/numeric pairs + (pl.Float64, pl.Float64, True), + (pl.Int64, pl.Float64, True), + (pl.Float32, pl.Int32, True), + # Temporal pairs + (pl.Datetime, pl.Datetime, True), + (pl.Date, pl.Date, True), + (pl.Datetime, pl.Date, True), + # Enum/categorical + (pl.Enum(["a", "b"]), pl.Enum(["a", "b"]), False), + (pl.Enum(["a", "b"]), pl.Enum(["a", "b", "c"]), True), + (pl.Enum(["a"]), pl.Categorical(), True), + (pl.Categorical(), pl.Enum(["a"]), True), + # Struct with no tolerance-requiring fields + ( + pl.Struct({"x": pl.Int64, "y": pl.String}), + pl.Struct({"x": pl.Int64, "y": pl.String}), + False, + ), + # Struct with a float field + ( + pl.Struct({"x": pl.Int64, "y": pl.Float64}), + pl.Struct({"x": pl.Int64, "y": pl.Float64}), + True, + ), + # Struct with different-category enums + ( + pl.Struct({"x": pl.Enum(["a"])}), + pl.Struct({"x": pl.Enum(["b"])}), + True, + ), + # List/Array with non-tolerance inner type + (pl.List(pl.Int64), pl.List(pl.Int64), False), + (pl.Array(pl.String, shape=3), pl.Array(pl.String, shape=3), False), + # List/Array with tolerance-requiring inner type + (pl.List(pl.Float64), pl.List(pl.Float64), True), + (pl.Array(pl.Datetime, shape=2), pl.Array(pl.Datetime, shape=2), True), + # Nested: list of structs with a float field + ( + pl.List(pl.Struct({"x": pl.Float64})), + pl.List(pl.Struct({"x": pl.Float64})), + True, + ), + # Nested: list of structs without tolerance-requiring fields + ( + pl.List(pl.Struct({"x": pl.Int64})), + pl.List(pl.Struct({"x": pl.Int64})), + False, + ), + # Deeply nested: struct with a list of structs with a float field + ( + pl.List(pl.Struct({"x": pl.String, "y": pl.List(pl.Float64)})), + pl.List(pl.Struct({"x": pl.String, "y": pl.List(pl.Float64)})), + True, + ), + ], +) +def test_needs_element_wise_comparison( + dtype_left: pl.DataType, dtype_right: pl.DataType, expected: bool +) -> None: + assert _needs_element_wise_comparison(dtype_left, dtype_right) == expected diff --git a/tests/test_performance.py b/tests/test_performance.py index b776028..3bd81dc 100644 --- a/tests/test_performance.py +++ b/tests/test_performance.py @@ -83,10 +83,10 @@ def expensive_computation(col: pl.Expr) -> pl.Expr: ) -def test_element_wise_comparison_slower_than_eq_missing_for_list_columns() -> None: - """Confirm that comparing list columns with non-tolerance inner types via - eq_missing() is significantly faster than the element-wise - _compare_sequence_columns() path.""" +def test_eq_missing_not_slower_than_element_wise_for_list_columns() -> None: + """Ensure that comparing list columns with non-tolerance inner types via + eq_missing() is not slower than the element-wise _compare_sequence_columns() + path.""" n_rows = 500_000 list_len = 20 num_runs_measured = 10 @@ -126,10 +126,10 @@ def test_element_wise_comparison_slower_than_eq_missing_for_list_columns() -> No mean_time_cond = statistics.mean(times_cond[num_runs_warmup:]) ratio = mean_time_cond / mean_time_eq - assert ratio > 2.0, ( - f"Element-wise comparison was only {ratio:.1f}x slower than eq_missing " + assert ratio < 1.25, ( + f"condition_equal_columns was {ratio:.1f}x slower than eq_missing " f"({mean_time_cond:.3f}s vs {mean_time_eq:.3f}s). " - f"Expected at least 2x slowdown to justify the optimization." + f"Expected comparable performance since list should use eq_missing directly." )