Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/vignettes/exploring_results.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Exploring Results

Recall our previous example, {doc}`~vignettes/more_advanced_models`, where we finalized and collected our results with
Recall our previous example, {doc}`More Advanced Analysis <more_advanced_models>`, where we finalized and collected our results with

```python
my_output = my_analysis.collect()
Expand Down Expand Up @@ -185,4 +185,4 @@ Because we have an excused-censoring analysis, we are also provided with informa
| 1 | False | 1 | 1256 |
| 1 | False | 0 | 527107 |
| 1 | True | 0 | 18508 |
| 0 | False | 0 | 91300 |
| 0 | False | 0 | 91300 |
4 changes: 2 additions & 2 deletions docs/vignettes/more_advanced_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ from pySEQTarget.data import load_data

data = load_data("SEQdata_LTFU")
my_options = SEQopts(
bootstrap_nboot = 20, # 20 bootstrap iterations
bootstrap_nboot = 20, # 20 bootstrap iterations (for demonstration only — use 500+ in practice)
cense_colname = "LTFU", # control for losses-to-followup as a censor
excused = True, # allow excused treatment swapping
excused_colnames = ["excusedZero", "excusedOne"],
Expand All @@ -25,7 +25,7 @@ my_options = SEQopts(
weighted = True, # enables the weighting
weight_lag_condition=False, # turn off lag condition when weighting for adherance
weight_p99 = True, # bounds weights by the 1st and 99th percentile
weight_preexpansion = False # weights are predicted using post-expansion data as a stabilizer
weight_preexpansion = True # weights are predicted using pre-expansion data
)
```

Expand Down
18 changes: 10 additions & 8 deletions pySEQTarget/SEQopts.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ class SEQopts:
:type denominator: Optional[str] or None
:param excused: Boolean to allow excused conditions when method is censoring
:type excused: bool
:param excused_colnames: Column names (at the same length of treatment_level) specifying excused conditions
:type excused_colnames: List[str] or []
:param excused_colnames: Column names (at the same length of treatment_level) specifying excused conditions, default ``[]``
:type excused_colnames: List[str]
:param followup_class: Boolean to force followup values to be treated as classes
:type followup_class: bool
:param followup_include: Boolean to force regular followup values into model covariates
Expand All @@ -54,7 +54,7 @@ class SEQopts:
:type indicator_squared: str
:param km_curves: Boolean to create survival, risk, and incidence (if applicable) estimates
:type km_curves: bool
:param ncores: Number of cores to use if running in parallel
:param ncores: Number of cores to use if running in parallel, default ``max(1, cpu_count() - 1)``
:type ncores: int
:param numerator: Override to specify the outcome patsy formula for
numerator models; "1" or "" indicate intercept only model
Expand All @@ -65,9 +65,9 @@ class SEQopts:
:type offload_dir: str
:param parallel: Boolean to run model fitting in parallel
:type parallel: bool
:param plot_colors: List of colors for KM plots, if applicable
:param plot_colors: List of colors for KM plots, if applicable, default ``["#F8766D", "#00BFC4", "#555555"]``
:type plot_colors: List[str]
:param plot_labels: List of length treat_level to specify treatment labeling
:param plot_labels: List of length treat_level to specify treatment labeling, default ``[]``
:type plot_labels: List[str]
:param plot_title: Plot title
:type plot_title: str
Expand All @@ -83,14 +83,14 @@ class SEQopts:
:type selection_random: bool
:param subgroup_colname: Column name for subgroups to share the same weighting but different outcome model fits
:type subgroup_colname: str
:param treatment_level: List of eligible treatment levels within treatment_col
:param treatment_level: List of eligible treatment levels within treatment_col, default ``[0, 1]``
:type treatment_level: List[int]
:param trial_include: Boolean to force trial values into model covariates
:type trial_include: bool
:param visit_colname: Column name specifying visit number
:type visit_colname: str
:param weight_eligible_colnames: List of column names of length
treatment_level to identify which rows are eligible for weight fitting
treatment_level to identify which rows are eligible for weight fitting, default ``[]``
:type weight_eligible_colnames: List[str]
:param weight_fit_method: The fitting method to be used ["newton", "bfgs", "lbfgs", "nm"], default "newton"
:type weight_fit_method: str
Expand Down Expand Up @@ -130,7 +130,7 @@ class SEQopts:
indicator_baseline: str = "_bas"
indicator_squared: str = "_sq"
km_curves: bool = False
ncores: int = max(1, multiprocessing.cpu_count() - 1)
ncores: Optional[int] = None
numerator: Optional[str] = None
offload: bool = False
offload_dir: str = "_seq_models"
Expand Down Expand Up @@ -220,6 +220,8 @@ def _normalize_formulas(self):
setattr(self, i, "".join(attr.split()))

def __post_init__(self):
if self.ncores is None:
self.ncores = max(1, multiprocessing.cpu_count() - 1)
self._validate_bools()
self._validate_ranges()
self._validate_choices()
Expand Down
2 changes: 2 additions & 0 deletions pySEQTarget/SEQoutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def summary(
) -> List:
"""
Returns a list of model summaries of either the numerator, denominator, outcome, or competing event models

:param type: Indicator for which model list you would like returned
:type type: str
"""
Expand Down Expand Up @@ -108,6 +109,7 @@ def retrieve_data(
) -> pl.DataFrame:
"""
Getter for data stored within ``SEQoutput``

:param type: Data which you would like to access, ['km_data', 'hazard',
'risk_ratio', 'risk_difference', 'unique_outcomes',
'nonunique_outcomes', 'unique_switches', 'nonunique_switches']
Expand Down
123 changes: 97 additions & 26 deletions pySEQTarget/analysis/_risk_estimates.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import math

import polars as pl
from scipy import stats


def _compute_rd_rr(comp, has_bootstrap, z=None, group_cols=None):
"""
Compute Risk Difference and Risk Ratio from a comparison dataframe.
Consolidates the repeated calculation logic.
Fallback used when paired bootstrap data is unavailable (e.g. subgroups).
"""
if group_cols is None:
group_cols = []
Expand Down Expand Up @@ -80,20 +82,30 @@ def _risk_estimates(self):
group_cols = [self.subgroup_colname] if self.subgroup_colname else []
has_bootstrap = self.bootstrap_nboot > 0

# Paired approach: compute RD_i and RR_i per bootstrap iteration then take
# SE or percentile of those — correct because arms share the same bootstrap
# samples, so pairing captures their correlation. Falls back to the
# independent delta method for subgroups (not yet supported) or when
# _boot_risks is unavailable.
use_paired = (
has_bootstrap
and not group_cols
and hasattr(self, "_boot_risks")
and all(tx in self._boot_risks for tx in self.treatment_level)
)

if has_bootstrap:
alpha = 1 - self.bootstrap_CI
z = stats.norm.ppf(1 - alpha / 2)
else:
z = None
alpha = None

# Pre-extract data for each treatment level once (avoid repeated filtering)
risk_by_level = {}
for tx in self.treatment_level:
level_data = risk.filter(pl.col(self.treatment_col) == tx)
risk_by_level[tx] = {
"pred": level_data.select(group_cols + ["pred"]),
}
if has_bootstrap:
risk_by_level[tx] = {"pred": level_data.select(group_cols + ["pred"])}
if has_bootstrap and not use_paired:
risk_by_level[tx]["SE"] = level_data.select(group_cols + ["SE"])

rd_comparisons = []
Expand All @@ -104,31 +116,90 @@ def _risk_estimates(self):
if tx_x == tx_y:
continue

# Use pre-extracted data instead of filtering again
risk_x = risk_by_level[tx_x]["pred"].rename({"pred": "risk_x"})
risk_y = risk_by_level[tx_y]["pred"].rename({"pred": "risk_y"})

if group_cols:
comp = risk_x.join(risk_y, on=group_cols, how="left")
if use_paired:
boot_x = (
self._boot_risks[tx_x]
.filter(pl.col("followup") == last_followup)
.select(["boot_idx", pl.col("risk").alias("risk_x")])
)
boot_y = (
self._boot_risks[tx_y]
.filter(pl.col("followup") == last_followup)
.select(["boot_idx", pl.col("risk").alias("risk_y")])
)
paired = boot_x.join(boot_y, on="boot_idx").with_columns(
(pl.col("risk_x") - pl.col("risk_y")).alias("RD")
)

risk_x_val = float(risk_by_level[tx_x]["pred"]["pred"][0])
risk_y_val = float(risk_by_level[tx_y]["pred"]["pred"][0])
rd_point = risk_x_val - risk_y_val
rr_point = risk_x_val / risk_y_val if risk_y_val != 0 else float("inf")

# Filter degenerate RR bootstrap values (risk_y == 0 or negative)
valid_rr = paired.filter(
(pl.col("risk_y") > 0) & (pl.col("risk_x") >= 0)
).with_columns(
(pl.col("risk_x") / pl.col("risk_y")).alias("RR")
)

if self.bootstrap_CI_method == "percentile":
rd_lci = float(paired["RD"].quantile(alpha / 2))
rd_uci = float(paired["RD"].quantile(1 - alpha / 2))
rr_lci = float(valid_rr["RR"].quantile(alpha / 2))
rr_uci = float(valid_rr["RR"].quantile(1 - alpha / 2))
else:
rd_se = float(paired["RD"].std())
rd_lci = rd_point - z * rd_se
rd_uci = rd_point + z * rd_se
log_rr_se = float(valid_rr["RR"].log().std())
rr_lci = math.exp(math.log(rr_point) - z * log_rr_se)
rr_uci = math.exp(math.log(rr_point) + z * log_rr_se)

rd_comp = pl.DataFrame(
{
"A_x": [tx_x],
"A_y": [tx_y],
"Risk Difference": [rd_point],
"RD 95% LCI": [rd_lci],
"RD 95% UCI": [rd_uci],
}
)
rr_comp = pl.DataFrame(
{
"A_x": [tx_x],
"A_y": [tx_y],
"Risk Ratio": [rr_point],
"RR 95% LCI": [rr_lci],
"RR 95% UCI": [rr_uci],
}
)
else:
comp = risk_x.join(risk_y, how="cross")

comp = comp.with_columns(
[pl.lit(tx_x).alias("A_x"), pl.lit(tx_y).alias("A_y")]
)

if has_bootstrap:
se_x = risk_by_level[tx_x]["SE"].rename({"SE": "se_x"})
se_y = risk_by_level[tx_y]["SE"].rename({"SE": "se_y"})
# Fall back to independent delta method
risk_x = risk_by_level[tx_x]["pred"].rename({"pred": "risk_x"})
risk_y = risk_by_level[tx_y]["pred"].rename({"pred": "risk_y"})

if group_cols:
comp = comp.join(se_x, on=group_cols, how="left")
comp = comp.join(se_y, on=group_cols, how="left")
comp = risk_x.join(risk_y, on=group_cols, how="left")
else:
comp = comp.join(se_x, how="cross")
comp = comp.join(se_y, how="cross")
comp = risk_x.join(risk_y, how="cross")

comp = comp.with_columns(
[pl.lit(tx_x).alias("A_x"), pl.lit(tx_y).alias("A_y")]
)

if has_bootstrap:
se_x = risk_by_level[tx_x]["SE"].rename({"SE": "se_x"})
se_y = risk_by_level[tx_y]["SE"].rename({"SE": "se_y"})
if group_cols:
comp = comp.join(se_x, on=group_cols, how="left")
comp = comp.join(se_y, on=group_cols, how="left")
else:
comp = comp.join(se_x, how="cross")
comp = comp.join(se_y, how="cross")

rd_comp, rr_comp = _compute_rd_rr(comp, has_bootstrap, z, group_cols)

rd_comp, rr_comp = _compute_rd_rr(comp, has_bootstrap, z, group_cols)
rd_comparisons.append(rd_comp)
rr_comparisons.append(rr_comp)

Expand Down
27 changes: 27 additions & 0 deletions pySEQTarget/analysis/_survival_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,27 @@
from ._outcome_fit import _cast_categories


def _store_boot_risks(obj, treatment_val, TxDT, boot_cols, is_survival=False):
"""Store per-bootstrap mean risks per followup for paired RD/RR CI computation."""
if not boot_cols:
return
df = TxDT.select(["followup"] + boot_cols)
if is_survival:
df = df.with_columns([(1 - pl.col(c)).alias(c) for c in boot_cols])
obj._boot_risks[treatment_val] = (
df.unpivot(
index="followup",
on=boot_cols,
variable_name="_col",
value_name="risk",
)
.with_columns(
pl.col("_col").str.extract(r"(\d+)$").cast(pl.Int32).alias("boot_idx")
)
.drop("_col")
)


def _get_outcome_predictions(self, TxDT, idx=None):
data = _cast_categories(self, TxDT.to_pandas())
predictions = {"outcome": []}
Expand All @@ -27,6 +48,8 @@ def _pred_risk(self):
isinstance(self.outcome_model[0], list) if self.outcome_model else False
)

self._boot_risks = {}

if not has_subgroups:
return _calculate_risk(self, self.DT, idx=None, val=None)

Expand Down Expand Up @@ -138,6 +161,8 @@ def _calculate_risk(self, data, idx=None, val=None):
)
main_col = "surv"
boot_cols = [col for col in surv_names if col != "surv"]
if val is None:
_store_boot_risks(self, treatment_val, TxDT, boot_cols, is_survival=True)
else:
TxDT = (
TxDT.with_columns(
Expand All @@ -153,6 +178,8 @@ def _calculate_risk(self, data, idx=None, val=None):
)
main_col = "pred_outcome"
boot_cols = [col for col in outcome_names if col != "pred_outcome"]
if val is None:
_store_boot_risks(self, treatment_val, TxDT, boot_cols)

if boot_cols:
risk = (
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "pySEQTarget"
version = "0.12.6"
version = "0.12.7"
description = "Sequentially Nested Target Trial Emulation"
readme = "README.md"
license = {text = "MIT"}
Expand Down