From 104862be1ccf485b4caebd4676707fe833d8cd78 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Fri, 20 Mar 2026 10:42:14 +0000 Subject: [PATCH 01/15] Fix SEQoutput.plot() to display figure instead of printing repr --- pySEQTarget/SEQoutput.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/pySEQTarget/SEQoutput.py b/pySEQTarget/SEQoutput.py index d01a9a7..5587441 100644 --- a/pySEQTarget/SEQoutput.py +++ b/pySEQTarget/SEQoutput.py @@ -4,6 +4,7 @@ from typing import List, Literal, Optional import matplotlib.figure +import matplotlib.pyplot as plt import polars as pl from statsmodels.base.wrapper import ResultsWrapper @@ -61,9 +62,14 @@ class SEQoutput: def plot(self) -> None: """ - Prints the kaplan-meier graph + Displays the kaplan-meier graph """ - print(self.km_graph) + if self.km_graph is None: + raise ValueError( + "No plot available. Ensure km_curves=True and run SEQuential.plot() before collect()." + ) + plt.figure(self.km_graph) + plt.show() def summary( self, type=Optional[Literal["numerator", "denominator", "outcome", "compevent"]] From 64aa91f2ae4fed8dffd919c4f86a873c691062b0 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Fri, 20 Mar 2026 10:43:26 +0000 Subject: [PATCH 02/15] Add plot tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - `test_sequential_plot_returns_figure` — `SEQuential.plot()` stores a real `Figure`, not just a repr - `test_seqoutput_plot_shows_figure` — `SEQoutput.plot()` runs without error after `collect()` - `test_seqoutput_plot_raises_without_km_graph` — graceful error when `km_curves=False` - `test_sequential_plot_risk` — risk plot type - `test_sequential_plot_survival` — survival plot type - `test_sequential_plot_subgroups` — subgroup faceted plot --- tests/test_plot.py | 118 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 tests/test_plot.py diff --git a/tests/test_plot.py b/tests/test_plot.py new file mode 100644 index 0000000..c8c828c --- /dev/null +++ b/tests/test_plot.py @@ -0,0 +1,118 @@ +import matplotlib +import matplotlib.figure +import matplotlib.pyplot as plt +import pytest + +matplotlib.use("Agg") # non-interactive backend — no windows opened + +from pySEQTarget import SEQopts, SEQuential +from pySEQTarget.data import load_data + + +@pytest.fixture(autouse=True) +def close_figures(): + yield + plt.close("all") + + +@pytest.fixture +def base_seq(): + data = load_data("SEQdata") + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(km_curves=True), + ) + s.expand() + s.fit() + s.survival() + return s + + +def test_sequential_plot_returns_figure(base_seq): + """SEQuential.plot() should store a Figure in km_graph, not print its repr.""" + base_seq.plot() + assert isinstance(base_seq.km_graph, matplotlib.figure.Figure) + + +def test_seqoutput_plot_shows_figure(base_seq): + """SEQoutput.plot() should display the figure without raising.""" + base_seq.plot() + result = base_seq.collect() + result.plot() # must not raise; previously printed Figure repr instead + + +def test_seqoutput_plot_raises_without_km_graph(): + """SEQoutput.plot() raises ValueError when no figure was generated.""" + data = load_data("SEQdata") + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(km_curves=False), + ) + s.expand() + s.fit() + result = s.collect() + with pytest.raises(ValueError): + result.plot() + + +def test_sequential_plot_risk(base_seq): + base_seq.plot(plot_type="risk") + assert isinstance(base_seq.km_graph, matplotlib.figure.Figure) + + +def test_sequential_plot_survival(): + data = load_data("SEQdata") + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(km_curves=True, plot_type="survival"), + ) + s.expand() + s.fit() + s.survival() + s.plot() + assert isinstance(s.km_graph, matplotlib.figure.Figure) + + +def test_sequential_plot_subgroups(): + data = load_data("SEQdata") + s = SEQuential( + data, + id_col="ID", + time_col="time", + eligible_col="eligible", + treatment_col="tx_init", + outcome_col="outcome", + time_varying_cols=["N", "L", "P"], + fixed_cols=["sex"], + method="ITT", + parameters=SEQopts(km_curves=True, subgroup_colname="sex"), + ) + s.expand() + s.fit() + s.survival() + s.plot() + assert isinstance(s.km_graph, matplotlib.figure.Figure) From 1906e3afd107f980507bd99ac8b40ab20eeaf70e Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Fri, 20 Mar 2026 10:44:59 +0000 Subject: [PATCH 03/15] Bump version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 323721d..2806fa8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pySEQTarget" -version = "0.12.5" +version = "0.12.6" description = "Sequentially Nested Target Trial Emulation" readme = "README.md" license = {text = "MIT"} From d80109b2782a51594a01f31fbed7297e1654eab8 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Fri, 20 Mar 2026 10:58:15 +0000 Subject: [PATCH 04/15] Fix SEQuential.plot() to display figure by calling plt.show() --- pySEQTarget/SEQuential.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pySEQTarget/SEQuential.py b/pySEQTarget/SEQuential.py index 22efef8..4984c7e 100644 --- a/pySEQTarget/SEQuential.py +++ b/pySEQTarget/SEQuential.py @@ -4,6 +4,7 @@ from dataclasses import asdict from typing import List, Literal, Optional +import matplotlib.pyplot as plt import numpy as np import polars as pl @@ -317,6 +318,7 @@ def plot(self, **kwargs) -> None: else: raise ValueError(f"Unknown or misplaced argument: {key}") self.km_graph = _survival_plot(self) + plt.show() def collect(self) -> SEQoutput: """ From 06bcd330d38eaf7770c6229b518173a2164638ca Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Fri, 20 Mar 2026 10:59:29 +0000 Subject: [PATCH 05/15] Add test verifying SEQuential.plot() calls plt.show() --- tests/test_plot.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/test_plot.py b/tests/test_plot.py index c8c828c..ce94eef 100644 --- a/tests/test_plot.py +++ b/tests/test_plot.py @@ -1,3 +1,5 @@ +import unittest.mock as mock + import matplotlib import matplotlib.figure import matplotlib.pyplot as plt @@ -42,6 +44,13 @@ def test_sequential_plot_returns_figure(base_seq): assert isinstance(base_seq.km_graph, matplotlib.figure.Figure) +def test_sequential_plot_calls_show(base_seq): + """SEQuential.plot() must call plt.show() to actually display the figure.""" + with mock.patch("matplotlib.pyplot.show") as mock_show: + base_seq.plot() + mock_show.assert_called_once() + + def test_seqoutput_plot_shows_figure(base_seq): """SEQoutput.plot() should display the figure without raising.""" base_seq.plot() From 0d3032948cdd380a3625ef9dfc91d87aa6b4bdc8 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Fri, 20 Mar 2026 12:09:37 +0000 Subject: [PATCH 06/15] Re-enable Python 3.10 The lifelines package seems to allow Python 3.10 again. --- .github/workflows/python-app.yml | 2 +- pyproject.toml | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index 72c0145..1b7b9f8 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -14,7 +14,7 @@ jobs: runs-on: macos-26 strategy: matrix: - python-version: ["3.11", "3.12", "3.13", "3.14"] + python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] steps: - uses: actions/checkout@v6 diff --git a/pyproject.toml b/pyproject.toml index 2806fa8..059b744 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,11 +9,12 @@ description = "Sequentially Nested Target Trial Emulation" readme = "README.md" license = {text = "MIT"} keywords = ["causal inference", "sequential trial emulation", "target trial", "observational studies"] -requires-python = ">=3.11" +requires-python = ">=3.10" classifiers = [ "Development Status :: 4 - Beta", "Intended Audience :: Science/Research", "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", From fbbf400bf8f88c6b04cc200c5ddf5da7e4679dee Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Fri, 20 Mar 2026 12:25:59 +0000 Subject: [PATCH 07/15] Tighten SEQopts range validation Fix bootstrap_sample and selection_sample lower bounds Add weight_min/max and followup_min/max ordering checks --- pySEQTarget/SEQopts.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/pySEQTarget/SEQopts.py b/pySEQTarget/SEQopts.py index 45085f2..4a84001 100644 --- a/pySEQTarget/SEQopts.py +++ b/pySEQTarget/SEQopts.py @@ -184,12 +184,20 @@ def _validate_ranges(self): raise ValueError("bootstrap_nboot must be a positive integer.") if self.ncores < 1 or not isinstance(self.ncores, int): raise ValueError("ncores must be a positive integer.") - if not (0.0 <= self.bootstrap_sample <= 1.0): - raise ValueError("bootstrap_sample must be between 0 and 1.") + if not (0.0 < self.bootstrap_sample <= 1.0): + raise ValueError("bootstrap_sample must be between 0 (exclusive) and 1.") if not (0.0 < self.bootstrap_CI < 1.0): raise ValueError("bootstrap_CI must be between 0 and 1.") - if not (0.0 <= self.selection_sample <= 1.0): - raise ValueError("selection_sample must be between 0 and 1.") + if not (0.0 < self.selection_sample <= 1.0): + raise ValueError("selection_sample must be between 0 (exclusive) and 1.") + if self.weight_max is not None and self.weight_max <= self.weight_min: + raise ValueError( + f"weight_min ({self.weight_min}) must be less than weight_max ({self.weight_max})." + ) + if self.followup_max is not None and self.followup_max <= self.followup_min: + raise ValueError( + f"followup_min ({self.followup_min}) must be less than followup_max ({self.followup_max})." + ) def _validate_choices(self): if self.plot_type not in ["risk", "survival", "incidence"]: From 1b06c154e3ccb64fb597bddf61582b8e90988b74 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Fri, 20 Mar 2026 12:28:44 +0000 Subject: [PATCH 08/15] Add validation that columns cannot appear in both time_varying_cols and fixed_cols --- pySEQTarget/error/_param_checker.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pySEQTarget/error/_param_checker.py b/pySEQTarget/error/_param_checker.py index c048583..e7368da 100644 --- a/pySEQTarget/error/_param_checker.py +++ b/pySEQTarget/error/_param_checker.py @@ -2,6 +2,11 @@ def _param_checker(self): + overlap = set(self.time_varying_cols) & set(self.fixed_cols) + if overlap: + raise ValueError( + f"Columns cannot appear in both time_varying_cols and fixed_cols: {sorted(overlap)}" + ) if ( self.subgroup_colname is not None and self.subgroup_colname not in self.fixed_cols From 7136bcaa41840904641d0d6d521976c3204cd325 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Fri, 20 Mar 2026 12:29:13 +0000 Subject: [PATCH 09/15] Add validation that all treatment_level values exist in the treatment colum --- pySEQTarget/error/_param_checker.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pySEQTarget/error/_param_checker.py b/pySEQTarget/error/_param_checker.py index e7368da..a926569 100644 --- a/pySEQTarget/error/_param_checker.py +++ b/pySEQTarget/error/_param_checker.py @@ -7,6 +7,14 @@ def _param_checker(self): raise ValueError( f"Columns cannot appear in both time_varying_cols and fixed_cols: {sorted(overlap)}" ) + + actual_levels = set(self.data[self.treatment_col].unique().to_list()) + missing_levels = set(self.treatment_level) - actual_levels + if missing_levels: + raise ValueError( + f"treatment_level contains values not found in '{self.treatment_col}': {sorted(missing_levels)}" + ) + if ( self.subgroup_colname is not None and self.subgroup_colname not in self.fixed_cols From 4117386d8896d16a8f5a411f76b8bf35dc59b8ca Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Fri, 20 Mar 2026 12:31:04 +0000 Subject: [PATCH 10/15] Add binary (0/1) validation for eligible_col, outcome_col, cense_eligible_colname, and weight_eligible_colnames --- pySEQTarget/error/_data_checker.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/pySEQTarget/error/_data_checker.py b/pySEQTarget/error/_data_checker.py index 1162a09..9e9ca76 100644 --- a/pySEQTarget/error/_data_checker.py +++ b/pySEQTarget/error/_data_checker.py @@ -1,7 +1,25 @@ import polars as pl +def _check_binary(data, col): + unique_vals = set(data[col].drop_nulls().unique().to_list()) + if not unique_vals.issubset({0, 1}): + raise ValueError( + f"Column '{col}' must be binary (0/1) but contains values: {sorted(unique_vals)}" + ) + + def _data_checker(self): + _check_binary(self.data, self.eligible_col) + _check_binary(self.data, self.outcome_col) + + if self.cense_eligible_colname is not None: + _check_binary(self.data, self.cense_eligible_colname) + + for col in self.weight_eligible_colnames: + if col is not None: + _check_binary(self.data, col) + check = self.data.group_by(self.id_col).agg( [pl.len().alias("row_count"), pl.col(self.time_col).max().alias("max_time")] ) From 45a04c65b5b9cf2698a094143d90010a4c9cd94d Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Fri, 20 Mar 2026 12:32:19 +0000 Subject: [PATCH 11/15] Amend single-transition check error message to be generic --- pySEQTarget/error/_data_checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pySEQTarget/error/_data_checker.py b/pySEQTarget/error/_data_checker.py index 9e9ca76..e044233 100644 --- a/pySEQTarget/error/_data_checker.py +++ b/pySEQTarget/error/_data_checker.py @@ -51,6 +51,6 @@ def _data_checker(self): if len(violations) > 0: raise ValueError( - f"Column '{col}' violates 'once one, always one' rule for excusing treatment " + f"Column '{col}' violates the 'once one, always one' rule: " f"{len(violations)} ID(s) have zeros after ones." ) From 0376fe537847ce3914557afaa2613b6964db1341 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Sun, 22 Mar 2026 03:47:44 +0000 Subject: [PATCH 12/15] Amend weight_preexpansion default to True to match R package default --- pySEQTarget/SEQopts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pySEQTarget/SEQopts.py b/pySEQTarget/SEQopts.py index 4a84001..5e03437 100644 --- a/pySEQTarget/SEQopts.py +++ b/pySEQTarget/SEQopts.py @@ -155,7 +155,7 @@ class SEQopts: weight_max: float = None weight_lag_condition: bool = True weight_p99: bool = False - weight_preexpansion: bool = False + weight_preexpansion: bool = True weighted: bool = False def _validate_bools(self): From b9741e258f926163629fceaa7fd58fdd84d3a6a8 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Sun, 22 Mar 2026 07:42:00 +0000 Subject: [PATCH 13/15] Explicitly set weight_preexpansion=False in PostE tests to preserve behaviour after default change --- tests/test_coefficients.py | 7 ++++--- tests/test_covariates.py | 6 +++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/test_coefficients.py b/tests/test_coefficients.py index 9debc39..7aecd9b 100644 --- a/tests/test_coefficients.py +++ b/tests/test_coefficients.py @@ -81,7 +81,7 @@ def test_PostE_dose_response_coefs(): time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], method="dose-response", - parameters=SEQopts(weighted=True), + parameters=SEQopts(weighted=True, weight_preexpansion=False), ) s.expand() @@ -147,7 +147,7 @@ def test_PostE_censoring_coefs(): time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], method="censoring", - parameters=SEQopts(weighted=True), + parameters=SEQopts(weighted=True, weight_preexpansion=False), ) s.expand() s.fit() @@ -216,6 +216,7 @@ def test_PostE_censoring_excused_coefs(): method="censoring", parameters=SEQopts( weighted=True, + weight_preexpansion=False, excused=True, excused_colnames=["excusedZero", "excusedOne"], weight_max=1, @@ -287,7 +288,7 @@ def test_PostE_LTFU_ITT(): time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], method="ITT", - parameters=SEQopts(weighted=True, cense_colname="LTFU"), + parameters=SEQopts(weighted=True, weight_preexpansion=False, cense_colname="LTFU"), ) s.expand() s.fit() diff --git a/tests/test_covariates.py b/tests/test_covariates.py index 9863f8f..1bd03c9 100644 --- a/tests/test_covariates.py +++ b/tests/test_covariates.py @@ -62,7 +62,7 @@ def test_PostE_dose_response_covariates(): time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], method="dose-response", - parameters=SEQopts(weighted=True), + parameters=SEQopts(weighted=True, weight_preexpansion=False), ) assert ( s.covariates @@ -110,7 +110,7 @@ def test_PostE_censoring_covariates(): time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], method="censoring", - parameters=SEQopts(weighted=True), + parameters=SEQopts(weighted=True, weight_preexpansion=False), ) assert ( s.covariates @@ -165,7 +165,7 @@ def test_PostE_censoring_excused_covariates(): fixed_cols=["sex"], method="censoring", parameters=SEQopts( - weighted=True, excused=True, excused_colnames=["excusedZero", "excusedOne"] + weighted=True, weight_preexpansion=False, excused=True, excused_colnames=["excusedZero", "excusedOne"] ), ) assert ( From 4d41d305d710a36edc2e82b348b01e73d6031b20 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Sun, 22 Mar 2026 03:49:43 +0000 Subject: [PATCH 14/15] Amend plot_type default to "survival" to match R package default --- pySEQTarget/SEQopts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pySEQTarget/SEQopts.py b/pySEQTarget/SEQopts.py index 5e03437..a974298 100644 --- a/pySEQTarget/SEQopts.py +++ b/pySEQTarget/SEQopts.py @@ -140,7 +140,7 @@ class SEQopts: ) plot_labels: List[str] = field(default_factory=lambda: []) plot_title: str = None - plot_type: Literal["risk", "survival", "incidence"] = "risk" + plot_type: Literal["risk", "survival", "incidence"] = "survival" seed: Optional[int] = None selection_first_trial: bool = False selection_sample: float = 0.8 From 35763c6acf036c2756e35c0e1bc229ccada40f5f Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Sun, 22 Mar 2026 03:51:34 +0000 Subject: [PATCH 15/15] Amend ncores default to cpu_count - 1 to match R package --- pySEQTarget/SEQopts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pySEQTarget/SEQopts.py b/pySEQTarget/SEQopts.py index a974298..dc51d14 100644 --- a/pySEQTarget/SEQopts.py +++ b/pySEQTarget/SEQopts.py @@ -130,7 +130,7 @@ class SEQopts: indicator_baseline: str = "_bas" indicator_squared: str = "_sq" km_curves: bool = False - ncores: int = multiprocessing.cpu_count() + ncores: int = max(1, multiprocessing.cpu_count() - 1) numerator: Optional[str] = None offload: bool = False offload_dir: str = "_seq_models"