Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
runs-on: macos-26
strategy:
matrix:
python-version: ["3.11", "3.12", "3.13", "3.14"]
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]

steps:
- uses: actions/checkout@v6
Expand Down
22 changes: 15 additions & 7 deletions pySEQTarget/SEQopts.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class SEQopts:
indicator_baseline: str = "_bas"
indicator_squared: str = "_sq"
km_curves: bool = False
ncores: int = multiprocessing.cpu_count()
ncores: int = max(1, multiprocessing.cpu_count() - 1)
numerator: Optional[str] = None
offload: bool = False
offload_dir: str = "_seq_models"
Expand All @@ -140,7 +140,7 @@ class SEQopts:
)
plot_labels: List[str] = field(default_factory=lambda: [])
plot_title: str = None
plot_type: Literal["risk", "survival", "incidence"] = "risk"
plot_type: Literal["risk", "survival", "incidence"] = "survival"
seed: Optional[int] = None
selection_first_trial: bool = False
selection_sample: float = 0.8
Expand All @@ -155,7 +155,7 @@ class SEQopts:
weight_max: float = None
weight_lag_condition: bool = True
weight_p99: bool = False
weight_preexpansion: bool = False
weight_preexpansion: bool = True
weighted: bool = False

def _validate_bools(self):
Expand Down Expand Up @@ -184,12 +184,20 @@ def _validate_ranges(self):
raise ValueError("bootstrap_nboot must be a positive integer.")
if self.ncores < 1 or not isinstance(self.ncores, int):
raise ValueError("ncores must be a positive integer.")
if not (0.0 <= self.bootstrap_sample <= 1.0):
raise ValueError("bootstrap_sample must be between 0 and 1.")
if not (0.0 < self.bootstrap_sample <= 1.0):
raise ValueError("bootstrap_sample must be between 0 (exclusive) and 1.")
if not (0.0 < self.bootstrap_CI < 1.0):
raise ValueError("bootstrap_CI must be between 0 and 1.")
if not (0.0 <= self.selection_sample <= 1.0):
raise ValueError("selection_sample must be between 0 and 1.")
if not (0.0 < self.selection_sample <= 1.0):
raise ValueError("selection_sample must be between 0 (exclusive) and 1.")
if self.weight_max is not None and self.weight_max <= self.weight_min:
raise ValueError(
f"weight_min ({self.weight_min}) must be less than weight_max ({self.weight_max})."
)
if self.followup_max is not None and self.followup_max <= self.followup_min:
raise ValueError(
f"followup_min ({self.followup_min}) must be less than followup_max ({self.followup_max})."
)

def _validate_choices(self):
if self.plot_type not in ["risk", "survival", "incidence"]:
Expand Down
10 changes: 8 additions & 2 deletions pySEQTarget/SEQoutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List, Literal, Optional

import matplotlib.figure
import matplotlib.pyplot as plt
import polars as pl
from statsmodels.base.wrapper import ResultsWrapper

Expand Down Expand Up @@ -61,9 +62,14 @@ class SEQoutput:

def plot(self) -> None:
"""
Prints the kaplan-meier graph
Displays the kaplan-meier graph
"""
print(self.km_graph)
if self.km_graph is None:
raise ValueError(
"No plot available. Ensure km_curves=True and run SEQuential.plot() before collect()."
)
plt.figure(self.km_graph)
plt.show()

def summary(
self, type=Optional[Literal["numerator", "denominator", "outcome", "compevent"]]
Expand Down
2 changes: 2 additions & 0 deletions pySEQTarget/SEQuential.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import asdict
from typing import List, Literal, Optional

import matplotlib.pyplot as plt
import numpy as np
import polars as pl

Expand Down Expand Up @@ -317,6 +318,7 @@ def plot(self, **kwargs) -> None:
else:
raise ValueError(f"Unknown or misplaced argument: {key}")
self.km_graph = _survival_plot(self)
plt.show()

def collect(self) -> SEQoutput:
"""
Expand Down
20 changes: 19 additions & 1 deletion pySEQTarget/error/_data_checker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,25 @@
import polars as pl


def _check_binary(data, col):
unique_vals = set(data[col].drop_nulls().unique().to_list())
if not unique_vals.issubset({0, 1}):
raise ValueError(
f"Column '{col}' must be binary (0/1) but contains values: {sorted(unique_vals)}"
)


def _data_checker(self):
_check_binary(self.data, self.eligible_col)
_check_binary(self.data, self.outcome_col)

if self.cense_eligible_colname is not None:
_check_binary(self.data, self.cense_eligible_colname)

for col in self.weight_eligible_colnames:
if col is not None:
_check_binary(self.data, col)

check = self.data.group_by(self.id_col).agg(
[pl.len().alias("row_count"), pl.col(self.time_col).max().alias("max_time")]
)
Expand Down Expand Up @@ -33,6 +51,6 @@ def _data_checker(self):

if len(violations) > 0:
raise ValueError(
f"Column '{col}' violates 'once one, always one' rule for excusing treatment "
f"Column '{col}' violates the 'once one, always one' rule: "
f"{len(violations)} ID(s) have zeros after ones."
)
13 changes: 13 additions & 0 deletions pySEQTarget/error/_param_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,19 @@


def _param_checker(self):
overlap = set(self.time_varying_cols) & set(self.fixed_cols)
if overlap:
raise ValueError(
f"Columns cannot appear in both time_varying_cols and fixed_cols: {sorted(overlap)}"
)

actual_levels = set(self.data[self.treatment_col].unique().to_list())
missing_levels = set(self.treatment_level) - actual_levels
if missing_levels:
raise ValueError(
f"treatment_level contains values not found in '{self.treatment_col}': {sorted(missing_levels)}"
)

if (
self.subgroup_colname is not None
and self.subgroup_colname not in self.fixed_cols
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@ build-backend = "setuptools.build_meta"

[project]
name = "pySEQTarget"
version = "0.12.5"
version = "0.12.6"
description = "Sequentially Nested Target Trial Emulation"
readme = "README.md"
license = {text = "MIT"}
keywords = ["causal inference", "sequential trial emulation", "target trial", "observational studies"]
requires-python = ">=3.11"
requires-python = ">=3.10"
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Science/Research",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
Expand Down
7 changes: 4 additions & 3 deletions tests/test_coefficients.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_PostE_dose_response_coefs():
time_varying_cols=["N", "L", "P"],
fixed_cols=["sex"],
method="dose-response",
parameters=SEQopts(weighted=True),
parameters=SEQopts(weighted=True, weight_preexpansion=False),
)

s.expand()
Expand Down Expand Up @@ -147,7 +147,7 @@ def test_PostE_censoring_coefs():
time_varying_cols=["N", "L", "P"],
fixed_cols=["sex"],
method="censoring",
parameters=SEQopts(weighted=True),
parameters=SEQopts(weighted=True, weight_preexpansion=False),
)
s.expand()
s.fit()
Expand Down Expand Up @@ -216,6 +216,7 @@ def test_PostE_censoring_excused_coefs():
method="censoring",
parameters=SEQopts(
weighted=True,
weight_preexpansion=False,
excused=True,
excused_colnames=["excusedZero", "excusedOne"],
weight_max=1,
Expand Down Expand Up @@ -287,7 +288,7 @@ def test_PostE_LTFU_ITT():
time_varying_cols=["N", "L", "P"],
fixed_cols=["sex"],
method="ITT",
parameters=SEQopts(weighted=True, cense_colname="LTFU"),
parameters=SEQopts(weighted=True, weight_preexpansion=False, cense_colname="LTFU"),
)
s.expand()
s.fit()
Expand Down
6 changes: 3 additions & 3 deletions tests/test_covariates.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_PostE_dose_response_covariates():
time_varying_cols=["N", "L", "P"],
fixed_cols=["sex"],
method="dose-response",
parameters=SEQopts(weighted=True),
parameters=SEQopts(weighted=True, weight_preexpansion=False),
)
assert (
s.covariates
Expand Down Expand Up @@ -110,7 +110,7 @@ def test_PostE_censoring_covariates():
time_varying_cols=["N", "L", "P"],
fixed_cols=["sex"],
method="censoring",
parameters=SEQopts(weighted=True),
parameters=SEQopts(weighted=True, weight_preexpansion=False),
)
assert (
s.covariates
Expand Down Expand Up @@ -165,7 +165,7 @@ def test_PostE_censoring_excused_covariates():
fixed_cols=["sex"],
method="censoring",
parameters=SEQopts(
weighted=True, excused=True, excused_colnames=["excusedZero", "excusedOne"]
weighted=True, weight_preexpansion=False, excused=True, excused_colnames=["excusedZero", "excusedOne"]
),
)
assert (
Expand Down
127 changes: 127 additions & 0 deletions tests/test_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import unittest.mock as mock

import matplotlib
import matplotlib.figure
import matplotlib.pyplot as plt
import pytest

matplotlib.use("Agg") # non-interactive backend — no windows opened

from pySEQTarget import SEQopts, SEQuential
from pySEQTarget.data import load_data


@pytest.fixture(autouse=True)
def close_figures():
yield
plt.close("all")


@pytest.fixture
def base_seq():
data = load_data("SEQdata")
s = SEQuential(
data,
id_col="ID",
time_col="time",
eligible_col="eligible",
treatment_col="tx_init",
outcome_col="outcome",
time_varying_cols=["N", "L", "P"],
fixed_cols=["sex"],
method="ITT",
parameters=SEQopts(km_curves=True),
)
s.expand()
s.fit()
s.survival()
return s


def test_sequential_plot_returns_figure(base_seq):
"""SEQuential.plot() should store a Figure in km_graph, not print its repr."""
base_seq.plot()
assert isinstance(base_seq.km_graph, matplotlib.figure.Figure)


def test_sequential_plot_calls_show(base_seq):
"""SEQuential.plot() must call plt.show() to actually display the figure."""
with mock.patch("matplotlib.pyplot.show") as mock_show:
base_seq.plot()
mock_show.assert_called_once()


def test_seqoutput_plot_shows_figure(base_seq):
"""SEQoutput.plot() should display the figure without raising."""
base_seq.plot()
result = base_seq.collect()
result.plot() # must not raise; previously printed Figure repr instead


def test_seqoutput_plot_raises_without_km_graph():
"""SEQoutput.plot() raises ValueError when no figure was generated."""
data = load_data("SEQdata")
s = SEQuential(
data,
id_col="ID",
time_col="time",
eligible_col="eligible",
treatment_col="tx_init",
outcome_col="outcome",
time_varying_cols=["N", "L", "P"],
fixed_cols=["sex"],
method="ITT",
parameters=SEQopts(km_curves=False),
)
s.expand()
s.fit()
result = s.collect()
with pytest.raises(ValueError):
result.plot()


def test_sequential_plot_risk(base_seq):
base_seq.plot(plot_type="risk")
assert isinstance(base_seq.km_graph, matplotlib.figure.Figure)


def test_sequential_plot_survival():
data = load_data("SEQdata")
s = SEQuential(
data,
id_col="ID",
time_col="time",
eligible_col="eligible",
treatment_col="tx_init",
outcome_col="outcome",
time_varying_cols=["N", "L", "P"],
fixed_cols=["sex"],
method="ITT",
parameters=SEQopts(km_curves=True, plot_type="survival"),
)
s.expand()
s.fit()
s.survival()
s.plot()
assert isinstance(s.km_graph, matplotlib.figure.Figure)


def test_sequential_plot_subgroups():
data = load_data("SEQdata")
s = SEQuential(
data,
id_col="ID",
time_col="time",
eligible_col="eligible",
treatment_col="tx_init",
outcome_col="outcome",
time_varying_cols=["N", "L", "P"],
fixed_cols=["sex"],
method="ITT",
parameters=SEQopts(km_curves=True, subgroup_colname="sex"),
)
s.expand()
s.fit()
s.survival()
s.plot()
assert isinstance(s.km_graph, matplotlib.figure.Figure)