Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
uv run pytest tests/ -v --cov=pySEQTarget --cov-report=xml

- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
uses: codecov/codecov-action@v6
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: CausalInference/pySEQTarget
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
templates_path = ["_templates"]
exclude_patterns = []

autodoc_class_signature = "separated"


# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
Expand Down
3 changes: 3 additions & 0 deletions pySEQTarget/SEQuential.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,9 @@ def survival(self, **kwargs) -> None:
self.km_data = _clamp(pl.concat([risk_data, surv_data]))
self.risk_estimates = _risk_estimates(self)

if hasattr(self, "_boot_risks"):
del self._boot_risks

end = time.perf_counter()
self._survival_time = _format_time(start, end)

Expand Down
18 changes: 11 additions & 7 deletions pySEQTarget/analysis/_hazard.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,17 @@ def _calculate_hazard_single(self, data, idx=None, val=None):
self._rng = np.random.RandomState(self.seed + boot_idx + 1)
id_counts = self._boot_samples[boot_idx]

boot_data_list = []
for id_val, count in id_counts.items():
id_data = data.filter(pl.col(self.id_col) == id_val)
for _ in range(count):
boot_data_list.append(id_data)

boot_data = pl.concat(boot_data_list)
counts = pl.DataFrame(
{self.id_col: list(id_counts.keys()), "_count": list(id_counts.values())}
)
boot_data = (
data.lazy()
.join(counts.lazy(), on=self.id_col, how="inner")
.with_columns(pl.int_ranges(0, pl.col("_count")).alias("_rep"))
.explode("_rep")
.drop("_count", "_rep")
.collect()
)

boot_log_hr = _hazard_handler(self, boot_data, idx, boot_idx + 1, self._rng)
if boot_log_hr is not None and not np.isnan(boot_log_hr):
Expand Down
20 changes: 15 additions & 5 deletions pySEQTarget/analysis/_risk_estimates.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,18 +141,28 @@ def _risk_estimates(self):
(pl.col("risk_y") > 0) & (pl.col("risk_x") >= 0)
).with_columns((pl.col("risk_x") / pl.col("risk_y")).alias("RR"))

n_valid_rr = len(valid_rr)

if self.bootstrap_CI_method == "percentile":
rd_lci = float(paired["RD"].quantile(alpha / 2))
rd_uci = float(paired["RD"].quantile(1 - alpha / 2))
rr_lci = float(valid_rr["RR"].quantile(alpha / 2))
rr_uci = float(valid_rr["RR"].quantile(1 - alpha / 2))
if n_valid_rr >= 2:
rr_lci = float(valid_rr["RR"].quantile(alpha / 2))
rr_uci = float(valid_rr["RR"].quantile(1 - alpha / 2))
else:
rr_lci = float("nan")
rr_uci = float("nan")
else:
rd_se = float(paired["RD"].std())
rd_lci = rd_point - z * rd_se
rd_uci = rd_point + z * rd_se
log_rr_se = float(valid_rr["RR"].log().std())
rr_lci = math.exp(math.log(rr_point) - z * log_rr_se)
rr_uci = math.exp(math.log(rr_point) + z * log_rr_se)
if n_valid_rr >= 2 and rr_point > 0:
log_rr_se = float(valid_rr["RR"].log().std())
rr_lci = math.exp(math.log(rr_point) - z * log_rr_se)
rr_uci = math.exp(math.log(rr_point) + z * log_rr_se)
else:
rr_lci = float("nan")
rr_uci = float("nan")

rd_comp = pl.DataFrame(
{
Expand Down
9 changes: 6 additions & 3 deletions pySEQTarget/analysis/_survival_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ def _get_outcome_predictions(self, TxDT, idx=None):
for boot_model in self.outcome_model:
model_dict = boot_model[idx] if idx is not None else boot_model
outcome_model = self._offloader.load_model(model_dict["outcome"])
predictions["outcome"].append(_safe_predict(outcome_model, data.copy()))
predictions["outcome"].append(_safe_predict(outcome_model, data))

if self.compevent_colname is not None:
compevent_model = self._offloader.load_model(model_dict["compevent"])
predictions["compevent"].append(_safe_predict(compevent_model, data.copy()))
predictions["compevent"].append(_safe_predict(compevent_model, data))

return predictions

Expand Down Expand Up @@ -79,7 +79,6 @@ def _calculate_risk(self, data, idx=None, val=None):
)
.group_by("TID")
.first()
.sort("TID")
.drop(["followup", f"followup{self.indicator_squared}"])
.with_columns([pl.lit(followup_range).alias("followup")])
.explode("followup")
Expand Down Expand Up @@ -112,6 +111,10 @@ def _calculate_risk(self, data, idx=None, val=None):
)

preds = _get_outcome_predictions(self, TxDT, idx=idx)

# Drop original data columns — only followup and TID needed from here
TxDT = TxDT.select(["followup", "TID"])

pred_series = [pl.Series("pred_outcome", preds["outcome"][0])]

if self.bootstrap_nboot > 0:
Expand Down
26 changes: 11 additions & 15 deletions pySEQTarget/expansion/_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,17 @@ def _dynamic(self):
).sort([self.id_col, "trial", "followup"])

if self.excused:
DT = (
DT.with_columns(
pl.col("isExcused")
.cast(pl.Int8)
.cum_sum()
.over([self.id_col, "trial"])
.alias("_excused_tmp")
)
.with_columns(
pl.when(pl.col("_excused_tmp") > 0)
.then(pl.lit(False))
.otherwise(pl.col("switch"))
.alias("switch")
)
.drop("_excused_tmp")
excused_cumsum = (
pl.col("isExcused")
.cast(pl.Int8)
.cum_sum()
.over([self.id_col, "trial"])
)
DT = DT.with_columns(
pl.when(excused_cumsum > 0)
.then(pl.lit(False))
.otherwise(pl.col("switch"))
.alias("switch")
)

DT = DT.filter(
Expand Down
6 changes: 4 additions & 2 deletions pySEQTarget/helpers/_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ def _prepare_boot_data(self, data, boot_id):
{self.id_col: list(id_counts.keys()), "count": list(id_counts.values())}
)

bootstrapped = data.join(counts, on=self.id_col, how="inner")
bootstrapped = (
bootstrapped.with_columns(pl.int_ranges(0, pl.col("count")).alias("replicate"))
data.lazy()
.join(counts.lazy(), on=self.id_col, how="inner")
.with_columns(pl.int_ranges(0, pl.col("count")).alias("replicate"))
.explode("replicate")
.with_columns(
(
Expand All @@ -29,6 +30,7 @@ def _prepare_boot_data(self, data, boot_id):
).alias(self.id_col)
)
.drop("count", "replicate")
.collect()
)

return bootstrapped
Expand Down
4 changes: 1 addition & 3 deletions pySEQTarget/helpers/_predict_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@ def _safe_predict(model, data, clip_probs=True):
clip_probs : bool
If True, clip probabilities to [0, 1] and replace NaN with 0.5
"""
data = data.copy()

try:
probs = model.predict(data)
except Exception as e:
if "mismatching levels" in str(e):
data = _fix_categories_for_predict(model, data)
data = _fix_categories_for_predict(model, data.copy())
probs = model.predict(data)
else:
raise
Expand Down
40 changes: 24 additions & 16 deletions pySEQTarget/weighting/_weight_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,15 +173,19 @@ def _weight_predict(self, WDT):
if cense_num_model is not None and cense_denom_model is not None:
p_num = _predict_model(self, cense_num_model, WDT).flatten()
p_denom = _predict_model(self, cense_denom_model, WDT).flatten()
WDT = WDT.with_columns(
[
pl.Series("cense_numerator", p_num),
pl.Series("cense_denominator", p_denom),
]
).with_columns(
(pl.col("cense_numerator") / pl.col("cense_denominator")).alias(
"_cense"
WDT = (
WDT.with_columns(
[
pl.Series("cense_numerator", p_num),
pl.Series("cense_denominator", p_denom),
]
)
.with_columns(
(pl.col("cense_numerator") / pl.col("cense_denominator")).alias(
"_cense"
)
)
.drop(["cense_numerator", "cense_denominator"])
)
else:
WDT = WDT.with_columns(pl.lit(1.0).alias("_cense"))
Expand All @@ -194,15 +198,19 @@ def _weight_predict(self, WDT):
if visit_num_model is not None and visit_denom_model is not None:
p_num = _predict_model(self, visit_num_model, WDT).flatten()
p_denom = _predict_model(self, visit_denom_model, WDT).flatten()
WDT = WDT.with_columns(
[
pl.Series("visit_numerator", p_num),
pl.Series("visit_denominator", p_denom),
]
).with_columns(
(pl.col("visit_numerator") / pl.col("visit_denominator")).alias(
"_visit"
WDT = (
WDT.with_columns(
[
pl.Series("visit_numerator", p_num),
pl.Series("visit_denominator", p_denom),
]
)
.with_columns(
(pl.col("visit_numerator") / pl.col("visit_denominator")).alias(
"_visit"
)
)
.drop(["visit_numerator", "visit_denominator"])
)
else:
WDT = WDT.with_columns(pl.lit(1.0).alias("_visit"))
Expand Down
4 changes: 2 additions & 2 deletions pySEQTarget/weighting/_weight_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def _weight_stats(self):
)

if self.weight_p99:
self.weight_min = stats.select("weight_p01").item()
self.weight_max = stats.select("weight_p99").item()
self.weight_min = float(stats["weight_p01"][0])
self.weight_max = float(stats["weight_p99"][0])

return stats
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "pySEQTarget"
version = "0.12.7"
version = "0.12.8"
description = "Sequentially Nested Target Trial Emulation"
readme = "README.md"
license = {text = "MIT"}
Expand Down
2 changes: 1 addition & 1 deletion tests/test_hazard.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_bootstrap_hazard():
time_varying_cols=["N", "L", "P"],
fixed_cols=["sex"],
method="ITT",
parameters=SEQopts(hazard_estimate=True, bootstrap_nboot=2),
parameters=SEQopts(hazard_estimate=True, bootstrap_nboot=2, seed=42),
)
s.expand()
s.bootstrap()
Expand Down
1 change: 1 addition & 0 deletions tests/test_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def test_compevent_offload():
weight_p99=True,
weight_preexpansion=True,
offload=True,
seed=42,
)

model = SEQuential(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_parallel_ITT():
time_varying_cols=["N", "L", "P"],
fixed_cols=["sex"],
method="ITT",
parameters=SEQopts(parallel=True, bootstrap_nboot=2, ncores=1),
parameters=SEQopts(parallel=True, bootstrap_nboot=2, ncores=1, seed=42),
)
s.expand()
s.bootstrap()
Expand Down
5 changes: 3 additions & 2 deletions tests/test_survival.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_bootstrapped_survival():
time_varying_cols=["N", "L", "P"],
fixed_cols=["sex"],
method="ITT",
parameters=SEQopts(km_curves=True, bootstrap_nboot=2),
parameters=SEQopts(km_curves=True, bootstrap_nboot=2, seed=42),
)
s.expand()
s.bootstrap()
Expand Down Expand Up @@ -83,7 +83,7 @@ def test_subgroup_bootstrapped_survival():
time_varying_cols=["N", "L", "P"],
fixed_cols=["sex"],
method="ITT",
parameters=SEQopts(km_curves=True, subgroup_colname="sex", bootstrap_nboot=2),
parameters=SEQopts(km_curves=True, subgroup_colname="sex", bootstrap_nboot=2, seed=42),
)
s.expand()
s.bootstrap()
Expand Down Expand Up @@ -139,6 +139,7 @@ def test_bootstrapped_compevent():
compevent_colname="LTFU",
plot_type="incidence",
bootstrap_nboot=2,
seed=42,
),
)
s.expand()
Expand Down
Loading