From ff3a9e91d550bca638e2639b194a0bcf80d7273f Mon Sep 17 00:00:00 2001 From: Nardi Lam Date: Mon, 23 Mar 2026 16:59:45 +0100 Subject: [PATCH] feat: Allow specifying Duration time unit --- dataframely/columns/datetime.py | 8 ++++++-- dataframely/random.py | 4 +++- tests/columns/test_pyarrow.py | 8 ++++++++ 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/dataframely/columns/datetime.py b/dataframely/columns/datetime.py index 2b8d1465..030eee55 100644 --- a/dataframely/columns/datetime.py +++ b/dataframely/columns/datetime.py @@ -440,6 +440,7 @@ def __init__( max: dt.timedelta | None = None, max_exclusive: dt.timedelta | None = None, resolution: str | None = None, + time_unit: TimeUnit = "us", check: Check | None = None, alias: str | None = None, metadata: dict[str, Any] | None = None, @@ -462,6 +463,7 @@ def __init__( the formatting language used by :mod:`polars` datetime `truncate` method. For example, a value `1h` expects all durations to be full hours. Note that this setting does *not* affect the storage resolution. + time_unit: Unit of time. Defaults to `us` (microseconds). check: A custom rule or multiple rules to run for this column. This can be: - A single callable that returns a non-aggregated boolean expression. The name of the rule is derived from the callable name, or defaults to @@ -504,10 +506,11 @@ def __init__( metadata=metadata, ) self.resolution = resolution + self.time_unit = time_unit @property def dtype(self) -> pl.DataType: - return pl.Duration() + return pl.Duration(time_unit=self.time_unit) def validation_rules(self, expr: pl.Expr) -> dict[str, pl.Expr]: result = super().validation_rules(expr) @@ -526,7 +529,7 @@ def sqlalchemy_dtype(self, dialect: sa.Dialect) -> sa_TypeEngine: @property def pyarrow_dtype(self) -> pa.DataType: - return pa.duration("us") + return pa.duration(self.time_unit) def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: # NOTE: If no duration is specified, we default to 100 years @@ -543,6 +546,7 @@ def _sample_unchecked(self, generator: Generator, n: int) -> pl.Series: default=dt.timedelta(days=365 * 100), ), resolution=self.resolution, + time_unit=self.time_unit, null_probability=self._null_probability, ) diff --git a/dataframely/random.py b/dataframely/random.py index 67a00280..4fe038b5 100644 --- a/dataframely/random.py +++ b/dataframely/random.py @@ -376,6 +376,7 @@ def sample_duration( min: dt.timedelta, max: dt.timedelta, resolution: str | None = None, + time_unit: TimeUnit = "us", null_probability: float = 0.0, ) -> pl.Series: """Sample a list of durations in the provided range. @@ -386,6 +387,7 @@ def sample_duration( max: The maximum duration to sample (exclusive). resolution: The resolution that durations in the column must have. This uses the formatting language used by :mod:`polars` datetime `round` method. + time_unit: The time unit of the duration column. Defaults to `us` (microseconds). null_probability: The probability of an element being `null`. Returns: @@ -410,7 +412,7 @@ def sample_duration( max=max_microseconds, null_probability=null_probability, ) - ).cast(pl.Duration) + ).cast(pl.Duration(time_unit=time_unit)) if resolution is not None: ref_dt = pl.lit(EPOCH_DATETIME) diff --git a/tests/columns/test_pyarrow.py b/tests/columns/test_pyarrow.py index 4d697152..02e7af3a 100644 --- a/tests/columns/test_pyarrow.py +++ b/tests/columns/test_pyarrow.py @@ -266,3 +266,11 @@ def test_datetime_time_unit(time_unit: TimeUnit) -> None: "test", {"a": dy.Datetime(time_unit=time_unit, nullable=True)} ) assert str(schema.to_pyarrow_schema()) == f"a: timestamp[{time_unit}]" + + +@pytest.mark.parametrize("time_unit", ["ns", "us", "ms"]) +def test_duration_time_unit(time_unit: TimeUnit) -> None: + schema = create_schema( + "test", {"a": dy.Duration(time_unit=time_unit, nullable=True)} + ) + assert str(schema.to_pyarrow_schema()) == f"a: duration[{time_unit}]"