From fa3b9fa8454ac1eb0df7ba788309a1dd123b96ad Mon Sep 17 00:00:00 2001 From: gabriel Date: Thu, 5 Mar 2026 10:46:39 +0100 Subject: [PATCH 01/13] mvp infer schema --- dataframely/__init__.py | 2 + dataframely/_generate_schema.py | 247 +++++++++++++++++++++++++++++ tests/test_infer_schema.py | 272 ++++++++++++++++++++++++++++++++ 3 files changed, 521 insertions(+) create mode 100644 dataframely/_generate_schema.py create mode 100644 tests/test_infer_schema.py diff --git a/dataframely/__init__.py b/dataframely/__init__.py index 399f971..ca2b586 100644 --- a/dataframely/__init__.py +++ b/dataframely/__init__.py @@ -12,6 +12,7 @@ from . import random from ._filter import filter +from ._generate_schema import infer_schema from ._rule import rule from ._typing import DataFrame, LazyFrame, Validation from .collection import ( @@ -78,6 +79,7 @@ "deserialize_schema", "read_parquet_metadata_schema", "read_parquet_metadata_collection", + "infer_schema", "Any", "Binary", "Bool", diff --git a/dataframely/_generate_schema.py b/dataframely/_generate_schema.py new file mode 100644 index 0000000..4fea170 --- /dev/null +++ b/dataframely/_generate_schema.py @@ -0,0 +1,247 @@ +# Copyright (c) QuantCo 2025-2026 +# SPDX-License-Identifier: BSD-3-Clause +"""Infer schema from a Polars DataFrame.""" + +from __future__ import annotations + +import keyword +import re +from typing import TYPE_CHECKING, Literal, overload + +import polars as pl + +if TYPE_CHECKING: + from dataframely.schema import Schema + + +@overload +def infer_schema( + df: pl.DataFrame, + schema_name: str = ..., + *, + return_type: None = ..., +) -> None: ... + + +@overload +def infer_schema( + df: pl.DataFrame, + schema_name: str = ..., + *, + return_type: Literal["string"], +) -> str: ... + + +@overload +def infer_schema( + df: pl.DataFrame, + schema_name: str = ..., + *, + return_type: Literal["schema"], +) -> type[Schema]: ... + + +def infer_schema( + df: pl.DataFrame, + schema_name: str = "InferredSchema", + *, + return_type: Literal["string", "schema"] | None = None, +) -> str | type[Schema] | None: + """Infer a dataframely schema from a Polars DataFrame. + + This function inspects a DataFrame's schema and generates a corresponding + dataframely Schema. It can print the schema code, return it as a string, + or return an actual Schema class. + + Args: + df: The Polars DataFrame to infer the schema from. + schema_name: The name for the generated schema class. + return_type: Controls the return format: + + - ``None`` (default): Print the schema code to stdout, return ``None``. + - ``"string"``: Return the schema code as a string. + - ``"schema"``: Return an actual Schema class. + + Returns: + Depends on ``return_type``: + + - ``None``: Returns ``None`` (prints to stdout). + - ``"string"``: Returns the schema code as a string. + - ``"schema"``: Returns a Schema class that can be used directly. + + Example: + >>> import polars as pl + >>> import dataframely as dy + >>> df = pl.DataFrame({ + ... "name": ["Alice", "Bob"], + ... "age": [25, 30], + ... "score": [95.5, None], + ... }) + >>> dy.infer_schema(df, "PersonSchema") + class PersonSchema(dy.Schema): + name = dy.String() + age = dy.Int64() + score = dy.Float64(nullable=True) + >>> schema = dy.infer_schema(df, "PersonSchema", return_type="schema") + >>> schema.is_valid(df) + True + """ + code = _generate_schema_code(df, schema_name) + + if return_type is None: + print(code) # noqa: T201 + return None + if return_type == "string": + return code + if return_type == "schema": + import dataframely as dy + + namespace: dict = {"dy": dy} + exec(code, namespace) # noqa: S102 + return namespace[schema_name] + + msg = f"Invalid return_type: {return_type!r}" + raise ValueError(msg) + + +def _generate_schema_code(df: pl.DataFrame, schema_name: str) -> str: + """Generate schema code string from a DataFrame.""" + lines = [f"class {schema_name}(dy.Schema):"] + + for col_name, series in df.to_dict().items(): + if _is_valid_identifier(col_name): + attr_name = col_name + alias = None + else: + attr_name = _make_valid_identifier(col_name) + alias = col_name + col_code = _dtype_to_column_code(series, alias=alias) + lines.append(f" {attr_name} = {col_code}") + + return "\n".join(lines) + + +def _is_valid_identifier(name: str) -> bool: + """Check if a string is a valid Python identifier and not a keyword.""" + return name.isidentifier() and not keyword.iskeyword(name) + + +def _make_valid_identifier(name: str) -> str: + """Convert a string to a valid Python identifier.""" + # Replace invalid characters with underscores + result = re.sub(r"[^a-zA-Z0-9_]", "_", name) + # Ensure it doesn't start with a digit + if result and result[0].isdigit(): + result = "_" + result + # Ensure it's not empty + if not result: + result = "_column" + # Handle keywords + if keyword.iskeyword(result): + result = result + "_" + return result + + +def _format_args(*args: str, nullable: bool = False, alias: str | None = None) -> str: + """Format arguments for column constructor.""" + all_args = list(args) + if nullable: + all_args.insert(0, "nullable=True") + if alias: + all_args.insert(0, f'alias="{alias}"') + return ", ".join(all_args) + + +def _dtype_to_column_code(series: pl.Series, *, alias: str | None = None) -> str: + """Convert a Polars Series to dataframely column constructor code.""" + dtype = series.dtype + nullable = series.null_count() > 0 + + # Simple types + if dtype == pl.Boolean(): + return f"dy.Bool({_format_args(nullable=nullable, alias=alias)})" + if dtype == pl.Int8(): + return f"dy.Int8({_format_args(nullable=nullable, alias=alias)})" + if dtype == pl.Int16(): + return f"dy.Int16({_format_args(nullable=nullable, alias=alias)})" + if dtype == pl.Int32(): + return f"dy.Int32({_format_args(nullable=nullable, alias=alias)})" + if dtype == pl.Int64(): + return f"dy.Int64({_format_args(nullable=nullable, alias=alias)})" + if dtype == pl.UInt8(): + return f"dy.UInt8({_format_args(nullable=nullable, alias=alias)})" + if dtype == pl.UInt16(): + return f"dy.UInt16({_format_args(nullable=nullable, alias=alias)})" + if dtype == pl.UInt32(): + return f"dy.UInt32({_format_args(nullable=nullable, alias=alias)})" + if dtype == pl.UInt64(): + return f"dy.UInt64({_format_args(nullable=nullable, alias=alias)})" + if dtype == pl.Float32(): + return f"dy.Float32({_format_args(nullable=nullable, alias=alias)})" + if dtype == pl.Float64(): + return f"dy.Float64({_format_args(nullable=nullable, alias=alias)})" + if dtype == pl.String(): + return f"dy.String({_format_args(nullable=nullable, alias=alias)})" + if dtype == pl.Binary(): + return f"dy.Binary({_format_args(nullable=nullable, alias=alias)})" + if dtype == pl.Date(): + return f"dy.Date({_format_args(nullable=nullable, alias=alias)})" + if dtype == pl.Time(): + return f"dy.Time({_format_args(nullable=nullable, alias=alias)})" + if dtype == pl.Null(): + return f"dy.Any({_format_args(alias=alias)})" + if dtype == pl.Object(): + return f"dy.Object({_format_args(nullable=nullable, alias=alias)})" + if dtype == pl.Categorical(): + return f"dy.Categorical({_format_args(nullable=nullable, alias=alias)})" + + # Datetime with parameters + if isinstance(dtype, pl.Datetime): + args = [] + if dtype.time_zone is not None: + args.append(f'time_zone="{dtype.time_zone}"') + if dtype.time_unit != "us": # us is the default + args.append(f'time_unit="{dtype.time_unit}"') + return f"dy.Datetime({_format_args(*args, nullable=nullable, alias=alias)})" + + # Duration with time_unit + if isinstance(dtype, pl.Duration): + return f"dy.Duration({_format_args(nullable=nullable, alias=alias)})" + + # Decimal with precision and scale + if isinstance(dtype, pl.Decimal): + args = [] + if dtype.precision is not None: + args.append(f"precision={dtype.precision}") + if dtype.scale != 0: + args.append(f"scale={dtype.scale}") + return f"dy.Decimal({_format_args(*args, nullable=nullable, alias=alias)})" + + # Enum with categories + if isinstance(dtype, pl.Enum): + categories = dtype.categories.to_list() + return ( + f"dy.Enum({_format_args(repr(categories), nullable=nullable, alias=alias)})" + ) + + # List with inner type + if isinstance(dtype, pl.List): + inner_code = _dtype_to_column_code(series.explode()) + return f"dy.List({_format_args(inner_code, nullable=nullable, alias=alias)})" + + # Array with inner type and shape + if isinstance(dtype, pl.Array): + inner_code = _dtype_to_column_code(series.explode()) + return f"dy.Array({_format_args(inner_code, f'shape={dtype.size}', nullable=nullable, alias=alias)})" + + # Struct with fields + if isinstance(dtype, pl.Struct): + fields_parts = [] + for field in dtype.fields: + field_code = _dtype_to_column_code(series.struct.field(field.name)) + fields_parts.append(f'"{field.name}": {field_code}') + fields_dict = "{" + ", ".join(fields_parts) + "}" + return f"dy.Struct({_format_args(fields_dict, nullable=nullable, alias=alias)})" + + # Fallback for unknown types + return f"dy.Any({_format_args(alias=alias)}) # Unknown dtype: {dtype}" diff --git a/tests/test_infer_schema.py b/tests/test_infer_schema.py new file mode 100644 index 0000000..2dfec52 --- /dev/null +++ b/tests/test_infer_schema.py @@ -0,0 +1,272 @@ +# Copyright (c) QuantCo 2025-2026 +# SPDX-License-Identifier: BSD-3-Clause + +import datetime +import textwrap + +import polars as pl + +import dataframely as dy + + +class TestInferSchema: + def test_basic_types(self) -> None: + df = pl.DataFrame( + { + "int_col": [1, 2, 3], + "float_col": [1.0, 2.0, 3.0], + "str_col": ["a", "b", "c"], + "bool_col": [True, False, True], + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="BasicSchema") + expected = textwrap.dedent("""\ + class BasicSchema(dy.Schema): + int_col = dy.Int64() + float_col = dy.Float64() + str_col = dy.String() + bool_col = dy.Bool()""") + assert result == expected + + def test_nullable_detection(self) -> None: + df = pl.DataFrame( + { + "nullable_int": [1, None, 3], + "non_nullable_int": [1, 2, 3], + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="NullableSchema") + expected = textwrap.dedent("""\ + class NullableSchema(dy.Schema): + nullable_int = dy.Int64(nullable=True) + non_nullable_int = dy.Int64()""") + assert result == expected + + def test_datetime_types(self) -> None: + df = pl.DataFrame( + { + "date_col": [datetime.date(2024, 1, 1)], + "time_col": [datetime.time(12, 0, 0)], + "datetime_col": [datetime.datetime(2024, 1, 1, 12, 0, 0)], + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="DatetimeSchema") + expected = textwrap.dedent("""\ + class DatetimeSchema(dy.Schema): + date_col = dy.Date() + time_col = dy.Time() + datetime_col = dy.Datetime()""") + assert result == expected + + def test_datetime_with_timezone(self) -> None: + df = pl.DataFrame( + { + "utc_time": pl.Series( + [datetime.datetime(2024, 1, 1)] + ).dt.replace_time_zone("UTC"), + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="TzSchema") + expected = textwrap.dedent("""\ + class TzSchema(dy.Schema): + utc_time = dy.Datetime(time_zone="UTC")""") + assert result == expected + + def test_enum_type(self) -> None: + df = pl.DataFrame( + { + "status": pl.Series(["active", "pending"]).cast( + pl.Enum(["active", "pending", "inactive"]) + ), + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="EnumSchema") + expected = textwrap.dedent("""\ + class EnumSchema(dy.Schema): + status = dy.Enum(['active', 'pending', 'inactive'])""") + assert result == expected + + def test_decimal_type(self) -> None: + df = pl.DataFrame( + { + "amount": pl.Series(["10.50"]).cast(pl.Decimal(precision=10, scale=2)), + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="DecimalSchema") + expected = textwrap.dedent("""\ + class DecimalSchema(dy.Schema): + amount = dy.Decimal(precision=10, scale=2)""") + assert result == expected + + def test_list_type(self) -> None: + df = pl.DataFrame( + { + "tags": [["a", "b"], ["c"]], + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="ListSchema") + expected = textwrap.dedent("""\ + class ListSchema(dy.Schema): + tags = dy.List(dy.String())""") + assert result == expected + + def test_struct_type(self) -> None: + df = pl.DataFrame( + { + "metadata": [{"key": "value"}, {"key": "other"}], + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="StructSchema") + expected = textwrap.dedent("""\ + class StructSchema(dy.Schema): + metadata = dy.Struct({"key": dy.String()})""") + assert result == expected + + def test_list_with_nullable_inner(self) -> None: + df = pl.DataFrame({"names": [["Alice"], [None]]}) + result = dy.infer_schema( + df, return_type="string", schema_name="ListNullableInnerSchema" + ) + expected = textwrap.dedent("""\ + class ListNullableInnerSchema(dy.Schema): + names = dy.List(dy.String(nullable=True))""") + assert result == expected + + def test_struct_with_nullable_field(self) -> None: + df = pl.DataFrame({"data": [{"key": "value"}, {"key": None}]}) + result = dy.infer_schema( + df, return_type="string", schema_name="StructNullableFieldSchema" + ) + expected = textwrap.dedent("""\ + class StructNullableFieldSchema(dy.Schema): + data = dy.Struct({"key": dy.String(nullable=True)})""") + assert result == expected + + def test_array_type(self) -> None: + df = pl.DataFrame({"vector": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]}).cast( + {"vector": pl.Array(pl.Float64(), 3)} + ) + result = dy.infer_schema(df, return_type="string", schema_name="ArraySchema") + expected = textwrap.dedent("""\ + class ArraySchema(dy.Schema): + vector = dy.Array(dy.Float64(), shape=3)""") + assert result == expected + + def test_invalid_identifier(self) -> None: + df = pl.DataFrame( + { + "123invalid": ["test"], + } + ) + result = dy.infer_schema( + df, return_type="string", schema_name="InvalidIdSchema" + ) + expected = textwrap.dedent("""\ + class InvalidIdSchema(dy.Schema): + _123invalid = dy.String(alias="123invalid")""") + assert result == expected + + def test_python_keyword(self) -> None: + df = pl.DataFrame( + { + "class": ["test"], + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="KeywordSchema") + expected = textwrap.dedent("""\ + class KeywordSchema(dy.Schema): + class_ = dy.String(alias="class")""") + assert result == expected + + def test_all_integer_types(self) -> None: + df = pl.DataFrame( + { + "i8": pl.Series([1], dtype=pl.Int8), + "i16": pl.Series([1], dtype=pl.Int16), + "i32": pl.Series([1], dtype=pl.Int32), + "i64": pl.Series([1], dtype=pl.Int64), + "u8": pl.Series([1], dtype=pl.UInt8), + "u16": pl.Series([1], dtype=pl.UInt16), + "u32": pl.Series([1], dtype=pl.UInt32), + "u64": pl.Series([1], dtype=pl.UInt64), + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="IntSchema") + assert "dy.Int8()" in result + assert "dy.Int16()" in result + assert "dy.Int32()" in result + assert "dy.Int64()" in result + assert "dy.UInt8()" in result + assert "dy.UInt16()" in result + assert "dy.UInt32()" in result + assert "dy.UInt64()" in result + + def test_float_types(self) -> None: + df = pl.DataFrame( + { + "f32": pl.Series([1.0], dtype=pl.Float32), + "f64": pl.Series([1.0], dtype=pl.Float64), + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="FloatSchema") + assert "dy.Float32()" in result + assert "dy.Float64()" in result + + +class TestInferSchemaReturnsSchema: + """Test that return_type='schema' produces working schemas.""" + + def test_inferred_schema_validates_dataframe(self) -> None: + """Verify inferred schema validates the original dataframe.""" + dataframes = [ + # Basic types + pl.DataFrame( + { + "int_col": [1, 2, 3], + "float_col": [1.0, 2.0, 3.0], + "str_col": ["a", "b", "c"], + "bool_col": [True, False, True], + } + ), + # Nullable + pl.DataFrame({"nullable_int": [1, None, 3], "non_nullable_int": [1, 2, 3]}), + # Datetime types + pl.DataFrame( + { + "date_col": [datetime.date(2024, 1, 1)], + "time_col": [datetime.time(12, 0, 0)], + "datetime_col": [datetime.datetime(2024, 1, 1, 12, 0, 0)], + } + ), + # Enum + pl.DataFrame( + { + "status": pl.Series(["active", "pending"]).cast( + pl.Enum(["active", "pending", "inactive"]) + ) + } + ), + # List and struct + pl.DataFrame({"tags": [["a", "b"], ["c"]]}), + pl.DataFrame({"metadata": [{"key": "value"}]}), + # Array + pl.DataFrame({"vector": [[1.0, 2.0, 3.0]]}).cast( + {"vector": pl.Array(pl.Float64(), 3)} + ), + # Invalid identifiers and keywords + pl.DataFrame({"123invalid": ["test"], "class": ["test"]}), + # Decimal + pl.DataFrame( + {"amount": pl.Series(["10.50"]).cast(pl.Decimal(precision=10, scale=2))} + ), + # Nested types + pl.DataFrame({"nested_list": [[["a", "b"]]]}), + pl.DataFrame({"nested_struct": [{"outer": {"inner": "value"}}]}), + # Nullable inner types + pl.DataFrame({"list_with_nulls": [["a"], [None]]}), + pl.DataFrame({"struct_with_nulls": [{"key": "value"}, {"key": None}]}), + ] + + for i, df in enumerate(dataframes): + schema = dy.infer_schema(df, f"Schema{i}", return_type="schema") + assert schema.is_valid(df), f"Schema{i} failed for {df.schema}" From 6c19bfab2ff5a006e2d99400297160d8e5e8dc3e Mon Sep 17 00:00:00 2001 From: gabriel Date: Thu, 5 Mar 2026 11:20:21 +0100 Subject: [PATCH 02/13] increase code coverage --- dataframely/_generate_schema.py | 2 +- tests/test_infer_schema.py | 110 ++++++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+), 1 deletion(-) diff --git a/dataframely/_generate_schema.py b/dataframely/_generate_schema.py index 4fea170..0ce7c31 100644 --- a/dataframely/_generate_schema.py +++ b/dataframely/_generate_schema.py @@ -43,7 +43,7 @@ def infer_schema( def infer_schema( df: pl.DataFrame, - schema_name: str = "InferredSchema", + schema_name: str = "Schema", *, return_type: Literal["string", "schema"] | None = None, ) -> str | type[Schema] | None: diff --git a/tests/test_infer_schema.py b/tests/test_infer_schema.py index 2dfec52..e57aa95 100644 --- a/tests/test_infer_schema.py +++ b/tests/test_infer_schema.py @@ -5,6 +5,7 @@ import textwrap import polars as pl +import pytest import dataframely as dy @@ -213,6 +214,115 @@ def test_float_types(self) -> None: assert "dy.Float64()" in result +class TestInferSchemaReturnTypes: + """Test the different return_type options.""" + + def test_return_type_none_prints_to_stdout( + self, capsys: pytest.CaptureFixture[str] + ) -> None: + df = pl.DataFrame({"col": [1, 2, 3]}) + result = dy.infer_schema(df, "TestSchema") + assert result is None + captured = capsys.readouterr() + assert "class TestSchema(dy.Schema):" in captured.out + assert "col = dy.Int64()" in captured.out + + def test_return_type_string(self) -> None: + df = pl.DataFrame({"col": [1, 2, 3]}) + result = dy.infer_schema(df, "TestSchema", return_type="string") + assert isinstance(result, str) + assert "class TestSchema(dy.Schema):" in result + + def test_return_type_schema(self) -> None: + df = pl.DataFrame({"col": [1, 2, 3]}) + schema = dy.infer_schema(df, "TestSchema", return_type="schema") + assert schema.is_valid(df) + + def test_invalid_return_type_raises_error(self) -> None: + df = pl.DataFrame({"col": [1]}) + with pytest.raises(ValueError, match="Invalid return_type"): + dy.infer_schema(df, "Test", return_type="invalid") # type: ignore[call-overload] + + def test_default_schema_name(self) -> None: + df = pl.DataFrame({"col": [1]}) + result = dy.infer_schema(df, return_type="string") + assert "class Schema(dy.Schema):" in result + + +class TestSpecialTypes: + """Test special column types.""" + + def test_binary_type(self) -> None: + df = pl.DataFrame({"data": pl.Series([b"hello"], dtype=pl.Binary)}) + result = dy.infer_schema(df, return_type="string", schema_name="BinarySchema") + assert "dy.Binary()" in result + + def test_null_type(self) -> None: + df = pl.DataFrame({"null_col": pl.Series([None, None], dtype=pl.Null)}) + result = dy.infer_schema(df, return_type="string", schema_name="NullSchema") + assert "dy.Any()" in result + + def test_object_type(self) -> None: + df = pl.DataFrame({"obj": pl.Series([object()], dtype=pl.Object)}) + result = dy.infer_schema(df, return_type="string", schema_name="ObjectSchema") + assert "dy.Object()" in result + + def test_categorical_type(self) -> None: + df = pl.DataFrame({"cat": pl.Series(["a", "b"]).cast(pl.Categorical())}) + result = dy.infer_schema(df, return_type="string", schema_name="CatSchema") + assert "dy.Categorical()" in result + + def test_duration_type(self) -> None: + df = pl.DataFrame( + {"dur": pl.Series([datetime.timedelta(days=1)], dtype=pl.Duration)} + ) + result = dy.infer_schema(df, return_type="string", schema_name="DurSchema") + assert "dy.Duration()" in result + + def test_datetime_with_time_unit_ms(self) -> None: + df = pl.DataFrame( + {"dt": pl.Series([datetime.datetime(2024, 1, 1)]).cast(pl.Datetime("ms"))} + ) + result = dy.infer_schema(df, return_type="string", schema_name="DtSchema") + assert 'time_unit="ms"' in result + + def test_datetime_with_time_unit_ns(self) -> None: + df = pl.DataFrame( + {"dt": pl.Series([datetime.datetime(2024, 1, 1)]).cast(pl.Datetime("ns"))} + ) + result = dy.infer_schema(df, return_type="string", schema_name="DtSchema") + assert 'time_unit="ns"' in result + + def test_decimal_without_scale(self) -> None: + df = pl.DataFrame( + {"amount": pl.Series(["10"]).cast(pl.Decimal(precision=5, scale=0))} + ) + result = dy.infer_schema(df, return_type="string", schema_name="DecSchema") + assert "precision=5" in result + assert "scale=" not in result + + +class TestMakeValidIdentifier: + """Test edge cases of _make_valid_identifier.""" + + def test_column_with_special_chars_replaced(self) -> None: + df = pl.DataFrame({"!!!": ["test"]}) + result = dy.infer_schema(df, return_type="string", schema_name="SpecialSchema") + assert '___ = dy.String(alias="!!!")' in result + + def test_column_empty_after_sanitization(self) -> None: + # Empty string column name results in _column fallback + df = pl.DataFrame({"": ["test"]}) + result = dy.infer_schema(df, return_type="string", schema_name="EmptySchema") + # Empty string alias is not included (falsy), but _column is generated + assert "_column = dy.String()" in result + + def test_column_with_spaces(self) -> None: + df = pl.DataFrame({"col name": ["test"]}) + result = dy.infer_schema(df, return_type="string", schema_name="SpaceSchema") + assert 'col_name = dy.String(alias="col name")' in result + + class TestInferSchemaReturnsSchema: """Test that return_type='schema' produces working schemas.""" From f0e07fb14426abd8ea6120d0fe2a90d70c9ec59a Mon Sep 17 00:00:00 2001 From: gabriel Date: Thu, 5 Mar 2026 11:36:07 +0100 Subject: [PATCH 03/13] copilot --- dataframely/_generate_schema.py | 11 +++++++++-- docs/api/schema/index.rst | 1 + docs/api/schema/inference.rst | 9 +++++++++ tests/test_infer_schema.py | 28 +++++++++++++++++----------- 4 files changed, 36 insertions(+), 13 deletions(-) create mode 100644 docs/api/schema/inference.rst diff --git a/dataframely/_generate_schema.py b/dataframely/_generate_schema.py index 0ce7c31..68a5501 100644 --- a/dataframely/_generate_schema.py +++ b/dataframely/_generate_schema.py @@ -85,7 +85,14 @@ class PersonSchema(dy.Schema): >>> schema = dy.infer_schema(df, "PersonSchema", return_type="schema") >>> schema.is_valid(df) True + + Raises: + ValueError: If ``schema_name`` is not a valid Python identifier. """ + if not schema_name.isidentifier(): + msg = f"schema_name must be a valid Python identifier, got {schema_name!r}" + raise ValueError(msg) + code = _generate_schema_code(df, schema_name) if return_type is None: @@ -146,9 +153,9 @@ def _format_args(*args: str, nullable: bool = False, alias: str | None = None) - """Format arguments for column constructor.""" all_args = list(args) if nullable: - all_args.insert(0, "nullable=True") + all_args.append("nullable=True") if alias: - all_args.insert(0, f'alias="{alias}"') + all_args.append(f'alias="{alias}"') return ", ".join(all_args) diff --git a/docs/api/schema/index.rst b/docs/api/schema/index.rst index 77e0323..5ed25ba 100644 --- a/docs/api/schema/index.rst +++ b/docs/api/schema/index.rst @@ -9,6 +9,7 @@ Schema validation io generation + inference conversion metadata diff --git a/docs/api/schema/inference.rst b/docs/api/schema/inference.rst new file mode 100644 index 0000000..29d335e --- /dev/null +++ b/docs/api/schema/inference.rst @@ -0,0 +1,9 @@ +========= +Inference +========= + +.. currentmodule:: dataframely +.. autosummary:: + :toctree: _gen/ + + infer_schema diff --git a/tests/test_infer_schema.py b/tests/test_infer_schema.py index e57aa95..62dbd72 100644 --- a/tests/test_infer_schema.py +++ b/tests/test_infer_schema.py @@ -243,6 +243,13 @@ def test_invalid_return_type_raises_error(self) -> None: with pytest.raises(ValueError, match="Invalid return_type"): dy.infer_schema(df, "Test", return_type="invalid") # type: ignore[call-overload] + def test_invalid_schema_name_raises_error(self) -> None: + df = pl.DataFrame({"col": [1]}) + with pytest.raises( + ValueError, match="schema_name must be a valid Python identifier" + ): + dy.infer_schema(df, "Invalid Name") + def test_default_schema_name(self) -> None: df = pl.DataFrame({"col": [1]}) result = dy.infer_schema(df, return_type="string") @@ -324,11 +331,9 @@ def test_column_with_spaces(self) -> None: class TestInferSchemaReturnsSchema: - """Test that return_type='schema' produces working schemas.""" - - def test_inferred_schema_validates_dataframe(self) -> None: - """Verify inferred schema validates the original dataframe.""" - dataframes = [ + @pytest.mark.parametrize( + "df", + [ # Basic types pl.DataFrame( { @@ -356,8 +361,9 @@ def test_inferred_schema_validates_dataframe(self) -> None: ) } ), - # List and struct + # List pl.DataFrame({"tags": [["a", "b"], ["c"]]}), + # Struct pl.DataFrame({"metadata": [{"key": "value"}]}), # Array pl.DataFrame({"vector": [[1.0, 2.0, 3.0]]}).cast( @@ -375,8 +381,8 @@ def test_inferred_schema_validates_dataframe(self) -> None: # Nullable inner types pl.DataFrame({"list_with_nulls": [["a"], [None]]}), pl.DataFrame({"struct_with_nulls": [{"key": "value"}, {"key": None}]}), - ] - - for i, df in enumerate(dataframes): - schema = dy.infer_schema(df, f"Schema{i}", return_type="schema") - assert schema.is_valid(df), f"Schema{i} failed for {df.schema}" + ], + ) + def test_inferred_schema_validates_dataframe(self, df: pl.DataFrame) -> None: + schema = dy.infer_schema(df, "TestSchema", return_type="schema") + assert schema.is_valid(df) From 7ee32cf436d6a752b24b77675bed69ed6832c08c Mon Sep 17 00:00:00 2001 From: gabriel Date: Thu, 5 Mar 2026 13:46:59 +0100 Subject: [PATCH 04/13] pragma: no cover --- dataframely/_generate_schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataframely/_generate_schema.py b/dataframely/_generate_schema.py index 68a5501..934c657 100644 --- a/dataframely/_generate_schema.py +++ b/dataframely/_generate_schema.py @@ -251,4 +251,4 @@ def _dtype_to_column_code(series: pl.Series, *, alias: str | None = None) -> str return f"dy.Struct({_format_args(fields_dict, nullable=nullable, alias=alias)})" # Fallback for unknown types - return f"dy.Any({_format_args(alias=alias)}) # Unknown dtype: {dtype}" + return f"dy.Any({_format_args(alias=alias)}) # Unknown dtype: {dtype}" # pragma: no cover From 50b723d4574e91226c5a6a9cc022c22f394c6e60 Mon Sep 17 00:00:00 2001 From: gabriel Date: Mon, 16 Mar 2026 15:09:45 +0100 Subject: [PATCH 05/13] more concise --- dataframely/_generate_schema.py | 139 +++--- tests/test_infer_schema.py | 749 +++++++++++++++++--------------- 2 files changed, 452 insertions(+), 436 deletions(-) diff --git a/dataframely/_generate_schema.py b/dataframely/_generate_schema.py index 934c657..0f6c0a9 100644 --- a/dataframely/_generate_schema.py +++ b/dataframely/_generate_schema.py @@ -14,6 +14,34 @@ from dataframely.schema import Schema +_POLARS_DTYPE_MAP: dict[type[pl.DataType], str] = { + pl.Boolean: "Bool", + pl.Int8: "Int8", + pl.Int16: "Int16", + pl.Int32: "Int32", + pl.Int64: "Int64", + pl.UInt8: "UInt8", + pl.UInt16: "UInt16", + pl.UInt32: "UInt32", + pl.UInt64: "UInt64", + pl.Float32: "Float32", + pl.Float64: "Float64", + pl.String: "String", + pl.Binary: "Binary", + pl.Date: "Date", + pl.Time: "Time", + pl.Object: "Object", + pl.Categorical: "Categorical", + pl.Duration: "Duration", + pl.Datetime: "Datetime", + pl.Decimal: "Decimal", + pl.Enum: "Enum", + pl.List: "List", + pl.Array: "Array", + pl.Struct: "Struct", +} + + @overload def infer_schema( df: pl.DataFrame, @@ -149,106 +177,65 @@ def _make_valid_identifier(name: str) -> str: return result -def _format_args(*args: str, nullable: bool = False, alias: str | None = None) -> str: - """Format arguments for column constructor.""" - all_args = list(args) - if nullable: - all_args.append("nullable=True") - if alias: - all_args.append(f'alias="{alias}"') - return ", ".join(all_args) - - -def _dtype_to_column_code(series: pl.Series, *, alias: str | None = None) -> str: - """Convert a Polars Series to dataframely column constructor code.""" - dtype = series.dtype - nullable = series.null_count() > 0 - - # Simple types - if dtype == pl.Boolean(): - return f"dy.Bool({_format_args(nullable=nullable, alias=alias)})" - if dtype == pl.Int8(): - return f"dy.Int8({_format_args(nullable=nullable, alias=alias)})" - if dtype == pl.Int16(): - return f"dy.Int16({_format_args(nullable=nullable, alias=alias)})" - if dtype == pl.Int32(): - return f"dy.Int32({_format_args(nullable=nullable, alias=alias)})" - if dtype == pl.Int64(): - return f"dy.Int64({_format_args(nullable=nullable, alias=alias)})" - if dtype == pl.UInt8(): - return f"dy.UInt8({_format_args(nullable=nullable, alias=alias)})" - if dtype == pl.UInt16(): - return f"dy.UInt16({_format_args(nullable=nullable, alias=alias)})" - if dtype == pl.UInt32(): - return f"dy.UInt32({_format_args(nullable=nullable, alias=alias)})" - if dtype == pl.UInt64(): - return f"dy.UInt64({_format_args(nullable=nullable, alias=alias)})" - if dtype == pl.Float32(): - return f"dy.Float32({_format_args(nullable=nullable, alias=alias)})" - if dtype == pl.Float64(): - return f"dy.Float64({_format_args(nullable=nullable, alias=alias)})" - if dtype == pl.String(): - return f"dy.String({_format_args(nullable=nullable, alias=alias)})" - if dtype == pl.Binary(): - return f"dy.Binary({_format_args(nullable=nullable, alias=alias)})" - if dtype == pl.Date(): - return f"dy.Date({_format_args(nullable=nullable, alias=alias)})" - if dtype == pl.Time(): - return f"dy.Time({_format_args(nullable=nullable, alias=alias)})" - if dtype == pl.Null(): - return f"dy.Any({_format_args(alias=alias)})" - if dtype == pl.Object(): - return f"dy.Object({_format_args(nullable=nullable, alias=alias)})" - if dtype == pl.Categorical(): - return f"dy.Categorical({_format_args(nullable=nullable, alias=alias)})" - - # Datetime with parameters +def _get_dtype_args(dtype: pl.DataType, series: pl.Series) -> list[str]: + """Get extra arguments for parameterized types.""" if isinstance(dtype, pl.Datetime): args = [] if dtype.time_zone is not None: args.append(f'time_zone="{dtype.time_zone}"') - if dtype.time_unit != "us": # us is the default + if dtype.time_unit != "us": args.append(f'time_unit="{dtype.time_unit}"') - return f"dy.Datetime({_format_args(*args, nullable=nullable, alias=alias)})" + return args - # Duration with time_unit if isinstance(dtype, pl.Duration): - return f"dy.Duration({_format_args(nullable=nullable, alias=alias)})" + if dtype.time_unit != "us": # us is the default + return [f'time_unit="{dtype.time_unit}"'] - # Decimal with precision and scale if isinstance(dtype, pl.Decimal): args = [] if dtype.precision is not None: args.append(f"precision={dtype.precision}") if dtype.scale != 0: args.append(f"scale={dtype.scale}") - return f"dy.Decimal({_format_args(*args, nullable=nullable, alias=alias)})" + return args - # Enum with categories if isinstance(dtype, pl.Enum): - categories = dtype.categories.to_list() - return ( - f"dy.Enum({_format_args(repr(categories), nullable=nullable, alias=alias)})" - ) + return [repr(dtype.categories.to_list())] - # List with inner type if isinstance(dtype, pl.List): - inner_code = _dtype_to_column_code(series.explode()) - return f"dy.List({_format_args(inner_code, nullable=nullable, alias=alias)})" + return [_dtype_to_column_code(series.explode())] - # Array with inner type and shape if isinstance(dtype, pl.Array): - inner_code = _dtype_to_column_code(series.explode()) - return f"dy.Array({_format_args(inner_code, f'shape={dtype.size}', nullable=nullable, alias=alias)})" + return [_dtype_to_column_code(series.explode()), f"shape={dtype.size}"] - # Struct with fields if isinstance(dtype, pl.Struct): fields_parts = [] for field in dtype.fields: field_code = _dtype_to_column_code(series.struct.field(field.name)) fields_parts.append(f'"{field.name}": {field_code}') - fields_dict = "{" + ", ".join(fields_parts) + "}" - return f"dy.Struct({_format_args(fields_dict, nullable=nullable, alias=alias)})" + return ["{" + ", ".join(fields_parts) + "}"] + + return [] + + +def _format_args(*args: str, nullable: bool = False, alias: str | None = None) -> str: + """Format arguments for column constructor.""" + all_args = list(args) + if nullable: + all_args.append("nullable=True") + if alias: + all_args.append(f'alias="{alias}"') + return ", ".join(all_args) + + +def _dtype_to_column_code(series: pl.Series, *, alias: str | None = None) -> str: + """Convert a Polars Series to dataframely column constructor code.""" + dtype = series.dtype + nullable = series.null_count() > 0 + dy_name = _POLARS_DTYPE_MAP.get(type(dtype)) + + if dy_name is None: + return f"dy.Any({_format_args(alias=alias)}) # Unknown dtype: {dtype}" - # Fallback for unknown types - return f"dy.Any({_format_args(alias=alias)}) # Unknown dtype: {dtype}" # pragma: no cover + args = _get_dtype_args(dtype, series) + return f"dy.{dy_name}({_format_args(*args, nullable=nullable, alias=alias)})" diff --git a/tests/test_infer_schema.py b/tests/test_infer_schema.py index 62dbd72..0df76a4 100644 --- a/tests/test_infer_schema.py +++ b/tests/test_infer_schema.py @@ -10,379 +10,408 @@ import dataframely as dy -class TestInferSchema: - def test_basic_types(self) -> None: - df = pl.DataFrame( +def test_basic_types() -> None: + df = pl.DataFrame( + { + "int_col": [1, 2, 3], + "float_col": [1.0, 2.0, 3.0], + "str_col": ["a", "b", "c"], + "bool_col": [True, False, True], + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="BasicSchema") + expected = textwrap.dedent("""\ + class BasicSchema(dy.Schema): + int_col = dy.Int64() + float_col = dy.Float64() + str_col = dy.String() + bool_col = dy.Bool()""") + assert result == expected + + +def test_nullable_detection() -> None: + df = pl.DataFrame( + { + "nullable_int": [1, None, 3], + "non_nullable_int": [1, 2, 3], + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="NullableSchema") + expected = textwrap.dedent("""\ + class NullableSchema(dy.Schema): + nullable_int = dy.Int64(nullable=True) + non_nullable_int = dy.Int64()""") + assert result == expected + + +def test_datetime_types() -> None: + df = pl.DataFrame( + { + "date_col": [datetime.date(2024, 1, 1)], + "time_col": [datetime.time(12, 0, 0)], + "datetime_col": [datetime.datetime(2024, 1, 1, 12, 0, 0)], + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="DatetimeSchema") + expected = textwrap.dedent("""\ + class DatetimeSchema(dy.Schema): + date_col = dy.Date() + time_col = dy.Time() + datetime_col = dy.Datetime()""") + assert result == expected + + +def test_datetime_with_timezone() -> None: + df = pl.DataFrame( + { + "utc_time": pl.Series([datetime.datetime(2024, 1, 1)]).dt.replace_time_zone( + "UTC" + ), + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="TzSchema") + expected = textwrap.dedent("""\ + class TzSchema(dy.Schema): + utc_time = dy.Datetime(time_zone="UTC")""") + assert result == expected + + +def test_enum_type() -> None: + df = pl.DataFrame( + { + "status": pl.Series(["active", "pending"]).cast( + pl.Enum(["active", "pending", "inactive"]) + ), + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="EnumSchema") + expected = textwrap.dedent("""\ + class EnumSchema(dy.Schema): + status = dy.Enum(['active', 'pending', 'inactive'])""") + assert result == expected + + +def test_decimal_type() -> None: + df = pl.DataFrame( + { + "amount": pl.Series(["10.50"]).cast(pl.Decimal(precision=10, scale=2)), + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="DecimalSchema") + expected = textwrap.dedent("""\ + class DecimalSchema(dy.Schema): + amount = dy.Decimal(precision=10, scale=2)""") + assert result == expected + + +def test_list_type() -> None: + df = pl.DataFrame( + { + "tags": [["a", "b"], ["c"]], + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="ListSchema") + expected = textwrap.dedent("""\ + class ListSchema(dy.Schema): + tags = dy.List(dy.String())""") + assert result == expected + + +def test_struct_type() -> None: + df = pl.DataFrame( + { + "metadata": [{"key": "value"}, {"key": "other"}], + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="StructSchema") + expected = textwrap.dedent("""\ + class StructSchema(dy.Schema): + metadata = dy.Struct({"key": dy.String()})""") + assert result == expected + + +def test_list_with_nullable_inner() -> None: + df = pl.DataFrame({"names": [["Alice"], [None]]}) + result = dy.infer_schema( + df, return_type="string", schema_name="ListNullableInnerSchema" + ) + expected = textwrap.dedent("""\ + class ListNullableInnerSchema(dy.Schema): + names = dy.List(dy.String(nullable=True))""") + assert result == expected + + +def test_struct_with_nullable_field() -> None: + df = pl.DataFrame({"data": [{"key": "value"}, {"key": None}]}) + result = dy.infer_schema( + df, return_type="string", schema_name="StructNullableFieldSchema" + ) + expected = textwrap.dedent("""\ + class StructNullableFieldSchema(dy.Schema): + data = dy.Struct({"key": dy.String(nullable=True)})""") + assert result == expected + + +def test_array_type() -> None: + df = pl.DataFrame({"vector": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]}).cast( + {"vector": pl.Array(pl.Float64(), 3)} + ) + result = dy.infer_schema(df, return_type="string", schema_name="ArraySchema") + expected = textwrap.dedent("""\ + class ArraySchema(dy.Schema): + vector = dy.Array(dy.Float64(), shape=3)""") + assert result == expected + + +def test_invalid_identifier() -> None: + df = pl.DataFrame( + { + "123invalid": ["test"], + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="InvalidIdSchema") + expected = textwrap.dedent("""\ + class InvalidIdSchema(dy.Schema): + _123invalid = dy.String(alias="123invalid")""") + assert result == expected + + +def test_python_keyword() -> None: + df = pl.DataFrame( + { + "class": ["test"], + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="KeywordSchema") + expected = textwrap.dedent("""\ + class KeywordSchema(dy.Schema): + class_ = dy.String(alias="class")""") + assert result == expected + + +def test_all_integer_types() -> None: + df = pl.DataFrame( + { + "i8": pl.Series([1], dtype=pl.Int8), + "i16": pl.Series([1], dtype=pl.Int16), + "i32": pl.Series([1], dtype=pl.Int32), + "i64": pl.Series([1], dtype=pl.Int64), + "u8": pl.Series([1], dtype=pl.UInt8), + "u16": pl.Series([1], dtype=pl.UInt16), + "u32": pl.Series([1], dtype=pl.UInt32), + "u64": pl.Series([1], dtype=pl.UInt64), + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="IntSchema") + assert "dy.Int8()" in result + assert "dy.Int16()" in result + assert "dy.Int32()" in result + assert "dy.Int64()" in result + assert "dy.UInt8()" in result + assert "dy.UInt16()" in result + assert "dy.UInt32()" in result + assert "dy.UInt64()" in result + + +def test_float_types() -> None: + df = pl.DataFrame( + { + "f32": pl.Series([1.0], dtype=pl.Float32), + "f64": pl.Series([1.0], dtype=pl.Float64), + } + ) + result = dy.infer_schema(df, return_type="string", schema_name="FloatSchema") + assert "dy.Float32()" in result + assert "dy.Float64()" in result + + +def test_return_type_none_prints_to_stdout(capsys: pytest.CaptureFixture[str]) -> None: + df = pl.DataFrame({"col": [1, 2, 3]}) + result = dy.infer_schema(df, "TestSchema") + assert result is None + captured = capsys.readouterr() + assert "class TestSchema(dy.Schema):" in captured.out + assert "col = dy.Int64()" in captured.out + + +def test_return_type_string() -> None: + df = pl.DataFrame({"col": [1, 2, 3]}) + result = dy.infer_schema(df, "TestSchema", return_type="string") + assert isinstance(result, str) + assert "class TestSchema(dy.Schema):" in result + + +def test_return_type_schema() -> None: + df = pl.DataFrame({"col": [1, 2, 3]}) + schema = dy.infer_schema(df, "TestSchema", return_type="schema") + assert schema.is_valid(df) + + +def test_invalid_return_type_raises_error() -> None: + df = pl.DataFrame({"col": [1]}) + with pytest.raises(ValueError, match="Invalid return_type"): + dy.infer_schema(df, "Test", return_type="invalid") # type: ignore[call-overload] + + +def test_invalid_schema_name_raises_error() -> None: + df = pl.DataFrame({"col": [1]}) + with pytest.raises( + ValueError, match="schema_name must be a valid Python identifier" + ): + dy.infer_schema(df, "Invalid Name") + + +def test_default_schema_name() -> None: + df = pl.DataFrame({"col": [1]}) + result = dy.infer_schema(df, return_type="string") + assert "class Schema(dy.Schema):" in result + + +def test_binary_type() -> None: + df = pl.DataFrame({"data": pl.Series([b"hello"], dtype=pl.Binary)}) + result = dy.infer_schema(df, return_type="string", schema_name="BinarySchema") + assert "dy.Binary()" in result + + +def test_null_type() -> None: + df = pl.DataFrame({"null_col": pl.Series([None, None], dtype=pl.Null)}) + result = dy.infer_schema(df, return_type="string", schema_name="NullSchema") + assert "dy.Any()" in result + + +def test_object_type() -> None: + df = pl.DataFrame({"obj": pl.Series([object()], dtype=pl.Object)}) + result = dy.infer_schema(df, return_type="string", schema_name="ObjectSchema") + assert "dy.Object()" in result + + +def test_categorical_type() -> None: + df = pl.DataFrame({"cat": pl.Series(["a", "b"]).cast(pl.Categorical())}) + result = dy.infer_schema(df, return_type="string", schema_name="CatSchema") + assert "dy.Categorical()" in result + + +def test_duration_type() -> None: + df = pl.DataFrame( + {"dur": pl.Series([datetime.timedelta(days=1)], dtype=pl.Duration)} + ) + result = dy.infer_schema(df, return_type="string", schema_name="DurSchema") + assert "dy.Duration()" in result + + +def test_duration_with_time_unit_ms() -> None: + df = pl.DataFrame( + {"dur": pl.Series([datetime.timedelta(days=1)]).cast(pl.Duration("ms"))} + ) + result = dy.infer_schema(df, return_type="string", schema_name="DurSchema") + assert 'time_unit="ms"' in result + + +def test_duration_with_time_unit_ns() -> None: + df = pl.DataFrame( + {"dur": pl.Series([datetime.timedelta(days=1)]).cast(pl.Duration("ns"))} + ) + result = dy.infer_schema(df, return_type="string", schema_name="DurSchema") + assert 'time_unit="ns"' in result + + +def test_datetime_with_time_unit_ms() -> None: + df = pl.DataFrame( + {"dt": pl.Series([datetime.datetime(2024, 1, 1)]).cast(pl.Datetime("ms"))} + ) + result = dy.infer_schema(df, return_type="string", schema_name="DtSchema") + assert 'time_unit="ms"' in result + + +def test_datetime_with_time_unit_ns() -> None: + df = pl.DataFrame( + {"dt": pl.Series([datetime.datetime(2024, 1, 1)]).cast(pl.Datetime("ns"))} + ) + result = dy.infer_schema(df, return_type="string", schema_name="DtSchema") + assert 'time_unit="ns"' in result + + +def test_decimal_without_scale() -> None: + df = pl.DataFrame( + {"amount": pl.Series(["10"]).cast(pl.Decimal(precision=5, scale=0))} + ) + result = dy.infer_schema(df, return_type="string", schema_name="DecSchema") + assert "precision=5" in result + assert "scale=" not in result + + +def test_column_with_special_chars_replaced() -> None: + df = pl.DataFrame({"!!!": ["test"]}) + result = dy.infer_schema(df, return_type="string", schema_name="SpecialSchema") + assert '___ = dy.String(alias="!!!")' in result + + +def test_column_empty_after_sanitization() -> None: + # Empty string column name results in _column fallback + df = pl.DataFrame({"": ["test"]}) + result = dy.infer_schema(df, return_type="string", schema_name="EmptySchema") + # Empty string alias is not included (falsy), but _column is generated + assert "_column = dy.String()" in result + + +def test_column_with_spaces() -> None: + df = pl.DataFrame({"col name": ["test"]}) + result = dy.infer_schema(df, return_type="string", schema_name="SpaceSchema") + assert 'col_name = dy.String(alias="col name")' in result + + +@pytest.mark.parametrize( + "df", + [ + # Basic types + pl.DataFrame( { "int_col": [1, 2, 3], "float_col": [1.0, 2.0, 3.0], "str_col": ["a", "b", "c"], "bool_col": [True, False, True], } - ) - result = dy.infer_schema(df, return_type="string", schema_name="BasicSchema") - expected = textwrap.dedent("""\ - class BasicSchema(dy.Schema): - int_col = dy.Int64() - float_col = dy.Float64() - str_col = dy.String() - bool_col = dy.Bool()""") - assert result == expected - - def test_nullable_detection(self) -> None: - df = pl.DataFrame( - { - "nullable_int": [1, None, 3], - "non_nullable_int": [1, 2, 3], - } - ) - result = dy.infer_schema(df, return_type="string", schema_name="NullableSchema") - expected = textwrap.dedent("""\ - class NullableSchema(dy.Schema): - nullable_int = dy.Int64(nullable=True) - non_nullable_int = dy.Int64()""") - assert result == expected - - def test_datetime_types(self) -> None: - df = pl.DataFrame( + ), + # Nullable + pl.DataFrame({"nullable_int": [1, None, 3], "non_nullable_int": [1, 2, 3]}), + # Datetime types + pl.DataFrame( { "date_col": [datetime.date(2024, 1, 1)], "time_col": [datetime.time(12, 0, 0)], "datetime_col": [datetime.datetime(2024, 1, 1, 12, 0, 0)], } - ) - result = dy.infer_schema(df, return_type="string", schema_name="DatetimeSchema") - expected = textwrap.dedent("""\ - class DatetimeSchema(dy.Schema): - date_col = dy.Date() - time_col = dy.Time() - datetime_col = dy.Datetime()""") - assert result == expected - - def test_datetime_with_timezone(self) -> None: - df = pl.DataFrame( - { - "utc_time": pl.Series( - [datetime.datetime(2024, 1, 1)] - ).dt.replace_time_zone("UTC"), - } - ) - result = dy.infer_schema(df, return_type="string", schema_name="TzSchema") - expected = textwrap.dedent("""\ - class TzSchema(dy.Schema): - utc_time = dy.Datetime(time_zone="UTC")""") - assert result == expected - - def test_enum_type(self) -> None: - df = pl.DataFrame( + ), + # Enum + pl.DataFrame( { "status": pl.Series(["active", "pending"]).cast( pl.Enum(["active", "pending", "inactive"]) - ), - } - ) - result = dy.infer_schema(df, return_type="string", schema_name="EnumSchema") - expected = textwrap.dedent("""\ - class EnumSchema(dy.Schema): - status = dy.Enum(['active', 'pending', 'inactive'])""") - assert result == expected - - def test_decimal_type(self) -> None: - df = pl.DataFrame( - { - "amount": pl.Series(["10.50"]).cast(pl.Decimal(precision=10, scale=2)), - } - ) - result = dy.infer_schema(df, return_type="string", schema_name="DecimalSchema") - expected = textwrap.dedent("""\ - class DecimalSchema(dy.Schema): - amount = dy.Decimal(precision=10, scale=2)""") - assert result == expected - - def test_list_type(self) -> None: - df = pl.DataFrame( - { - "tags": [["a", "b"], ["c"]], - } - ) - result = dy.infer_schema(df, return_type="string", schema_name="ListSchema") - expected = textwrap.dedent("""\ - class ListSchema(dy.Schema): - tags = dy.List(dy.String())""") - assert result == expected - - def test_struct_type(self) -> None: - df = pl.DataFrame( - { - "metadata": [{"key": "value"}, {"key": "other"}], + ) } - ) - result = dy.infer_schema(df, return_type="string", schema_name="StructSchema") - expected = textwrap.dedent("""\ - class StructSchema(dy.Schema): - metadata = dy.Struct({"key": dy.String()})""") - assert result == expected - - def test_list_with_nullable_inner(self) -> None: - df = pl.DataFrame({"names": [["Alice"], [None]]}) - result = dy.infer_schema( - df, return_type="string", schema_name="ListNullableInnerSchema" - ) - expected = textwrap.dedent("""\ - class ListNullableInnerSchema(dy.Schema): - names = dy.List(dy.String(nullable=True))""") - assert result == expected - - def test_struct_with_nullable_field(self) -> None: - df = pl.DataFrame({"data": [{"key": "value"}, {"key": None}]}) - result = dy.infer_schema( - df, return_type="string", schema_name="StructNullableFieldSchema" - ) - expected = textwrap.dedent("""\ - class StructNullableFieldSchema(dy.Schema): - data = dy.Struct({"key": dy.String(nullable=True)})""") - assert result == expected - - def test_array_type(self) -> None: - df = pl.DataFrame({"vector": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]}).cast( + ), + # List + pl.DataFrame({"tags": [["a", "b"], ["c"]]}), + # Struct + pl.DataFrame({"metadata": [{"key": "value"}]}), + # Array + pl.DataFrame({"vector": [[1.0, 2.0, 3.0]]}).cast( {"vector": pl.Array(pl.Float64(), 3)} - ) - result = dy.infer_schema(df, return_type="string", schema_name="ArraySchema") - expected = textwrap.dedent("""\ - class ArraySchema(dy.Schema): - vector = dy.Array(dy.Float64(), shape=3)""") - assert result == expected - - def test_invalid_identifier(self) -> None: - df = pl.DataFrame( - { - "123invalid": ["test"], - } - ) - result = dy.infer_schema( - df, return_type="string", schema_name="InvalidIdSchema" - ) - expected = textwrap.dedent("""\ - class InvalidIdSchema(dy.Schema): - _123invalid = dy.String(alias="123invalid")""") - assert result == expected - - def test_python_keyword(self) -> None: - df = pl.DataFrame( - { - "class": ["test"], - } - ) - result = dy.infer_schema(df, return_type="string", schema_name="KeywordSchema") - expected = textwrap.dedent("""\ - class KeywordSchema(dy.Schema): - class_ = dy.String(alias="class")""") - assert result == expected - - def test_all_integer_types(self) -> None: - df = pl.DataFrame( - { - "i8": pl.Series([1], dtype=pl.Int8), - "i16": pl.Series([1], dtype=pl.Int16), - "i32": pl.Series([1], dtype=pl.Int32), - "i64": pl.Series([1], dtype=pl.Int64), - "u8": pl.Series([1], dtype=pl.UInt8), - "u16": pl.Series([1], dtype=pl.UInt16), - "u32": pl.Series([1], dtype=pl.UInt32), - "u64": pl.Series([1], dtype=pl.UInt64), - } - ) - result = dy.infer_schema(df, return_type="string", schema_name="IntSchema") - assert "dy.Int8()" in result - assert "dy.Int16()" in result - assert "dy.Int32()" in result - assert "dy.Int64()" in result - assert "dy.UInt8()" in result - assert "dy.UInt16()" in result - assert "dy.UInt32()" in result - assert "dy.UInt64()" in result - - def test_float_types(self) -> None: - df = pl.DataFrame( - { - "f32": pl.Series([1.0], dtype=pl.Float32), - "f64": pl.Series([1.0], dtype=pl.Float64), - } - ) - result = dy.infer_schema(df, return_type="string", schema_name="FloatSchema") - assert "dy.Float32()" in result - assert "dy.Float64()" in result - - -class TestInferSchemaReturnTypes: - """Test the different return_type options.""" - - def test_return_type_none_prints_to_stdout( - self, capsys: pytest.CaptureFixture[str] - ) -> None: - df = pl.DataFrame({"col": [1, 2, 3]}) - result = dy.infer_schema(df, "TestSchema") - assert result is None - captured = capsys.readouterr() - assert "class TestSchema(dy.Schema):" in captured.out - assert "col = dy.Int64()" in captured.out - - def test_return_type_string(self) -> None: - df = pl.DataFrame({"col": [1, 2, 3]}) - result = dy.infer_schema(df, "TestSchema", return_type="string") - assert isinstance(result, str) - assert "class TestSchema(dy.Schema):" in result - - def test_return_type_schema(self) -> None: - df = pl.DataFrame({"col": [1, 2, 3]}) - schema = dy.infer_schema(df, "TestSchema", return_type="schema") - assert schema.is_valid(df) - - def test_invalid_return_type_raises_error(self) -> None: - df = pl.DataFrame({"col": [1]}) - with pytest.raises(ValueError, match="Invalid return_type"): - dy.infer_schema(df, "Test", return_type="invalid") # type: ignore[call-overload] - - def test_invalid_schema_name_raises_error(self) -> None: - df = pl.DataFrame({"col": [1]}) - with pytest.raises( - ValueError, match="schema_name must be a valid Python identifier" - ): - dy.infer_schema(df, "Invalid Name") - - def test_default_schema_name(self) -> None: - df = pl.DataFrame({"col": [1]}) - result = dy.infer_schema(df, return_type="string") - assert "class Schema(dy.Schema):" in result - - -class TestSpecialTypes: - """Test special column types.""" - - def test_binary_type(self) -> None: - df = pl.DataFrame({"data": pl.Series([b"hello"], dtype=pl.Binary)}) - result = dy.infer_schema(df, return_type="string", schema_name="BinarySchema") - assert "dy.Binary()" in result - - def test_null_type(self) -> None: - df = pl.DataFrame({"null_col": pl.Series([None, None], dtype=pl.Null)}) - result = dy.infer_schema(df, return_type="string", schema_name="NullSchema") - assert "dy.Any()" in result - - def test_object_type(self) -> None: - df = pl.DataFrame({"obj": pl.Series([object()], dtype=pl.Object)}) - result = dy.infer_schema(df, return_type="string", schema_name="ObjectSchema") - assert "dy.Object()" in result - - def test_categorical_type(self) -> None: - df = pl.DataFrame({"cat": pl.Series(["a", "b"]).cast(pl.Categorical())}) - result = dy.infer_schema(df, return_type="string", schema_name="CatSchema") - assert "dy.Categorical()" in result - - def test_duration_type(self) -> None: - df = pl.DataFrame( - {"dur": pl.Series([datetime.timedelta(days=1)], dtype=pl.Duration)} - ) - result = dy.infer_schema(df, return_type="string", schema_name="DurSchema") - assert "dy.Duration()" in result - - def test_datetime_with_time_unit_ms(self) -> None: - df = pl.DataFrame( - {"dt": pl.Series([datetime.datetime(2024, 1, 1)]).cast(pl.Datetime("ms"))} - ) - result = dy.infer_schema(df, return_type="string", schema_name="DtSchema") - assert 'time_unit="ms"' in result - - def test_datetime_with_time_unit_ns(self) -> None: - df = pl.DataFrame( - {"dt": pl.Series([datetime.datetime(2024, 1, 1)]).cast(pl.Datetime("ns"))} - ) - result = dy.infer_schema(df, return_type="string", schema_name="DtSchema") - assert 'time_unit="ns"' in result - - def test_decimal_without_scale(self) -> None: - df = pl.DataFrame( - {"amount": pl.Series(["10"]).cast(pl.Decimal(precision=5, scale=0))} - ) - result = dy.infer_schema(df, return_type="string", schema_name="DecSchema") - assert "precision=5" in result - assert "scale=" not in result - - -class TestMakeValidIdentifier: - """Test edge cases of _make_valid_identifier.""" - - def test_column_with_special_chars_replaced(self) -> None: - df = pl.DataFrame({"!!!": ["test"]}) - result = dy.infer_schema(df, return_type="string", schema_name="SpecialSchema") - assert '___ = dy.String(alias="!!!")' in result - - def test_column_empty_after_sanitization(self) -> None: - # Empty string column name results in _column fallback - df = pl.DataFrame({"": ["test"]}) - result = dy.infer_schema(df, return_type="string", schema_name="EmptySchema") - # Empty string alias is not included (falsy), but _column is generated - assert "_column = dy.String()" in result - - def test_column_with_spaces(self) -> None: - df = pl.DataFrame({"col name": ["test"]}) - result = dy.infer_schema(df, return_type="string", schema_name="SpaceSchema") - assert 'col_name = dy.String(alias="col name")' in result - - -class TestInferSchemaReturnsSchema: - @pytest.mark.parametrize( - "df", - [ - # Basic types - pl.DataFrame( - { - "int_col": [1, 2, 3], - "float_col": [1.0, 2.0, 3.0], - "str_col": ["a", "b", "c"], - "bool_col": [True, False, True], - } - ), - # Nullable - pl.DataFrame({"nullable_int": [1, None, 3], "non_nullable_int": [1, 2, 3]}), - # Datetime types - pl.DataFrame( - { - "date_col": [datetime.date(2024, 1, 1)], - "time_col": [datetime.time(12, 0, 0)], - "datetime_col": [datetime.datetime(2024, 1, 1, 12, 0, 0)], - } - ), - # Enum - pl.DataFrame( - { - "status": pl.Series(["active", "pending"]).cast( - pl.Enum(["active", "pending", "inactive"]) - ) - } - ), - # List - pl.DataFrame({"tags": [["a", "b"], ["c"]]}), - # Struct - pl.DataFrame({"metadata": [{"key": "value"}]}), - # Array - pl.DataFrame({"vector": [[1.0, 2.0, 3.0]]}).cast( - {"vector": pl.Array(pl.Float64(), 3)} - ), - # Invalid identifiers and keywords - pl.DataFrame({"123invalid": ["test"], "class": ["test"]}), - # Decimal - pl.DataFrame( - {"amount": pl.Series(["10.50"]).cast(pl.Decimal(precision=10, scale=2))} - ), - # Nested types - pl.DataFrame({"nested_list": [[["a", "b"]]]}), - pl.DataFrame({"nested_struct": [{"outer": {"inner": "value"}}]}), - # Nullable inner types - pl.DataFrame({"list_with_nulls": [["a"], [None]]}), - pl.DataFrame({"struct_with_nulls": [{"key": "value"}, {"key": None}]}), - ], - ) - def test_inferred_schema_validates_dataframe(self, df: pl.DataFrame) -> None: - schema = dy.infer_schema(df, "TestSchema", return_type="schema") - assert schema.is_valid(df) + ), + # Invalid identifiers and keywords + pl.DataFrame({"123invalid": ["test"], "class": ["test"]}), + # Decimal + pl.DataFrame( + {"amount": pl.Series(["10.50"]).cast(pl.Decimal(precision=10, scale=2))} + ), + # Nested types + pl.DataFrame({"nested_list": [[["a", "b"]]]}), + pl.DataFrame({"nested_struct": [{"outer": {"inner": "value"}}]}), + # Nullable inner types + pl.DataFrame({"list_with_nulls": [["a"], [None]]}), + pl.DataFrame({"struct_with_nulls": [{"key": "value"}, {"key": None}]}), + ], +) +def test_inferred_schema_validates_dataframe(df: pl.DataFrame) -> None: + schema = dy.infer_schema(df, "TestSchema", return_type="schema") + assert schema.is_valid(df) From d6ee33c0ba3c9526fb1018b893842babf3540793 Mon Sep 17 00:00:00 2001 From: gabriel Date: Tue, 17 Mar 2026 15:11:45 +0100 Subject: [PATCH 06/13] fix duplicated names --- dataframely/__init__.py | 2 +- .../{_generate_schema.py => _infer_schema.py} | 124 +++++------------- tests/test_infer_schema.py | 114 ++++++---------- 3 files changed, 74 insertions(+), 166 deletions(-) rename dataframely/{_generate_schema.py => _infer_schema.py} (62%) diff --git a/dataframely/__init__.py b/dataframely/__init__.py index ca2b586..c877dcc 100644 --- a/dataframely/__init__.py +++ b/dataframely/__init__.py @@ -12,7 +12,7 @@ from . import random from ._filter import filter -from ._generate_schema import infer_schema +from ._infer_schema import infer_schema from ._rule import rule from ._typing import DataFrame, LazyFrame, Validation from .collection import ( diff --git a/dataframely/_generate_schema.py b/dataframely/_infer_schema.py similarity index 62% rename from dataframely/_generate_schema.py rename to dataframely/_infer_schema.py index 0f6c0a9..4c19d02 100644 --- a/dataframely/_generate_schema.py +++ b/dataframely/_infer_schema.py @@ -6,14 +6,9 @@ import keyword import re -from typing import TYPE_CHECKING, Literal, overload import polars as pl -if TYPE_CHECKING: - from dataframely.schema import Schema - - _POLARS_DTYPE_MAP: dict[type[pl.DataType], str] = { pl.Boolean: "Bool", pl.Int8: "Int8", @@ -42,60 +37,21 @@ } -@overload -def infer_schema( - df: pl.DataFrame, - schema_name: str = ..., - *, - return_type: None = ..., -) -> None: ... - - -@overload -def infer_schema( - df: pl.DataFrame, - schema_name: str = ..., - *, - return_type: Literal["string"], -) -> str: ... - - -@overload -def infer_schema( - df: pl.DataFrame, - schema_name: str = ..., - *, - return_type: Literal["schema"], -) -> type[Schema]: ... - - def infer_schema( df: pl.DataFrame, schema_name: str = "Schema", - *, - return_type: Literal["string", "schema"] | None = None, -) -> str | type[Schema] | None: +) -> str: """Infer a dataframely schema from a Polars DataFrame. - This function inspects a DataFrame's schema and generates a corresponding - dataframely Schema. It can print the schema code, return it as a string, - or return an actual Schema class. + This function inspects a DataFrame's schema and generates corresponding + dataframely Schema code as a string. Args: df: The Polars DataFrame to infer the schema from. schema_name: The name for the generated schema class. - return_type: Controls the return format: - - - ``None`` (default): Print the schema code to stdout, return ``None``. - - ``"string"``: Return the schema code as a string. - - ``"schema"``: Return an actual Schema class. Returns: - Depends on ``return_type``: - - - ``None``: Returns ``None`` (prints to stdout). - - ``"string"``: Returns the schema code as a string. - - ``"schema"``: Returns a Schema class that can be used directly. + The schema code as a string. Example: >>> import polars as pl @@ -105,14 +61,11 @@ def infer_schema( ... "age": [25, 30], ... "score": [95.5, None], ... }) - >>> dy.infer_schema(df, "PersonSchema") + >>> print(dy.infer_schema(df, "PersonSchema")) class PersonSchema(dy.Schema): name = dy.String() age = dy.Int64() score = dy.Float64(nullable=True) - >>> schema = dy.infer_schema(df, "PersonSchema", return_type="schema") - >>> schema.is_valid(df) - True Raises: ValueError: If ``schema_name`` is not a valid Python identifier. @@ -121,60 +74,47 @@ class PersonSchema(dy.Schema): msg = f"schema_name must be a valid Python identifier, got {schema_name!r}" raise ValueError(msg) - code = _generate_schema_code(df, schema_name) - - if return_type is None: - print(code) # noqa: T201 - return None - if return_type == "string": - return code - if return_type == "schema": - import dataframely as dy - - namespace: dict = {"dy": dy} - exec(code, namespace) # noqa: S102 - return namespace[schema_name] - - msg = f"Invalid return_type: {return_type!r}" - raise ValueError(msg) + return _generate_schema_code(df, schema_name) def _generate_schema_code(df: pl.DataFrame, schema_name: str) -> str: """Generate schema code string from a DataFrame.""" lines = [f"class {schema_name}(dy.Schema):"] - - for col_name, series in df.to_dict().items(): - if _is_valid_identifier(col_name): - attr_name = col_name - alias = None - else: - attr_name = _make_valid_identifier(col_name) - alias = col_name + used_identifiers: set[str] = set() + + for idx, (col_name, series) in enumerate(df.to_dict().items()): + attr_name = _make_valid_identifier(col_name) + # Make sure yes have no duplicates + if attr_name in used_identifiers: + # Remove trailing "_" if exists as it will be included in the suffix anyway + if attr_name.endswith("_"): + attr_name = attr_name[:-1] + idx = 1 + while f"{attr_name}_{idx}" in used_identifiers: + idx += 1 + attr_name = f"{attr_name}_{idx}" + used_identifiers.add(attr_name) + alias = col_name if attr_name != col_name else None col_code = _dtype_to_column_code(series, alias=alias) lines.append(f" {attr_name} = {col_code}") return "\n".join(lines) -def _is_valid_identifier(name: str) -> bool: - """Check if a string is a valid Python identifier and not a keyword.""" - return name.isidentifier() and not keyword.iskeyword(name) - - def _make_valid_identifier(name: str) -> str: """Convert a string to a valid Python identifier.""" # Replace invalid characters with underscores - result = re.sub(r"[^a-zA-Z0-9_]", "_", name) + valid_identifier = re.sub(r"[^a-zA-Z0-9_]", "_", name) + + # Handle empty name or name with only special characters ones with simple "_" + if set(valid_identifier).issubset({"_"}): + return "_" # Ensure it doesn't start with a digit - if result and result[0].isdigit(): - result = "_" + result - # Ensure it's not empty - if not result: - result = "_column" - # Handle keywords - if keyword.iskeyword(result): - result = result + "_" - return result + if valid_identifier[0].isdigit(): + return "_" + valid_identifier + if keyword.iskeyword(valid_identifier): + return valid_identifier + "_" + return valid_identifier def _get_dtype_args(dtype: pl.DataType, series: pl.Series) -> list[str]: @@ -223,7 +163,7 @@ def _format_args(*args: str, nullable: bool = False, alias: str | None = None) - all_args = list(args) if nullable: all_args.append("nullable=True") - if alias: + if alias is not None: all_args.append(f'alias="{alias}"') return ", ".join(all_args) diff --git a/tests/test_infer_schema.py b/tests/test_infer_schema.py index 0df76a4..4f74c3f 100644 --- a/tests/test_infer_schema.py +++ b/tests/test_infer_schema.py @@ -19,7 +19,7 @@ def test_basic_types() -> None: "bool_col": [True, False, True], } ) - result = dy.infer_schema(df, return_type="string", schema_name="BasicSchema") + result = dy.infer_schema(df, "BasicSchema") expected = textwrap.dedent("""\ class BasicSchema(dy.Schema): int_col = dy.Int64() @@ -36,7 +36,7 @@ def test_nullable_detection() -> None: "non_nullable_int": [1, 2, 3], } ) - result = dy.infer_schema(df, return_type="string", schema_name="NullableSchema") + result = dy.infer_schema(df, "NullableSchema") expected = textwrap.dedent("""\ class NullableSchema(dy.Schema): nullable_int = dy.Int64(nullable=True) @@ -52,7 +52,7 @@ def test_datetime_types() -> None: "datetime_col": [datetime.datetime(2024, 1, 1, 12, 0, 0)], } ) - result = dy.infer_schema(df, return_type="string", schema_name="DatetimeSchema") + result = dy.infer_schema(df, schema_name="DatetimeSchema") expected = textwrap.dedent("""\ class DatetimeSchema(dy.Schema): date_col = dy.Date() @@ -69,7 +69,7 @@ def test_datetime_with_timezone() -> None: ), } ) - result = dy.infer_schema(df, return_type="string", schema_name="TzSchema") + result = dy.infer_schema(df, schema_name="TzSchema") expected = textwrap.dedent("""\ class TzSchema(dy.Schema): utc_time = dy.Datetime(time_zone="UTC")""") @@ -84,7 +84,7 @@ def test_enum_type() -> None: ), } ) - result = dy.infer_schema(df, return_type="string", schema_name="EnumSchema") + result = dy.infer_schema(df, schema_name="EnumSchema") expected = textwrap.dedent("""\ class EnumSchema(dy.Schema): status = dy.Enum(['active', 'pending', 'inactive'])""") @@ -97,7 +97,7 @@ def test_decimal_type() -> None: "amount": pl.Series(["10.50"]).cast(pl.Decimal(precision=10, scale=2)), } ) - result = dy.infer_schema(df, return_type="string", schema_name="DecimalSchema") + result = dy.infer_schema(df, schema_name="DecimalSchema") expected = textwrap.dedent("""\ class DecimalSchema(dy.Schema): amount = dy.Decimal(precision=10, scale=2)""") @@ -110,7 +110,7 @@ def test_list_type() -> None: "tags": [["a", "b"], ["c"]], } ) - result = dy.infer_schema(df, return_type="string", schema_name="ListSchema") + result = dy.infer_schema(df, schema_name="ListSchema") expected = textwrap.dedent("""\ class ListSchema(dy.Schema): tags = dy.List(dy.String())""") @@ -123,7 +123,7 @@ def test_struct_type() -> None: "metadata": [{"key": "value"}, {"key": "other"}], } ) - result = dy.infer_schema(df, return_type="string", schema_name="StructSchema") + result = dy.infer_schema(df, schema_name="StructSchema") expected = textwrap.dedent("""\ class StructSchema(dy.Schema): metadata = dy.Struct({"key": dy.String()})""") @@ -132,9 +132,7 @@ class StructSchema(dy.Schema): def test_list_with_nullable_inner() -> None: df = pl.DataFrame({"names": [["Alice"], [None]]}) - result = dy.infer_schema( - df, return_type="string", schema_name="ListNullableInnerSchema" - ) + result = dy.infer_schema(df, schema_name="ListNullableInnerSchema") expected = textwrap.dedent("""\ class ListNullableInnerSchema(dy.Schema): names = dy.List(dy.String(nullable=True))""") @@ -143,9 +141,7 @@ class ListNullableInnerSchema(dy.Schema): def test_struct_with_nullable_field() -> None: df = pl.DataFrame({"data": [{"key": "value"}, {"key": None}]}) - result = dy.infer_schema( - df, return_type="string", schema_name="StructNullableFieldSchema" - ) + result = dy.infer_schema(df, schema_name="StructNullableFieldSchema") expected = textwrap.dedent("""\ class StructNullableFieldSchema(dy.Schema): data = dy.Struct({"key": dy.String(nullable=True)})""") @@ -156,7 +152,7 @@ def test_array_type() -> None: df = pl.DataFrame({"vector": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]}).cast( {"vector": pl.Array(pl.Float64(), 3)} ) - result = dy.infer_schema(df, return_type="string", schema_name="ArraySchema") + result = dy.infer_schema(df, schema_name="ArraySchema") expected = textwrap.dedent("""\ class ArraySchema(dy.Schema): vector = dy.Array(dy.Float64(), shape=3)""") @@ -169,7 +165,7 @@ def test_invalid_identifier() -> None: "123invalid": ["test"], } ) - result = dy.infer_schema(df, return_type="string", schema_name="InvalidIdSchema") + result = dy.infer_schema(df, schema_name="InvalidIdSchema") expected = textwrap.dedent("""\ class InvalidIdSchema(dy.Schema): _123invalid = dy.String(alias="123invalid")""") @@ -182,7 +178,7 @@ def test_python_keyword() -> None: "class": ["test"], } ) - result = dy.infer_schema(df, return_type="string", schema_name="KeywordSchema") + result = dy.infer_schema(df, schema_name="KeywordSchema") expected = textwrap.dedent("""\ class KeywordSchema(dy.Schema): class_ = dy.String(alias="class")""") @@ -202,7 +198,7 @@ def test_all_integer_types() -> None: "u64": pl.Series([1], dtype=pl.UInt64), } ) - result = dy.infer_schema(df, return_type="string", schema_name="IntSchema") + result = dy.infer_schema(df, schema_name="IntSchema") assert "dy.Int8()" in result assert "dy.Int16()" in result assert "dy.Int32()" in result @@ -220,39 +216,11 @@ def test_float_types() -> None: "f64": pl.Series([1.0], dtype=pl.Float64), } ) - result = dy.infer_schema(df, return_type="string", schema_name="FloatSchema") + result = dy.infer_schema(df, schema_name="FloatSchema") assert "dy.Float32()" in result assert "dy.Float64()" in result -def test_return_type_none_prints_to_stdout(capsys: pytest.CaptureFixture[str]) -> None: - df = pl.DataFrame({"col": [1, 2, 3]}) - result = dy.infer_schema(df, "TestSchema") - assert result is None - captured = capsys.readouterr() - assert "class TestSchema(dy.Schema):" in captured.out - assert "col = dy.Int64()" in captured.out - - -def test_return_type_string() -> None: - df = pl.DataFrame({"col": [1, 2, 3]}) - result = dy.infer_schema(df, "TestSchema", return_type="string") - assert isinstance(result, str) - assert "class TestSchema(dy.Schema):" in result - - -def test_return_type_schema() -> None: - df = pl.DataFrame({"col": [1, 2, 3]}) - schema = dy.infer_schema(df, "TestSchema", return_type="schema") - assert schema.is_valid(df) - - -def test_invalid_return_type_raises_error() -> None: - df = pl.DataFrame({"col": [1]}) - with pytest.raises(ValueError, match="Invalid return_type"): - dy.infer_schema(df, "Test", return_type="invalid") # type: ignore[call-overload] - - def test_invalid_schema_name_raises_error() -> None: df = pl.DataFrame({"col": [1]}) with pytest.raises( @@ -263,31 +231,31 @@ def test_invalid_schema_name_raises_error() -> None: def test_default_schema_name() -> None: df = pl.DataFrame({"col": [1]}) - result = dy.infer_schema(df, return_type="string") + result = dy.infer_schema(df) assert "class Schema(dy.Schema):" in result def test_binary_type() -> None: df = pl.DataFrame({"data": pl.Series([b"hello"], dtype=pl.Binary)}) - result = dy.infer_schema(df, return_type="string", schema_name="BinarySchema") + result = dy.infer_schema(df, schema_name="BinarySchema") assert "dy.Binary()" in result def test_null_type() -> None: df = pl.DataFrame({"null_col": pl.Series([None, None], dtype=pl.Null)}) - result = dy.infer_schema(df, return_type="string", schema_name="NullSchema") + result = dy.infer_schema(df, schema_name="NullSchema") assert "dy.Any()" in result def test_object_type() -> None: df = pl.DataFrame({"obj": pl.Series([object()], dtype=pl.Object)}) - result = dy.infer_schema(df, return_type="string", schema_name="ObjectSchema") + result = dy.infer_schema(df, schema_name="ObjectSchema") assert "dy.Object()" in result def test_categorical_type() -> None: df = pl.DataFrame({"cat": pl.Series(["a", "b"]).cast(pl.Categorical())}) - result = dy.infer_schema(df, return_type="string", schema_name="CatSchema") + result = dy.infer_schema(df, schema_name="CatSchema") assert "dy.Categorical()" in result @@ -295,7 +263,7 @@ def test_duration_type() -> None: df = pl.DataFrame( {"dur": pl.Series([datetime.timedelta(days=1)], dtype=pl.Duration)} ) - result = dy.infer_schema(df, return_type="string", schema_name="DurSchema") + result = dy.infer_schema(df, schema_name="DurSchema") assert "dy.Duration()" in result @@ -303,7 +271,7 @@ def test_duration_with_time_unit_ms() -> None: df = pl.DataFrame( {"dur": pl.Series([datetime.timedelta(days=1)]).cast(pl.Duration("ms"))} ) - result = dy.infer_schema(df, return_type="string", schema_name="DurSchema") + result = dy.infer_schema(df, schema_name="DurSchema") assert 'time_unit="ms"' in result @@ -311,7 +279,7 @@ def test_duration_with_time_unit_ns() -> None: df = pl.DataFrame( {"dur": pl.Series([datetime.timedelta(days=1)]).cast(pl.Duration("ns"))} ) - result = dy.infer_schema(df, return_type="string", schema_name="DurSchema") + result = dy.infer_schema(df, schema_name="DurSchema") assert 'time_unit="ns"' in result @@ -319,7 +287,7 @@ def test_datetime_with_time_unit_ms() -> None: df = pl.DataFrame( {"dt": pl.Series([datetime.datetime(2024, 1, 1)]).cast(pl.Datetime("ms"))} ) - result = dy.infer_schema(df, return_type="string", schema_name="DtSchema") + result = dy.infer_schema(df, schema_name="DtSchema") assert 'time_unit="ms"' in result @@ -327,7 +295,7 @@ def test_datetime_with_time_unit_ns() -> None: df = pl.DataFrame( {"dt": pl.Series([datetime.datetime(2024, 1, 1)]).cast(pl.Datetime("ns"))} ) - result = dy.infer_schema(df, return_type="string", schema_name="DtSchema") + result = dy.infer_schema(df, schema_name="DtSchema") assert 'time_unit="ns"' in result @@ -335,29 +303,26 @@ def test_decimal_without_scale() -> None: df = pl.DataFrame( {"amount": pl.Series(["10"]).cast(pl.Decimal(precision=5, scale=0))} ) - result = dy.infer_schema(df, return_type="string", schema_name="DecSchema") + result = dy.infer_schema(df, schema_name="DecSchema") assert "precision=5" in result assert "scale=" not in result -def test_column_with_special_chars_replaced() -> None: - df = pl.DataFrame({"!!!": ["test"]}) - result = dy.infer_schema(df, return_type="string", schema_name="SpecialSchema") - assert '___ = dy.String(alias="!!!")' in result - - -def test_column_empty_after_sanitization() -> None: - # Empty string column name results in _column fallback - df = pl.DataFrame({"": ["test"]}) - result = dy.infer_schema(df, return_type="string", schema_name="EmptySchema") - # Empty string alias is not included (falsy), but _column is generated - assert "_column = dy.String()" in result +def test_column_sanitization() -> None: + df = pl.DataFrame({"$": ["test"], "valid": ["test"], "!!!": ["test"]}) + result = dy.infer_schema(df) + assert '_ = dy.String(alias="$")' in result + assert '_1 = dy.String(alias="!!!")' in result def test_column_with_spaces() -> None: - df = pl.DataFrame({"col name": ["test"]}) - result = dy.infer_schema(df, return_type="string", schema_name="SpaceSchema") + df = pl.DataFrame( + {"col name": ["test"], "col_name": ["test"], "col_name_1": ["test"]} + ) + result = dy.infer_schema(df, schema_name="SpaceSchema") assert 'col_name = dy.String(alias="col name")' in result + assert 'col_name_1 = dy.String(alias="col_name")' in result + assert 'col_name_1_1 = dy.String(alias="col_name_1")' in result @pytest.mark.parametrize( @@ -413,5 +378,8 @@ def test_column_with_spaces() -> None: ], ) def test_inferred_schema_validates_dataframe(df: pl.DataFrame) -> None: - schema = dy.infer_schema(df, "TestSchema", return_type="schema") + code = dy.infer_schema(df, "TestSchema") + namespace: dict = {"dy": dy} + exec(code, namespace) # noqa: S102 + schema = namespace["TestSchema"] assert schema.is_valid(df) From 142b036f221c81a8710881579087d22d56fd9e48 Mon Sep 17 00:00:00 2001 From: gabriel Date: Tue, 17 Mar 2026 15:29:45 +0100 Subject: [PATCH 07/13] code cov --- tests/test_infer_schema.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_infer_schema.py b/tests/test_infer_schema.py index 4f74c3f..55ec126 100644 --- a/tests/test_infer_schema.py +++ b/tests/test_infer_schema.py @@ -309,10 +309,11 @@ def test_decimal_without_scale() -> None: def test_column_sanitization() -> None: - df = pl.DataFrame({"$": ["test"], "valid": ["test"], "!!!": ["test"]}) + df = pl.DataFrame({"$": ["test"], "valid": ["test"], "!!!": ["test"], "": ["test"]}) result = dy.infer_schema(df) assert '_ = dy.String(alias="$")' in result assert '_1 = dy.String(alias="!!!")' in result + assert '_2 = dy.String(alias="")' in result def test_column_with_spaces() -> None: From 8009321c0293eee78cd034982997cfe1188ad62a Mon Sep 17 00:00:00 2001 From: gabriel Date: Wed, 18 Mar 2026 10:57:32 +0100 Subject: [PATCH 08/13] replace _ by columns_{index} --- dataframely/_infer_schema.py | 8 ++++---- tests/test_infer_schema.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/dataframely/_infer_schema.py b/dataframely/_infer_schema.py index 4c19d02..11aa5fc 100644 --- a/dataframely/_infer_schema.py +++ b/dataframely/_infer_schema.py @@ -82,8 +82,8 @@ def _generate_schema_code(df: pl.DataFrame, schema_name: str) -> str: lines = [f"class {schema_name}(dy.Schema):"] used_identifiers: set[str] = set() - for idx, (col_name, series) in enumerate(df.to_dict().items()): - attr_name = _make_valid_identifier(col_name) + for col_index, (col_name, series) in enumerate(df.to_dict().items()): + attr_name = _make_valid_identifier(col_name, col_index) # Make sure yes have no duplicates if attr_name in used_identifiers: # Remove trailing "_" if exists as it will be included in the suffix anyway @@ -101,14 +101,14 @@ def _generate_schema_code(df: pl.DataFrame, schema_name: str) -> str: return "\n".join(lines) -def _make_valid_identifier(name: str) -> str: +def _make_valid_identifier(name: str, col_index: int) -> str: """Convert a string to a valid Python identifier.""" # Replace invalid characters with underscores valid_identifier = re.sub(r"[^a-zA-Z0-9_]", "_", name) # Handle empty name or name with only special characters ones with simple "_" if set(valid_identifier).issubset({"_"}): - return "_" + return f"column_{col_index}" # Ensure it doesn't start with a digit if valid_identifier[0].isdigit(): return "_" + valid_identifier diff --git a/tests/test_infer_schema.py b/tests/test_infer_schema.py index 55ec126..ad63d1a 100644 --- a/tests/test_infer_schema.py +++ b/tests/test_infer_schema.py @@ -311,9 +311,9 @@ def test_decimal_without_scale() -> None: def test_column_sanitization() -> None: df = pl.DataFrame({"$": ["test"], "valid": ["test"], "!!!": ["test"], "": ["test"]}) result = dy.infer_schema(df) - assert '_ = dy.String(alias="$")' in result - assert '_1 = dy.String(alias="!!!")' in result - assert '_2 = dy.String(alias="")' in result + assert 'column_0 = dy.String(alias="$")' in result + assert 'column_2 = dy.String(alias="!!!")' in result + assert 'column_3 = dy.String(alias="")' in result def test_column_with_spaces() -> None: From ef573f8e2351d9032cd0ba9fe0e08902a77f1739 Mon Sep 17 00:00:00 2001 From: gabriel Date: Wed, 18 Mar 2026 12:05:41 +0100 Subject: [PATCH 09/13] remove comment in string --- dataframely/_infer_schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataframely/_infer_schema.py b/dataframely/_infer_schema.py index 11aa5fc..a1de589 100644 --- a/dataframely/_infer_schema.py +++ b/dataframely/_infer_schema.py @@ -175,7 +175,7 @@ def _dtype_to_column_code(series: pl.Series, *, alias: str | None = None) -> str dy_name = _POLARS_DTYPE_MAP.get(type(dtype)) if dy_name is None: - return f"dy.Any({_format_args(alias=alias)}) # Unknown dtype: {dtype}" + return f"dy.Any({_format_args(alias=alias)})" # Unknown dtype: {dtype} args = _get_dtype_args(dtype, series) return f"dy.{dy_name}({_format_args(*args, nullable=nullable, alias=alias)})" From b1526a2ac69b0c1a2caa2e58fcd1bb47c3594bc4 Mon Sep 17 00:00:00 2001 From: gabriel Date: Wed, 18 Mar 2026 14:14:04 +0100 Subject: [PATCH 10/13] code cov --- dataframely/_infer_schema.py | 2 +- tests/test_infer_schema.py | 18 +++++++++++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/dataframely/_infer_schema.py b/dataframely/_infer_schema.py index a1de589..3bff5f9 100644 --- a/dataframely/_infer_schema.py +++ b/dataframely/_infer_schema.py @@ -175,7 +175,7 @@ def _dtype_to_column_code(series: pl.Series, *, alias: str | None = None) -> str dy_name = _POLARS_DTYPE_MAP.get(type(dtype)) if dy_name is None: - return f"dy.Any({_format_args(alias=alias)})" # Unknown dtype: {dtype} + return f"dy.Any({_format_args(alias=alias)})" args = _get_dtype_args(dtype, series) return f"dy.{dy_name}({_format_args(*args, nullable=nullable, alias=alias)})" diff --git a/tests/test_infer_schema.py b/tests/test_infer_schema.py index ad63d1a..050a139 100644 --- a/tests/test_infer_schema.py +++ b/tests/test_infer_schema.py @@ -318,14 +318,30 @@ def test_column_sanitization() -> None: def test_column_with_spaces() -> None: df = pl.DataFrame( - {"col name": ["test"], "col_name": ["test"], "col_name_1": ["test"]} + { + "col name": ["test"], + "col_name": ["test"], + "col@name": ["test"], + "col_name_1": ["test"], + } ) result = dy.infer_schema(df, schema_name="SpaceSchema") assert 'col_name = dy.String(alias="col name")' in result assert 'col_name_1 = dy.String(alias="col_name")' in result + assert 'col_name_2 = dy.String(alias="col@name")' in result assert 'col_name_1_1 = dy.String(alias="col_name_1")' in result +def test_column_sanitization_conflict_with_existing() -> None: + df = pl.DataFrame({"col_": ["test"], "col!": ["test"]}) + result = dy.infer_schema(df, schema_name="ConflictSchema") + expected = textwrap.dedent("""\ + class ConflictSchema(dy.Schema): + col_ = dy.String() + col_1 = dy.String(alias="col!")""") + assert result == expected + + @pytest.mark.parametrize( "df", [ From a4545269860e84c7ddb1a522146bc89110c5bc48 Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Mon, 23 Mar 2026 12:25:53 +0100 Subject: [PATCH 11/13] move to experimental, add docs --- dataframely/__init__.py | 1 - dataframely/experimental/__init__.py | 6 ++ .../infer_schema.py} | 8 ++- docs/api/experimental/index.rst | 10 +++ docs/api/index.rst | 9 +++ docs/api/schema/index.rst | 1 - docs/api/schema/inference.rst | 9 --- docs/guides/faq.md | 9 ++- docs/guides/migration/index.md | 6 ++ docs/guides/quickstart.md | 4 +- tests/{ => experimental}/test_infer_schema.py | 63 ++++++++++--------- 11 files changed, 79 insertions(+), 47 deletions(-) create mode 100644 dataframely/experimental/__init__.py rename dataframely/{_infer_schema.py => experimental/infer_schema.py} (95%) create mode 100644 docs/api/experimental/index.rst delete mode 100644 docs/api/schema/inference.rst rename tests/{ => experimental}/test_infer_schema.py (85%) diff --git a/dataframely/__init__.py b/dataframely/__init__.py index c877dcc..88cbbbf 100644 --- a/dataframely/__init__.py +++ b/dataframely/__init__.py @@ -12,7 +12,6 @@ from . import random from ._filter import filter -from ._infer_schema import infer_schema from ._rule import rule from ._typing import DataFrame, LazyFrame, Validation from .collection import ( diff --git a/dataframely/experimental/__init__.py b/dataframely/experimental/__init__.py new file mode 100644 index 0000000..5575e38 --- /dev/null +++ b/dataframely/experimental/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) QuantCo 2025-2025 +# SPDX-License-Identifier: BSD-3-Clause + +from .infer_schema import infer_schema + +__all__ = ["infer_schema"] diff --git a/dataframely/_infer_schema.py b/dataframely/experimental/infer_schema.py similarity index 95% rename from dataframely/_infer_schema.py rename to dataframely/experimental/infer_schema.py index 3bff5f9..b6dd5bf 100644 --- a/dataframely/_infer_schema.py +++ b/dataframely/experimental/infer_schema.py @@ -55,18 +55,22 @@ def infer_schema( Example: >>> import polars as pl - >>> import dataframely as dy + >>> from dataframely.experimental import infer_schema >>> df = pl.DataFrame({ ... "name": ["Alice", "Bob"], ... "age": [25, 30], ... "score": [95.5, None], ... }) - >>> print(dy.infer_schema(df, "PersonSchema")) + >>> print(infer_schema(df, "PersonSchema")) class PersonSchema(dy.Schema): name = dy.String() age = dy.Int64() score = dy.Float64(nullable=True) + Attention: + This functionality is considered unstable. It may be changed at any time + without it being considered a breaking change. + Raises: ValueError: If ``schema_name`` is not a valid Python identifier. """ diff --git a/docs/api/experimental/index.rst b/docs/api/experimental/index.rst new file mode 100644 index 0000000..63ffaa7 --- /dev/null +++ b/docs/api/experimental/index.rst @@ -0,0 +1,10 @@ +============= +Experimental +============= + +.. currentmodule:: dataframely +.. autosummary:: + :toctree: _gen/ + :nosignatures: + + experimental.infer_schema diff --git a/docs/api/index.rst b/docs/api/index.rst index 1c795da..e5c7861 100644 --- a/docs/api/index.rst +++ b/docs/api/index.rst @@ -47,3 +47,12 @@ API Reference :maxdepth: 1 misc/index + +.. grid:: + + .. grid-item-card:: + + .. toctree:: + :maxdepth: 1 + + experimental/index diff --git a/docs/api/schema/index.rst b/docs/api/schema/index.rst index 5ed25ba..77e0323 100644 --- a/docs/api/schema/index.rst +++ b/docs/api/schema/index.rst @@ -9,7 +9,6 @@ Schema validation io generation - inference conversion metadata diff --git a/docs/api/schema/inference.rst b/docs/api/schema/inference.rst deleted file mode 100644 index 29d335e..0000000 --- a/docs/api/schema/inference.rst +++ /dev/null @@ -1,9 +0,0 @@ -========= -Inference -========= - -.. currentmodule:: dataframely -.. autosummary:: - :toctree: _gen/ - - infer_schema diff --git a/docs/guides/faq.md b/docs/guides/faq.md index d5e24d8..83636fe 100644 --- a/docs/guides/faq.md +++ b/docs/guides/faq.md @@ -5,7 +5,7 @@ thinking, please add it here. ## How do I define additional unique keys in a {class}`~dataframely.Schema`? -By default, `dataframely` only supports defining a single non-nullable (composite) primary key in :class: +By default, `dataframely` only supports defining a single non-nullable (composite) primary key in {class} `~dataframely.Schema`. However, in some scenarios it may be useful to define additional unique keys (which support nullable fields and/or which are additionally unique). @@ -34,6 +34,13 @@ class UserSchema(dy.Schema): See our documentation on [group rules](./quickstart.md#group-rules). +## How do I get a {class}`~dataframely.Schema` for my dataframe? + +You can use {func}`dataframely.experimental.infer_schema` to get a basic {class}`~dataframely.Schema` definition for +your dataframe. The function infers column names, types and nullability from the dataframe and returns a code +representation of a {class}`~dataframely.Schema` +Starting from this definition, you can then refine the schema by adding additional rules, adjusting types, etc. + ## What versions of `polars` does `dataframely` support? Our CI automatically tests `dataframely` for a minimal supported version of `polars`, which is currently `1.35.*`, diff --git a/docs/guides/migration/index.md b/docs/guides/migration/index.md index 3c06221..167d06c 100644 --- a/docs/guides/migration/index.md +++ b/docs/guides/migration/index.md @@ -37,3 +37,9 @@ Users can disable `FutureWarnings` either through builtins from tools like [pytest](https://docs.pytest.org/en/stable/how-to/capture-warnings.html#controlling-warnings), or by setting the `DATAFRAMELY_NO_FUTURE_WARNINGS` environment variable to `true` or `1`. + +## Experimental features + +Experimental features are published in a dedicated namespace `dataframely.experimental`. +The versioning policy above does not apply to this namespace, and we may introduce breaking changes to experimental +features in minor releases. diff --git a/docs/guides/quickstart.md b/docs/guides/quickstart.md index 4eb341a..51814ce 100644 --- a/docs/guides/quickstart.md +++ b/docs/guides/quickstart.md @@ -175,7 +175,7 @@ print(failure.counts()) ``` In this case, `good` remains to be a `dy.DataFrame[HouseSchema]`, albeit with potentially fewer rows than `df`. -The `failure` object is of type :class:`~dataframely.FailureInfo` and provides means to inspect +The `failure` object is of type {class}`~dataframely.FailureInfo` and provides means to inspect the reasons for validation failures for invalid rows. Given the example data above and the schema that we defined, we know that rows 2, 3, 4, and 5 are invalid (0-indexed): @@ -185,7 +185,7 @@ Given the example data above and the schema that we defined, we know that rows 2 - Row 4 violates both of the rules above - Row 5 violates the reasonable bathroom to bedroom ratio -Using the `counts` method on the :class:`~dataframely.FailureInfo` object will result in the following dictionary: +Using the `counts` method on the {class}`~dataframely.FailureInfo` object will result in the following dictionary: ```python { diff --git a/tests/test_infer_schema.py b/tests/experimental/test_infer_schema.py similarity index 85% rename from tests/test_infer_schema.py rename to tests/experimental/test_infer_schema.py index 050a139..902a265 100644 --- a/tests/test_infer_schema.py +++ b/tests/experimental/test_infer_schema.py @@ -8,6 +8,7 @@ import pytest import dataframely as dy +from dataframely.experimental import infer_schema def test_basic_types() -> None: @@ -19,7 +20,7 @@ def test_basic_types() -> None: "bool_col": [True, False, True], } ) - result = dy.infer_schema(df, "BasicSchema") + result = infer_schema(df, "BasicSchema") expected = textwrap.dedent("""\ class BasicSchema(dy.Schema): int_col = dy.Int64() @@ -36,7 +37,7 @@ def test_nullable_detection() -> None: "non_nullable_int": [1, 2, 3], } ) - result = dy.infer_schema(df, "NullableSchema") + result = infer_schema(df, "NullableSchema") expected = textwrap.dedent("""\ class NullableSchema(dy.Schema): nullable_int = dy.Int64(nullable=True) @@ -52,7 +53,7 @@ def test_datetime_types() -> None: "datetime_col": [datetime.datetime(2024, 1, 1, 12, 0, 0)], } ) - result = dy.infer_schema(df, schema_name="DatetimeSchema") + result = infer_schema(df, schema_name="DatetimeSchema") expected = textwrap.dedent("""\ class DatetimeSchema(dy.Schema): date_col = dy.Date() @@ -69,7 +70,7 @@ def test_datetime_with_timezone() -> None: ), } ) - result = dy.infer_schema(df, schema_name="TzSchema") + result = infer_schema(df, schema_name="TzSchema") expected = textwrap.dedent("""\ class TzSchema(dy.Schema): utc_time = dy.Datetime(time_zone="UTC")""") @@ -84,7 +85,7 @@ def test_enum_type() -> None: ), } ) - result = dy.infer_schema(df, schema_name="EnumSchema") + result = infer_schema(df, schema_name="EnumSchema") expected = textwrap.dedent("""\ class EnumSchema(dy.Schema): status = dy.Enum(['active', 'pending', 'inactive'])""") @@ -97,7 +98,7 @@ def test_decimal_type() -> None: "amount": pl.Series(["10.50"]).cast(pl.Decimal(precision=10, scale=2)), } ) - result = dy.infer_schema(df, schema_name="DecimalSchema") + result = infer_schema(df, schema_name="DecimalSchema") expected = textwrap.dedent("""\ class DecimalSchema(dy.Schema): amount = dy.Decimal(precision=10, scale=2)""") @@ -110,7 +111,7 @@ def test_list_type() -> None: "tags": [["a", "b"], ["c"]], } ) - result = dy.infer_schema(df, schema_name="ListSchema") + result = infer_schema(df, schema_name="ListSchema") expected = textwrap.dedent("""\ class ListSchema(dy.Schema): tags = dy.List(dy.String())""") @@ -123,7 +124,7 @@ def test_struct_type() -> None: "metadata": [{"key": "value"}, {"key": "other"}], } ) - result = dy.infer_schema(df, schema_name="StructSchema") + result = infer_schema(df, schema_name="StructSchema") expected = textwrap.dedent("""\ class StructSchema(dy.Schema): metadata = dy.Struct({"key": dy.String()})""") @@ -132,7 +133,7 @@ class StructSchema(dy.Schema): def test_list_with_nullable_inner() -> None: df = pl.DataFrame({"names": [["Alice"], [None]]}) - result = dy.infer_schema(df, schema_name="ListNullableInnerSchema") + result = infer_schema(df, schema_name="ListNullableInnerSchema") expected = textwrap.dedent("""\ class ListNullableInnerSchema(dy.Schema): names = dy.List(dy.String(nullable=True))""") @@ -141,7 +142,7 @@ class ListNullableInnerSchema(dy.Schema): def test_struct_with_nullable_field() -> None: df = pl.DataFrame({"data": [{"key": "value"}, {"key": None}]}) - result = dy.infer_schema(df, schema_name="StructNullableFieldSchema") + result = infer_schema(df, schema_name="StructNullableFieldSchema") expected = textwrap.dedent("""\ class StructNullableFieldSchema(dy.Schema): data = dy.Struct({"key": dy.String(nullable=True)})""") @@ -152,7 +153,7 @@ def test_array_type() -> None: df = pl.DataFrame({"vector": [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]}).cast( {"vector": pl.Array(pl.Float64(), 3)} ) - result = dy.infer_schema(df, schema_name="ArraySchema") + result = infer_schema(df, schema_name="ArraySchema") expected = textwrap.dedent("""\ class ArraySchema(dy.Schema): vector = dy.Array(dy.Float64(), shape=3)""") @@ -165,7 +166,7 @@ def test_invalid_identifier() -> None: "123invalid": ["test"], } ) - result = dy.infer_schema(df, schema_name="InvalidIdSchema") + result = infer_schema(df, schema_name="InvalidIdSchema") expected = textwrap.dedent("""\ class InvalidIdSchema(dy.Schema): _123invalid = dy.String(alias="123invalid")""") @@ -178,7 +179,7 @@ def test_python_keyword() -> None: "class": ["test"], } ) - result = dy.infer_schema(df, schema_name="KeywordSchema") + result = infer_schema(df, schema_name="KeywordSchema") expected = textwrap.dedent("""\ class KeywordSchema(dy.Schema): class_ = dy.String(alias="class")""") @@ -198,7 +199,7 @@ def test_all_integer_types() -> None: "u64": pl.Series([1], dtype=pl.UInt64), } ) - result = dy.infer_schema(df, schema_name="IntSchema") + result = infer_schema(df, schema_name="IntSchema") assert "dy.Int8()" in result assert "dy.Int16()" in result assert "dy.Int32()" in result @@ -216,7 +217,7 @@ def test_float_types() -> None: "f64": pl.Series([1.0], dtype=pl.Float64), } ) - result = dy.infer_schema(df, schema_name="FloatSchema") + result = infer_schema(df, schema_name="FloatSchema") assert "dy.Float32()" in result assert "dy.Float64()" in result @@ -226,36 +227,36 @@ def test_invalid_schema_name_raises_error() -> None: with pytest.raises( ValueError, match="schema_name must be a valid Python identifier" ): - dy.infer_schema(df, "Invalid Name") + infer_schema(df, "Invalid Name") def test_default_schema_name() -> None: df = pl.DataFrame({"col": [1]}) - result = dy.infer_schema(df) + result = infer_schema(df) assert "class Schema(dy.Schema):" in result def test_binary_type() -> None: df = pl.DataFrame({"data": pl.Series([b"hello"], dtype=pl.Binary)}) - result = dy.infer_schema(df, schema_name="BinarySchema") + result = infer_schema(df, schema_name="BinarySchema") assert "dy.Binary()" in result def test_null_type() -> None: df = pl.DataFrame({"null_col": pl.Series([None, None], dtype=pl.Null)}) - result = dy.infer_schema(df, schema_name="NullSchema") + result = infer_schema(df, schema_name="NullSchema") assert "dy.Any()" in result def test_object_type() -> None: df = pl.DataFrame({"obj": pl.Series([object()], dtype=pl.Object)}) - result = dy.infer_schema(df, schema_name="ObjectSchema") + result = infer_schema(df, schema_name="ObjectSchema") assert "dy.Object()" in result def test_categorical_type() -> None: df = pl.DataFrame({"cat": pl.Series(["a", "b"]).cast(pl.Categorical())}) - result = dy.infer_schema(df, schema_name="CatSchema") + result = infer_schema(df, schema_name="CatSchema") assert "dy.Categorical()" in result @@ -263,7 +264,7 @@ def test_duration_type() -> None: df = pl.DataFrame( {"dur": pl.Series([datetime.timedelta(days=1)], dtype=pl.Duration)} ) - result = dy.infer_schema(df, schema_name="DurSchema") + result = infer_schema(df, schema_name="DurSchema") assert "dy.Duration()" in result @@ -271,7 +272,7 @@ def test_duration_with_time_unit_ms() -> None: df = pl.DataFrame( {"dur": pl.Series([datetime.timedelta(days=1)]).cast(pl.Duration("ms"))} ) - result = dy.infer_schema(df, schema_name="DurSchema") + result = infer_schema(df, schema_name="DurSchema") assert 'time_unit="ms"' in result @@ -279,7 +280,7 @@ def test_duration_with_time_unit_ns() -> None: df = pl.DataFrame( {"dur": pl.Series([datetime.timedelta(days=1)]).cast(pl.Duration("ns"))} ) - result = dy.infer_schema(df, schema_name="DurSchema") + result = infer_schema(df, schema_name="DurSchema") assert 'time_unit="ns"' in result @@ -287,7 +288,7 @@ def test_datetime_with_time_unit_ms() -> None: df = pl.DataFrame( {"dt": pl.Series([datetime.datetime(2024, 1, 1)]).cast(pl.Datetime("ms"))} ) - result = dy.infer_schema(df, schema_name="DtSchema") + result = infer_schema(df, schema_name="DtSchema") assert 'time_unit="ms"' in result @@ -295,7 +296,7 @@ def test_datetime_with_time_unit_ns() -> None: df = pl.DataFrame( {"dt": pl.Series([datetime.datetime(2024, 1, 1)]).cast(pl.Datetime("ns"))} ) - result = dy.infer_schema(df, schema_name="DtSchema") + result = infer_schema(df, schema_name="DtSchema") assert 'time_unit="ns"' in result @@ -303,14 +304,14 @@ def test_decimal_without_scale() -> None: df = pl.DataFrame( {"amount": pl.Series(["10"]).cast(pl.Decimal(precision=5, scale=0))} ) - result = dy.infer_schema(df, schema_name="DecSchema") + result = infer_schema(df, schema_name="DecSchema") assert "precision=5" in result assert "scale=" not in result def test_column_sanitization() -> None: df = pl.DataFrame({"$": ["test"], "valid": ["test"], "!!!": ["test"], "": ["test"]}) - result = dy.infer_schema(df) + result = infer_schema(df) assert 'column_0 = dy.String(alias="$")' in result assert 'column_2 = dy.String(alias="!!!")' in result assert 'column_3 = dy.String(alias="")' in result @@ -325,7 +326,7 @@ def test_column_with_spaces() -> None: "col_name_1": ["test"], } ) - result = dy.infer_schema(df, schema_name="SpaceSchema") + result = infer_schema(df, schema_name="SpaceSchema") assert 'col_name = dy.String(alias="col name")' in result assert 'col_name_1 = dy.String(alias="col_name")' in result assert 'col_name_2 = dy.String(alias="col@name")' in result @@ -334,7 +335,7 @@ def test_column_with_spaces() -> None: def test_column_sanitization_conflict_with_existing() -> None: df = pl.DataFrame({"col_": ["test"], "col!": ["test"]}) - result = dy.infer_schema(df, schema_name="ConflictSchema") + result = infer_schema(df, schema_name="ConflictSchema") expected = textwrap.dedent("""\ class ConflictSchema(dy.Schema): col_ = dy.String() @@ -395,7 +396,7 @@ class ConflictSchema(dy.Schema): ], ) def test_inferred_schema_validates_dataframe(df: pl.DataFrame) -> None: - code = dy.infer_schema(df, "TestSchema") + code = infer_schema(df, "TestSchema") namespace: dict = {"dy": dy} exec(code, namespace) # noqa: S102 schema = namespace["TestSchema"] From 071ed892ef19a82870b9d93b360f9ff042b9cb32 Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Mon, 23 Mar 2026 12:30:46 +0100 Subject: [PATCH 12/13] precommit --- dataframely/experimental/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataframely/experimental/__init__.py b/dataframely/experimental/__init__.py index 5575e38..757a998 100644 --- a/dataframely/experimental/__init__.py +++ b/dataframely/experimental/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) QuantCo 2025-2025 +# Copyright (c) QuantCo 2025-2026 # SPDX-License-Identifier: BSD-3-Clause from .infer_schema import infer_schema From c7123748ca61f19d807acb22ea532701f7b73053 Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Mon, 23 Mar 2026 17:40:18 +0100 Subject: [PATCH 13/13] fix --- docs/api/experimental/index.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/api/experimental/index.rst b/docs/api/experimental/index.rst index 63ffaa7..e2aad1c 100644 --- a/docs/api/experimental/index.rst +++ b/docs/api/experimental/index.rst @@ -7,4 +7,4 @@ Experimental :toctree: _gen/ :nosignatures: - experimental.infer_schema + experimental.infer_schema