diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 72c0145..1b7b9f8 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -14,7 +14,7 @@ jobs: runs-on: macos-26 strategy: matrix: - python-version: ["3.11", "3.12", "3.13", "3.14"] + python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] steps: - uses: actions/checkout@v6 diff --git a/pySEQTarget/SEQopts.py b/pySEQTarget/SEQopts.py index 45085f2..dc51d14 100644 --- a/pySEQTarget/SEQopts.py +++ b/pySEQTarget/SEQopts.py @@ -130,7 +130,7 @@ class SEQopts: indicator_baseline: str = "_bas" indicator_squared: str = "_sq" km_curves: bool = False - ncores: int = multiprocessing.cpu_count() + ncores: int = max(1, multiprocessing.cpu_count() - 1) numerator: Optional[str] = None offload: bool = False offload_dir: str = "_seq_models" @@ -140,7 +140,7 @@ class SEQopts: ) plot_labels: List[str] = field(default_factory=lambda: []) plot_title: str = None - plot_type: Literal["risk", "survival", "incidence"] = "risk" + plot_type: Literal["risk", "survival", "incidence"] = "survival" seed: Optional[int] = None selection_first_trial: bool = False selection_sample: float = 0.8 @@ -155,7 +155,7 @@ class SEQopts: weight_max: float = None weight_lag_condition: bool = True weight_p99: bool = False - weight_preexpansion: bool = False + weight_preexpansion: bool = True weighted: bool = False def _validate_bools(self): @@ -184,12 +184,20 @@ def _validate_ranges(self): raise ValueError("bootstrap_nboot must be a positive integer.") if self.ncores < 1 or not isinstance(self.ncores, int): raise ValueError("ncores must be a positive integer.") - if not (0.0 <= self.bootstrap_sample <= 1.0): - raise ValueError("bootstrap_sample must be between 0 and 1.") + if not (0.0 < self.bootstrap_sample <= 1.0): + raise ValueError("bootstrap_sample must be between 0 (exclusive) and 1.") if not (0.0 < self.bootstrap_CI < 1.0): raise ValueError("bootstrap_CI must be between 0 and 1.") - if not (0.0 <= self.selection_sample <= 1.0): - raise ValueError("selection_sample must be between 0 and 1.") + if not (0.0 < self.selection_sample <= 1.0): + raise ValueError("selection_sample must be between 0 (exclusive) and 1.") + if self.weight_max is not None and self.weight_max <= self.weight_min: + raise ValueError( + f"weight_min ({self.weight_min}) must be less than weight_max ({self.weight_max})." + ) + if self.followup_max is not None and self.followup_max <= self.followup_min: + raise ValueError( + f"followup_min ({self.followup_min}) must be less than followup_max ({self.followup_max})." + ) def _validate_choices(self): if self.plot_type not in ["risk", "survival", "incidence"]: diff --git a/pySEQTarget/SEQoutput.py b/pySEQTarget/SEQoutput.py index d01a9a7..5587441 100644 --- a/pySEQTarget/SEQoutput.py +++ b/pySEQTarget/SEQoutput.py @@ -4,6 +4,7 @@ from typing import List, Literal, Optional import matplotlib.figure +import matplotlib.pyplot as plt import polars as pl from statsmodels.base.wrapper import ResultsWrapper @@ -61,9 +62,14 @@ class SEQoutput: def plot(self) -> None: """ - Prints the kaplan-meier graph + Displays the kaplan-meier graph """ - print(self.km_graph) + if self.km_graph is None: + raise ValueError( + "No plot available. Ensure km_curves=True and run SEQuential.plot() before collect()." + ) + plt.figure(self.km_graph) + plt.show() def summary( self, type=Optional[Literal["numerator", "denominator", "outcome", "compevent"]] diff --git a/pySEQTarget/SEQuential.py b/pySEQTarget/SEQuential.py index 22efef8..4984c7e 100644 --- a/pySEQTarget/SEQuential.py +++ b/pySEQTarget/SEQuential.py @@ -4,6 +4,7 @@ from dataclasses import asdict from typing import List, Literal, Optional +import matplotlib.pyplot as plt import numpy as np import polars as pl @@ -317,6 +318,7 @@ def plot(self, **kwargs) -> None: else: raise ValueError(f"Unknown or misplaced argument: {key}") self.km_graph = _survival_plot(self) + plt.show() def collect(self) -> SEQoutput: """ diff --git a/pySEQTarget/error/_data_checker.py b/pySEQTarget/error/_data_checker.py index 1162a09..e044233 100644 --- a/pySEQTarget/error/_data_checker.py +++ b/pySEQTarget/error/_data_checker.py @@ -1,7 +1,25 @@ import polars as pl +def _check_binary(data, col): + unique_vals = set(data[col].drop_nulls().unique().to_list()) + if not unique_vals.issubset({0, 1}): + raise ValueError( + f"Column '{col}' must be binary (0/1) but contains values: {sorted(unique_vals)}" + ) + + def _data_checker(self): + _check_binary(self.data, self.eligible_col) + _check_binary(self.data, self.outcome_col) + + if self.cense_eligible_colname is not None: + _check_binary(self.data, self.cense_eligible_colname) + + for col in self.weight_eligible_colnames: + if col is not None: + _check_binary(self.data, col) + check = self.data.group_by(self.id_col).agg( [pl.len().alias("row_count"), pl.col(self.time_col).max().alias("max_time")] ) @@ -33,6 +51,6 @@ def _data_checker(self): if len(violations) > 0: raise ValueError( - f"Column '{col}' violates 'once one, always one' rule for excusing treatment " + f"Column '{col}' violates the 'once one, always one' rule: " f"{len(violations)} ID(s) have zeros after ones." ) diff --git a/pySEQTarget/error/_param_checker.py b/pySEQTarget/error/_param_checker.py index c048583..a926569 100644 --- a/pySEQTarget/error/_param_checker.py +++ b/pySEQTarget/error/_param_checker.py @@ -2,6 +2,19 @@ def _param_checker(self): + overlap = set(self.time_varying_cols) & set(self.fixed_cols) + if overlap: + raise ValueError( + f"Columns cannot appear in both time_varying_cols and fixed_cols: {sorted(overlap)}" + ) + + actual_levels = set(self.data[self.treatment_col].unique().to_list()) + missing_levels = set(self.treatment_level) - actual_levels + if missing_levels: + raise ValueError( + f"treatment_level contains values not found in '{self.treatment_col}': {sorted(missing_levels)}" + ) + if ( self.subgroup_colname is not None and self.subgroup_colname not in self.fixed_cols diff --git a/pyproject.toml b/pyproject.toml index 323721d..059b744 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,16 +4,17 @@ build-backend = "setuptools.build_meta" [project] name = "pySEQTarget" -version = "0.12.5" +version = "0.12.6" description = "Sequentially Nested Target Trial Emulation" readme = "README.md" license = {text = "MIT"} keywords = ["causal inference", "sequential trial emulation", "target trial", "observational studies"] -requires-python = ">=3.11" +requires-python = ">=3.10" classifiers = [ "Development Status :: 4 - Beta", "Intended Audience :: Science/Research", "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", diff --git a/tests/test_coefficients.py b/tests/test_coefficients.py index 9debc39..7aecd9b 100644 --- a/tests/test_coefficients.py +++ b/tests/test_coefficients.py @@ -81,7 +81,7 @@ def test_PostE_dose_response_coefs(): time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], method="dose-response", - parameters=SEQopts(weighted=True), + parameters=SEQopts(weighted=True, weight_preexpansion=False), ) s.expand() @@ -147,7 +147,7 @@ def test_PostE_censoring_coefs(): time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], method="censoring", - parameters=SEQopts(weighted=True), + parameters=SEQopts(weighted=True, weight_preexpansion=False), ) s.expand() s.fit() @@ -216,6 +216,7 @@ def test_PostE_censoring_excused_coefs(): method="censoring", parameters=SEQopts( weighted=True, + weight_preexpansion=False, excused=True, excused_colnames=["excusedZero", "excusedOne"], weight_max=1, @@ -287,7 +288,7 @@ def test_PostE_LTFU_ITT(): time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], method="ITT", - parameters=SEQopts(weighted=True, cense_colname="LTFU"), + parameters=SEQopts(weighted=True, weight_preexpansion=False, cense_colname="LTFU"), ) s.expand() s.fit() diff --git a/tests/test_covariates.py b/tests/test_covariates.py index 9863f8f..1bd03c9 100644 --- a/tests/test_covariates.py +++ b/tests/test_covariates.py @@ -62,7 +62,7 @@ def test_PostE_dose_response_covariates(): time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], method="dose-response", - parameters=SEQopts(weighted=True), + parameters=SEQopts(weighted=True, weight_preexpansion=False), ) assert ( s.covariates @@ -110,7 +110,7 @@ def test_PostE_censoring_covariates(): time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], method="censoring", - parameters=SEQopts(weighted=True), + parameters=SEQopts(weighted=True, weight_preexpansion=False), ) assert ( s.covariates @@ -165,7 +165,7 @@ def test_PostE_censoring_excused_covariates(): fixed_cols=["sex"], method="censoring", parameters=SEQopts( - weighted=True, excused=True, excused_colnames=["excusedZero", "excusedOne"] + weighted=True, weight_preexpansion=False, excused=True, excused_colnames=["excusedZero", "excusedOne"] ), ) assert ( diff --git a/tests/test_plot.py b/tests/test_plot.py new file mode 100644 index 0000000..ce94eef --- /dev/null +++ b/tests/test_plot.py @@ -0,0 +1,127 @@ +import unittest.mock as mock + +import matplotlib +import matplotlib.figure +import matplotlib.pyplot as plt +import pytest + +matplotlib.use("Agg") # non-interactive backend — no windows opened + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +@pytest.fixture(autouse=True) +def close_figures(): + yield + plt.close("all") + + +@pytest.fixture +def base_seq(): + data = load_data("SEQdata") + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(km_curves=True), + ) + s.expand() + s.fit() + s.survival() + return s + + +def test_sequential_plot_returns_figure(base_seq): + """SEQuential.plot() should store a Figure in km_graph, not print its repr.""" + base_seq.plot() + assert isinstance(base_seq.km_graph, matplotlib.figure.Figure) + + +def test_sequential_plot_calls_show(base_seq): + """SEQuential.plot() must call plt.show() to actually display the figure.""" + with mock.patch("matplotlib.pyplot.show") as mock_show: + base_seq.plot() + mock_show.assert_called_once() + + +def test_seqoutput_plot_shows_figure(base_seq): + """SEQoutput.plot() should display the figure without raising.""" + base_seq.plot() + result = base_seq.collect() + result.plot() # must not raise; previously printed Figure repr instead + + +def test_seqoutput_plot_raises_without_km_graph(): + """SEQoutput.plot() raises ValueError when no figure was generated.""" + data = load_data("SEQdata") + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(km_curves=False), + ) + s.expand() + s.fit() + result = s.collect() + with pytest.raises(ValueError): + result.plot() + + +def test_sequential_plot_risk(base_seq): + base_seq.plot(plot_type="risk") + assert isinstance(base_seq.km_graph, matplotlib.figure.Figure) + + +def test_sequential_plot_survival(): + data = load_data("SEQdata") + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(km_curves=True, plot_type="survival"), + ) + s.expand() + s.fit() + s.survival() + s.plot() + assert isinstance(s.km_graph, matplotlib.figure.Figure) + + +def test_sequential_plot_subgroups(): + data = load_data("SEQdata") + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(km_curves=True, subgroup_colname="sex"), + ) + s.expand() + s.fit() + s.survival() + s.plot() + assert isinstance(s.km_graph, matplotlib.figure.Figure)