diff --git a/dataframely/__init__.py b/dataframely/__init__.py index 399f9711..88cbbbf8 100644 --- a/dataframely/__init__.py +++ b/dataframely/__init__.py @@ -78,6 +78,7 @@ "deserialize_schema", "read_parquet_metadata_schema", "read_parquet_metadata_collection", + "infer_schema", "Any", "Binary", "Bool", diff --git a/dataframely/experimental/__init__.py b/dataframely/experimental/__init__.py new file mode 100644 index 00000000..757a9981 --- /dev/null +++ b/dataframely/experimental/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) QuantCo 2025-2026 +# SPDX-License-Identifier: BSD-3-Clause + +from .infer_schema import infer_schema + +__all__ = ["infer_schema"] diff --git a/dataframely/experimental/infer_schema.py b/dataframely/experimental/infer_schema.py new file mode 100644 index 00000000..b6dd5bf3 --- /dev/null +++ b/dataframely/experimental/infer_schema.py @@ -0,0 +1,185 @@ +# Copyright (c) QuantCo 2025-2026 +# SPDX-License-Identifier: BSD-3-Clause +"""Infer schema from a Polars DataFrame.""" + +from __future__ import annotations + +import keyword +import re + +import polars as pl + +_POLARS_DTYPE_MAP: dict[type[pl.DataType], str] = { + pl.Boolean: "Bool", + pl.Int8: "Int8", + pl.Int16: "Int16", + pl.Int32: "Int32", + pl.Int64: "Int64", + pl.UInt8: "UInt8", + pl.UInt16: "UInt16", + pl.UInt32: "UInt32", + pl.UInt64: "UInt64", + pl.Float32: "Float32", + pl.Float64: "Float64", + pl.String: "String", + pl.Binary: "Binary", + pl.Date: "Date", + pl.Time: "Time", + pl.Object: "Object", + pl.Categorical: "Categorical", + pl.Duration: "Duration", + pl.Datetime: "Datetime", + pl.Decimal: "Decimal", + pl.Enum: "Enum", + pl.List: "List", + pl.Array: "Array", + pl.Struct: "Struct", +} + + +def infer_schema( + df: pl.DataFrame, + schema_name: str = "Schema", +) -> str: + """Infer a dataframely schema from a Polars DataFrame. + + This function inspects a DataFrame's schema and generates corresponding + dataframely Schema code as a string. + + Args: + df: The Polars DataFrame to infer the schema from. + schema_name: The name for the generated schema class. + + Returns: + The schema code as a string. + + Example: + >>> import polars as pl + >>> from dataframely.experimental import infer_schema + >>> df = pl.DataFrame({ + ... "name": ["Alice", "Bob"], + ... "age": [25, 30], + ... "score": [95.5, None], + ... }) + >>> print(infer_schema(df, "PersonSchema")) + class PersonSchema(dy.Schema): + name = dy.String() + age = dy.Int64() + score = dy.Float64(nullable=True) + + Attention: + This functionality is considered unstable. It may be changed at any time + without it being considered a breaking change. + + Raises: + ValueError: If ``schema_name`` is not a valid Python identifier. + """ + if not schema_name.isidentifier(): + msg = f"schema_name must be a valid Python identifier, got {schema_name!r}" + raise ValueError(msg) + + return _generate_schema_code(df, schema_name) + + +def _generate_schema_code(df: pl.DataFrame, schema_name: str) -> str: + """Generate schema code string from a DataFrame.""" + lines = [f"class {schema_name}(dy.Schema):"] + used_identifiers: set[str] = set() + + for col_index, (col_name, series) in enumerate(df.to_dict().items()): + attr_name = _make_valid_identifier(col_name, col_index) + # Make sure yes have no duplicates + if attr_name in used_identifiers: + # Remove trailing "_" if exists as it will be included in the suffix anyway + if attr_name.endswith("_"): + attr_name = attr_name[:-1] + idx = 1 + while f"{attr_name}_{idx}" in used_identifiers: + idx += 1 + attr_name = f"{attr_name}_{idx}" + used_identifiers.add(attr_name) + alias = col_name if attr_name != col_name else None + col_code = _dtype_to_column_code(series, alias=alias) + lines.append(f" {attr_name} = {col_code}") + + return "\n".join(lines) + + +def _make_valid_identifier(name: str, col_index: int) -> str: + """Convert a string to a valid Python identifier.""" + # Replace invalid characters with underscores + valid_identifier = re.sub(r"[^a-zA-Z0-9_]", "_", name) + + # Handle empty name or name with only special characters ones with simple "_" + if set(valid_identifier).issubset({"_"}): + return f"column_{col_index}" + # Ensure it doesn't start with a digit + if valid_identifier[0].isdigit(): + return "_" + valid_identifier + if keyword.iskeyword(valid_identifier): + return valid_identifier + "_" + return valid_identifier + + +def _get_dtype_args(dtype: pl.DataType, series: pl.Series) -> list[str]: + """Get extra arguments for parameterized types.""" + if isinstance(dtype, pl.Datetime): + args = [] + if dtype.time_zone is not None: + args.append(f'time_zone="{dtype.time_zone}"') + if dtype.time_unit != "us": + args.append(f'time_unit="{dtype.time_unit}"') + return args + + if isinstance(dtype, pl.Duration): + if dtype.time_unit != "us": # us is the default + return [f'time_unit="{dtype.time_unit}"'] + + if isinstance(dtype, pl.Decimal): + args = [] + if dtype.precision is not None: + args.append(f"precision={dtype.precision}") + if dtype.scale != 0: + args.append(f"scale={dtype.scale}") + return args + + if isinstance(dtype, pl.Enum): + return [repr(dtype.categories.to_list())] + + if isinstance(dtype, pl.List): + return [_dtype_to_column_code(series.explode())] + + if isinstance(dtype, pl.Array): + return [_dtype_to_column_code(series.explode()), f"shape={dtype.size}"] + + if isinstance(dtype, pl.Struct): + fields_parts = [] + for field in dtype.fields: + field_code = _dtype_to_column_code(series.struct.field(field.name)) + fields_parts.append(f'"{field.name}": {field_code}') + return ["{" + ", ".join(fields_parts) + "}"] + + return [] + + +def _format_args(*args: str, nullable: bool = False, alias: str | None = None) -> str: + """Format arguments for column constructor.""" + all_args = list(args) + if nullable: + all_args.append("nullable=True") + if alias is not None: + all_args.append(f'alias="{alias}"') + return ", ".join(all_args) + + +def _dtype_to_column_code(series: pl.Series, *, alias: str | None = None) -> str: + """Convert a Polars Series to dataframely column constructor code.""" + dtype = series.dtype + nullable = series.null_count() > 0 + dy_name = _POLARS_DTYPE_MAP.get(type(dtype)) + + if dy_name is None: + return f"dy.Any({_format_args(alias=alias)})" + + args = _get_dtype_args(dtype, series) + return f"dy.{dy_name}({_format_args(*args, nullable=nullable, alias=alias)})" diff --git a/docs/api/experimental/index.rst b/docs/api/experimental/index.rst new file mode 100644 index 00000000..e2aad1cb --- /dev/null +++ b/docs/api/experimental/index.rst @@ -0,0 +1,10 @@ +============= +Experimental +============= + +.. currentmodule:: dataframely +.. autosummary:: + :toctree: _gen/ + :nosignatures: + + experimental.infer_schema diff --git a/docs/api/index.rst b/docs/api/index.rst index 1c795da4..e5c78619 100644 --- a/docs/api/index.rst +++ b/docs/api/index.rst @@ -47,3 +47,12 @@ API Reference :maxdepth: 1 misc/index + +.. grid:: + + .. grid-item-card:: + + .. toctree:: + :maxdepth: 1 + + experimental/index diff --git a/docs/guides/faq.md b/docs/guides/faq.md index d5e24d8f..83636fed 100644 --- a/docs/guides/faq.md +++ b/docs/guides/faq.md @@ -5,7 +5,7 @@ thinking, please add it here. ## How do I define additional unique keys in a {class}`~dataframely.Schema`? -By default, `dataframely` only supports defining a single non-nullable (composite) primary key in :class: +By default, `dataframely` only supports defining a single non-nullable (composite) primary key in {class} `~dataframely.Schema`. However, in some scenarios it may be useful to define additional unique keys (which support nullable fields and/or which are additionally unique). @@ -34,6 +34,13 @@ class UserSchema(dy.Schema): See our documentation on [group rules](./quickstart.md#group-rules). +## How do I get a {class}`~dataframely.Schema` for my dataframe? + +You can use {func}`dataframely.experimental.infer_schema` to get a basic {class}`~dataframely.Schema` definition for +your dataframe. The function infers column names, types and nullability from the dataframe and returns a code +representation of a {class}`~dataframely.Schema` +Starting from this definition, you can then refine the schema by adding additional rules, adjusting types, etc. + ## What versions of `polars` does `dataframely` support? Our CI automatically tests `dataframely` for a minimal supported version of `polars`, which is currently `1.35.*`, diff --git a/docs/guides/migration/index.md b/docs/guides/migration/index.md index 3c06221a..167d06c7 100644 --- a/docs/guides/migration/index.md +++ b/docs/guides/migration/index.md @@ -37,3 +37,9 @@ Users can disable `FutureWarnings` either through builtins from tools like [pytest](https://docs.pytest.org/en/stable/how-to/capture-warnings.html#controlling-warnings), or by setting the `DATAFRAMELY_NO_FUTURE_WARNINGS` environment variable to `true` or `1`. + +## Experimental features + +Experimental features are published in a dedicated namespace `dataframely.experimental`. +The versioning policy above does not apply to this namespace, and we may introduce breaking changes to experimental +features in minor releases. diff --git a/docs/guides/quickstart.md b/docs/guides/quickstart.md index 4eb341ae..51814ce0 100644 --- a/docs/guides/quickstart.md +++ b/docs/guides/quickstart.md @@ -175,7 +175,7 @@ print(failure.counts()) ``` In this case, `good` remains to be a `dy.DataFrame[HouseSchema]`, albeit with potentially fewer rows than `df`. -The `failure` object is of type :class:`~dataframely.FailureInfo` and provides means to inspect +The `failure` object is of type {class}`~dataframely.FailureInfo` and provides means to inspect the reasons for validation failures for invalid rows. Given the example data above and the schema that we defined, we know that rows 2, 3, 4, and 5 are invalid (0-indexed): @@ -185,7 +185,7 @@ Given the example data above and the schema that we defined, we know that rows 2 - Row 4 violates both of the rules above - Row 5 violates the reasonable bathroom to bedroom ratio -Using the `counts` method on the :class:`~dataframely.FailureInfo` object will result in the following dictionary: +Using the `counts` method on the {class}`~dataframely.FailureInfo` object will result in the following dictionary: ```python { diff --git a/tests/experimental/test_infer_schema.py b/tests/experimental/test_infer_schema.py new file mode 100644 index 00000000..902a2657 --- /dev/null +++ b/tests/experimental/test_infer_schema.py @@ -0,0 +1,403 @@ +# Copyright (c) QuantCo 2025-2026 +# SPDX-License-Identifier: BSD-3-Clause + +import datetime +import textwrap + +import polars as pl +import pytest + +import dataframely as dy +from dataframely.experimental import infer_schema + + +def test_basic_types() -> None: + df = pl.DataFrame( + { + "int_col": [1, 2, 3], + "float_col": [1.0, 2.0, 3.0], + "str_col": ["a", "b", "c"], + "bool_col": [True, False, True], + } + ) + result = infer_schema(df, "BasicSchema") + expected = textwrap.dedent("""\ + class BasicSchema(dy.Schema): + int_col = dy.Int64() + float_col = dy.Float64() + str_col = dy.String() + bool_col = dy.Bool()""") + assert result == expected + + +def test_nullable_detection() -> None: + df = pl.DataFrame( + { + "nullable_int": [1, None, 3], + "non_nullable_int": [1, 2, 3], + } + ) + result = infer_schema(df, "NullableSchema") + expected = textwrap.dedent("""\ + class NullableSchema(dy.Schema): + nullable_int = dy.Int64(nullable=True) + non_nullable_int = dy.Int64()""") + assert result == expected + + +def test_datetime_types() -> None: + df = pl.DataFrame( + { + "date_col": [datetime.date(2024, 1, 1)], + "time_col": [datetime.time(12, 0, 0)], + "datetime_col": [datetime.datetime(2024, 1, 1, 12, 0, 0)], + } + ) + result = infer_schema(df, schema_name="DatetimeSchema") + expected = textwrap.dedent("""\ + class DatetimeSchema(dy.Schema): + date_col = dy.Date() + time_col = dy.Time() + datetime_col = dy.Datetime()""") + assert result == expected + + +def test_datetime_with_timezone() -> None: + df = pl.DataFrame( + { + "utc_time": pl.Series([datetime.datetime(2024, 1, 1)]).dt.replace_time_zone( + "UTC" + ), + } + ) + result = infer_schema(df, schema_name="TzSchema") + expected = textwrap.dedent("""\ + class TzSchema(dy.Schema): + utc_time = dy.Datetime(time_zone="UTC")""") + assert result == expected + + +def test_enum_type() -> None: + df = pl.DataFrame( + { + "status": pl.Series(["active", "pending"]).cast( + pl.Enum(["active", "pending", "inactive"]) + ), + } + ) + result = infer_schema(df, schema_name="EnumSchema") + expected = textwrap.dedent("""\ + class EnumSchema(dy.Schema): + status = dy.Enum(['active', 'pending', 'inactive'])""") + assert result == expected + + +def test_decimal_type() -> None: + df = pl.DataFrame( + { + "amount": pl.Series(["10.50"]).cast(pl.Decimal(precision=10, scale=2)), + } + ) + result = infer_schema(df, schema_name="DecimalSchema") + expected = textwrap.dedent("""\ + class DecimalSchema(dy.Schema): + amount = dy.Decimal(precision=10, scale=2)""") + assert result == expected + + +def test_list_type() -> None: + df = pl.DataFrame( + { + "tags": [["a", "b"], ["c"]], + } + ) + result = infer_schema(df, schema_name="ListSchema") + expected = textwrap.dedent("""\ + class ListSchema(dy.Schema): + tags = dy.List(dy.String())""") + assert result == expected + + +def test_struct_type() -> None: + df = pl.DataFrame( + { + "metadata": [{"key": "value"}, {"key": "other"}], + } + ) + result = infer_schema(df, schema_name="StructSchema") + expected = textwrap.dedent("""\ + class StructSchema(dy.Schema): + metadata = dy.Struct({"key": dy.String()})""") + assert result == expected + + +def test_list_with_nullable_inner() -> None: + df = pl.DataFrame({"names": [["Alice"], [None]]}) + result = infer_schema(df, schema_name="ListNullableInnerSchema") + expected = textwrap.dedent("""\ + class ListNullableInnerSchema(dy.Schema): + names = dy.List(dy.String(nullable=True))""") + assert result == expected + + +def test_struct_with_nullable_field() -> None: + df = pl.DataFrame({"data": [{"key": "value"}, {"key": None}]}) + result = infer_schema(df, schema_name="StructNullableFieldSchema") + expected = textwrap.dedent("""\ + class StructNullableFieldSchema(dy.Schema): + data = dy.Struct({"key": dy.String(nullable=True)})""") + assert result == expected + + +def test_array_type() -> None: + df = pl.DataFrame({"vector": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]}).cast( + {"vector": pl.Array(pl.Float64(), 3)} + ) + result = infer_schema(df, schema_name="ArraySchema") + expected = textwrap.dedent("""\ + class ArraySchema(dy.Schema): + vector = dy.Array(dy.Float64(), shape=3)""") + assert result == expected + + +def test_invalid_identifier() -> None: + df = pl.DataFrame( + { + "123invalid": ["test"], + } + ) + result = infer_schema(df, schema_name="InvalidIdSchema") + expected = textwrap.dedent("""\ + class InvalidIdSchema(dy.Schema): + _123invalid = dy.String(alias="123invalid")""") + assert result == expected + + +def test_python_keyword() -> None: + df = pl.DataFrame( + { + "class": ["test"], + } + ) + result = infer_schema(df, schema_name="KeywordSchema") + expected = textwrap.dedent("""\ + class KeywordSchema(dy.Schema): + class_ = dy.String(alias="class")""") + assert result == expected + + +def test_all_integer_types() -> None: + df = pl.DataFrame( + { + "i8": pl.Series([1], dtype=pl.Int8), + "i16": pl.Series([1], dtype=pl.Int16), + "i32": pl.Series([1], dtype=pl.Int32), + "i64": pl.Series([1], dtype=pl.Int64), + "u8": pl.Series([1], dtype=pl.UInt8), + "u16": pl.Series([1], dtype=pl.UInt16), + "u32": pl.Series([1], dtype=pl.UInt32), + "u64": pl.Series([1], dtype=pl.UInt64), + } + ) + result = infer_schema(df, schema_name="IntSchema") + assert "dy.Int8()" in result + assert "dy.Int16()" in result + assert "dy.Int32()" in result + assert "dy.Int64()" in result + assert "dy.UInt8()" in result + assert "dy.UInt16()" in result + assert "dy.UInt32()" in result + assert "dy.UInt64()" in result + + +def test_float_types() -> None: + df = pl.DataFrame( + { + "f32": pl.Series([1.0], dtype=pl.Float32), + "f64": pl.Series([1.0], dtype=pl.Float64), + } + ) + result = infer_schema(df, schema_name="FloatSchema") + assert "dy.Float32()" in result + assert "dy.Float64()" in result + + +def test_invalid_schema_name_raises_error() -> None: + df = pl.DataFrame({"col": [1]}) + with pytest.raises( + ValueError, match="schema_name must be a valid Python identifier" + ): + infer_schema(df, "Invalid Name") + + +def test_default_schema_name() -> None: + df = pl.DataFrame({"col": [1]}) + result = infer_schema(df) + assert "class Schema(dy.Schema):" in result + + +def test_binary_type() -> None: + df = pl.DataFrame({"data": pl.Series([b"hello"], dtype=pl.Binary)}) + result = infer_schema(df, schema_name="BinarySchema") + assert "dy.Binary()" in result + + +def test_null_type() -> None: + df = pl.DataFrame({"null_col": pl.Series([None, None], dtype=pl.Null)}) + result = infer_schema(df, schema_name="NullSchema") + assert "dy.Any()" in result + + +def test_object_type() -> None: + df = pl.DataFrame({"obj": pl.Series([object()], dtype=pl.Object)}) + result = infer_schema(df, schema_name="ObjectSchema") + assert "dy.Object()" in result + + +def test_categorical_type() -> None: + df = pl.DataFrame({"cat": pl.Series(["a", "b"]).cast(pl.Categorical())}) + result = infer_schema(df, schema_name="CatSchema") + assert "dy.Categorical()" in result + + +def test_duration_type() -> None: + df = pl.DataFrame( + {"dur": pl.Series([datetime.timedelta(days=1)], dtype=pl.Duration)} + ) + result = infer_schema(df, schema_name="DurSchema") + assert "dy.Duration()" in result + + +def test_duration_with_time_unit_ms() -> None: + df = pl.DataFrame( + {"dur": pl.Series([datetime.timedelta(days=1)]).cast(pl.Duration("ms"))} + ) + result = infer_schema(df, schema_name="DurSchema") + assert 'time_unit="ms"' in result + + +def test_duration_with_time_unit_ns() -> None: + df = pl.DataFrame( + {"dur": pl.Series([datetime.timedelta(days=1)]).cast(pl.Duration("ns"))} + ) + result = infer_schema(df, schema_name="DurSchema") + assert 'time_unit="ns"' in result + + +def test_datetime_with_time_unit_ms() -> None: + df = pl.DataFrame( + {"dt": pl.Series([datetime.datetime(2024, 1, 1)]).cast(pl.Datetime("ms"))} + ) + result = infer_schema(df, schema_name="DtSchema") + assert 'time_unit="ms"' in result + + +def test_datetime_with_time_unit_ns() -> None: + df = pl.DataFrame( + {"dt": pl.Series([datetime.datetime(2024, 1, 1)]).cast(pl.Datetime("ns"))} + ) + result = infer_schema(df, schema_name="DtSchema") + assert 'time_unit="ns"' in result + + +def test_decimal_without_scale() -> None: + df = pl.DataFrame( + {"amount": pl.Series(["10"]).cast(pl.Decimal(precision=5, scale=0))} + ) + result = infer_schema(df, schema_name="DecSchema") + assert "precision=5" in result + assert "scale=" not in result + + +def test_column_sanitization() -> None: + df = pl.DataFrame({"$": ["test"], "valid": ["test"], "!!!": ["test"], "": ["test"]}) + result = infer_schema(df) + assert 'column_0 = dy.String(alias="$")' in result + assert 'column_2 = dy.String(alias="!!!")' in result + assert 'column_3 = dy.String(alias="")' in result + + +def test_column_with_spaces() -> None: + df = pl.DataFrame( + { + "col name": ["test"], + "col_name": ["test"], + "col@name": ["test"], + "col_name_1": ["test"], + } + ) + result = infer_schema(df, schema_name="SpaceSchema") + assert 'col_name = dy.String(alias="col name")' in result + assert 'col_name_1 = dy.String(alias="col_name")' in result + assert 'col_name_2 = dy.String(alias="col@name")' in result + assert 'col_name_1_1 = dy.String(alias="col_name_1")' in result + + +def test_column_sanitization_conflict_with_existing() -> None: + df = pl.DataFrame({"col_": ["test"], "col!": ["test"]}) + result = infer_schema(df, schema_name="ConflictSchema") + expected = textwrap.dedent("""\ + class ConflictSchema(dy.Schema): + col_ = dy.String() + col_1 = dy.String(alias="col!")""") + assert result == expected + + +@pytest.mark.parametrize( + "df", + [ + # Basic types + pl.DataFrame( + { + "int_col": [1, 2, 3], + "float_col": [1.0, 2.0, 3.0], + "str_col": ["a", "b", "c"], + "bool_col": [True, False, True], + } + ), + # Nullable + pl.DataFrame({"nullable_int": [1, None, 3], "non_nullable_int": [1, 2, 3]}), + # Datetime types + pl.DataFrame( + { + "date_col": [datetime.date(2024, 1, 1)], + "time_col": [datetime.time(12, 0, 0)], + "datetime_col": [datetime.datetime(2024, 1, 1, 12, 0, 0)], + } + ), + # Enum + pl.DataFrame( + { + "status": pl.Series(["active", "pending"]).cast( + pl.Enum(["active", "pending", "inactive"]) + ) + } + ), + # List + pl.DataFrame({"tags": [["a", "b"], ["c"]]}), + # Struct + pl.DataFrame({"metadata": [{"key": "value"}]}), + # Array + pl.DataFrame({"vector": [[1.0, 2.0, 3.0]]}).cast( + {"vector": pl.Array(pl.Float64(), 3)} + ), + # Invalid identifiers and keywords + pl.DataFrame({"123invalid": ["test"], "class": ["test"]}), + # Decimal + pl.DataFrame( + {"amount": pl.Series(["10.50"]).cast(pl.Decimal(precision=10, scale=2))} + ), + # Nested types + pl.DataFrame({"nested_list": [[["a", "b"]]]}), + pl.DataFrame({"nested_struct": [{"outer": {"inner": "value"}}]}), + # Nullable inner types + pl.DataFrame({"list_with_nulls": [["a"], [None]]}), + pl.DataFrame({"struct_with_nulls": [{"key": "value"}, {"key": None}]}), + ], +) +def test_inferred_schema_validates_dataframe(df: pl.DataFrame) -> None: + code = infer_schema(df, "TestSchema") + namespace: dict = {"dy": dy} + exec(code, namespace) # noqa: S102 + schema = namespace["TestSchema"] + assert schema.is_valid(df)