diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index cba4298..6138c85 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -42,7 +42,7 @@ jobs: uv run pytest tests/ -v --cov=pySEQTarget --cov-report=xml - name: Upload coverage reports to Codecov - uses: codecov/codecov-action@v5 + uses: codecov/codecov-action@v6 with: token: ${{ secrets.CODECOV_TOKEN }} slug: CausalInference/pySEQTarget diff --git a/docs/conf.py b/docs/conf.py index 4a3648d..eaa3ba4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -42,6 +42,8 @@ templates_path = ["_templates"] exclude_patterns = [] +autodoc_class_signature = "separated" + # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output diff --git a/pySEQTarget/SEQuential.py b/pySEQTarget/SEQuential.py index 4984c7e..a439ac7 100644 --- a/pySEQTarget/SEQuential.py +++ b/pySEQTarget/SEQuential.py @@ -289,6 +289,9 @@ def survival(self, **kwargs) -> None: self.km_data = _clamp(pl.concat([risk_data, surv_data])) self.risk_estimates = _risk_estimates(self) + if hasattr(self, "_boot_risks"): + del self._boot_risks + end = time.perf_counter() self._survival_time = _format_time(start, end) diff --git a/pySEQTarget/analysis/_hazard.py b/pySEQTarget/analysis/_hazard.py index 41039ee..96f62ac 100644 --- a/pySEQTarget/analysis/_hazard.py +++ b/pySEQTarget/analysis/_hazard.py @@ -40,13 +40,17 @@ def _calculate_hazard_single(self, data, idx=None, val=None): self._rng = np.random.RandomState(self.seed + boot_idx + 1) id_counts = self._boot_samples[boot_idx] - boot_data_list = [] - for id_val, count in id_counts.items(): - id_data = data.filter(pl.col(self.id_col) == id_val) - for _ in range(count): - boot_data_list.append(id_data) - - boot_data = pl.concat(boot_data_list) + counts = pl.DataFrame( + {self.id_col: list(id_counts.keys()), "_count": list(id_counts.values())} + ) + boot_data = ( + data.lazy() + .join(counts.lazy(), on=self.id_col, how="inner") + .with_columns(pl.int_ranges(0, pl.col("_count")).alias("_rep")) + .explode("_rep") + .drop("_count", "_rep") + .collect() + ) boot_log_hr = _hazard_handler(self, boot_data, idx, boot_idx + 1, self._rng) if boot_log_hr is not None and not np.isnan(boot_log_hr): diff --git a/pySEQTarget/analysis/_risk_estimates.py b/pySEQTarget/analysis/_risk_estimates.py index fc19585..561179c 100644 --- a/pySEQTarget/analysis/_risk_estimates.py +++ b/pySEQTarget/analysis/_risk_estimates.py @@ -141,18 +141,28 @@ def _risk_estimates(self): (pl.col("risk_y") > 0) & (pl.col("risk_x") >= 0) ).with_columns((pl.col("risk_x") / pl.col("risk_y")).alias("RR")) + n_valid_rr = len(valid_rr) + if self.bootstrap_CI_method == "percentile": rd_lci = float(paired["RD"].quantile(alpha / 2)) rd_uci = float(paired["RD"].quantile(1 - alpha / 2)) - rr_lci = float(valid_rr["RR"].quantile(alpha / 2)) - rr_uci = float(valid_rr["RR"].quantile(1 - alpha / 2)) + if n_valid_rr >= 2: + rr_lci = float(valid_rr["RR"].quantile(alpha / 2)) + rr_uci = float(valid_rr["RR"].quantile(1 - alpha / 2)) + else: + rr_lci = float("nan") + rr_uci = float("nan") else: rd_se = float(paired["RD"].std()) rd_lci = rd_point - z * rd_se rd_uci = rd_point + z * rd_se - log_rr_se = float(valid_rr["RR"].log().std()) - rr_lci = math.exp(math.log(rr_point) - z * log_rr_se) - rr_uci = math.exp(math.log(rr_point) + z * log_rr_se) + if n_valid_rr >= 2 and rr_point > 0: + log_rr_se = float(valid_rr["RR"].log().std()) + rr_lci = math.exp(math.log(rr_point) - z * log_rr_se) + rr_uci = math.exp(math.log(rr_point) + z * log_rr_se) + else: + rr_lci = float("nan") + rr_uci = float("nan") rd_comp = pl.DataFrame( { diff --git a/pySEQTarget/analysis/_survival_pred.py b/pySEQTarget/analysis/_survival_pred.py index 712a265..48b7e56 100644 --- a/pySEQTarget/analysis/_survival_pred.py +++ b/pySEQTarget/analysis/_survival_pred.py @@ -34,11 +34,11 @@ def _get_outcome_predictions(self, TxDT, idx=None): for boot_model in self.outcome_model: model_dict = boot_model[idx] if idx is not None else boot_model outcome_model = self._offloader.load_model(model_dict["outcome"]) - predictions["outcome"].append(_safe_predict(outcome_model, data.copy())) + predictions["outcome"].append(_safe_predict(outcome_model, data)) if self.compevent_colname is not None: compevent_model = self._offloader.load_model(model_dict["compevent"]) - predictions["compevent"].append(_safe_predict(compevent_model, data.copy())) + predictions["compevent"].append(_safe_predict(compevent_model, data)) return predictions @@ -79,7 +79,6 @@ def _calculate_risk(self, data, idx=None, val=None): ) .group_by("TID") .first() - .sort("TID") .drop(["followup", f"followup{self.indicator_squared}"]) .with_columns([pl.lit(followup_range).alias("followup")]) .explode("followup") @@ -112,6 +111,10 @@ def _calculate_risk(self, data, idx=None, val=None): ) preds = _get_outcome_predictions(self, TxDT, idx=idx) + + # Drop original data columns — only followup and TID needed from here + TxDT = TxDT.select(["followup", "TID"]) + pred_series = [pl.Series("pred_outcome", preds["outcome"][0])] if self.bootstrap_nboot > 0: diff --git a/pySEQTarget/expansion/_dynamic.py b/pySEQTarget/expansion/_dynamic.py index 932ba75..0ab61a9 100644 --- a/pySEQTarget/expansion/_dynamic.py +++ b/pySEQTarget/expansion/_dynamic.py @@ -46,21 +46,17 @@ def _dynamic(self): ).sort([self.id_col, "trial", "followup"]) if self.excused: - DT = ( - DT.with_columns( - pl.col("isExcused") - .cast(pl.Int8) - .cum_sum() - .over([self.id_col, "trial"]) - .alias("_excused_tmp") - ) - .with_columns( - pl.when(pl.col("_excused_tmp") > 0) - .then(pl.lit(False)) - .otherwise(pl.col("switch")) - .alias("switch") - ) - .drop("_excused_tmp") + excused_cumsum = ( + pl.col("isExcused") + .cast(pl.Int8) + .cum_sum() + .over([self.id_col, "trial"]) + ) + DT = DT.with_columns( + pl.when(excused_cumsum > 0) + .then(pl.lit(False)) + .otherwise(pl.col("switch")) + .alias("switch") ) DT = DT.filter( diff --git a/pySEQTarget/helpers/_bootstrap.py b/pySEQTarget/helpers/_bootstrap.py index b828176..2a1b319 100644 --- a/pySEQTarget/helpers/_bootstrap.py +++ b/pySEQTarget/helpers/_bootstrap.py @@ -17,9 +17,10 @@ def _prepare_boot_data(self, data, boot_id): {self.id_col: list(id_counts.keys()), "count": list(id_counts.values())} ) - bootstrapped = data.join(counts, on=self.id_col, how="inner") bootstrapped = ( - bootstrapped.with_columns(pl.int_ranges(0, pl.col("count")).alias("replicate")) + data.lazy() + .join(counts.lazy(), on=self.id_col, how="inner") + .with_columns(pl.int_ranges(0, pl.col("count")).alias("replicate")) .explode("replicate") .with_columns( ( @@ -29,6 +30,7 @@ def _prepare_boot_data(self, data, boot_id): ).alias(self.id_col) ) .drop("count", "replicate") + .collect() ) return bootstrapped diff --git a/pySEQTarget/helpers/_predict_model.py b/pySEQTarget/helpers/_predict_model.py index 385cf4e..41a35ba 100644 --- a/pySEQTarget/helpers/_predict_model.py +++ b/pySEQTarget/helpers/_predict_model.py @@ -18,13 +18,11 @@ def _safe_predict(model, data, clip_probs=True): clip_probs : bool If True, clip probabilities to [0, 1] and replace NaN with 0.5 """ - data = data.copy() - try: probs = model.predict(data) except Exception as e: if "mismatching levels" in str(e): - data = _fix_categories_for_predict(model, data) + data = _fix_categories_for_predict(model, data.copy()) probs = model.predict(data) else: raise diff --git a/pySEQTarget/weighting/_weight_pred.py b/pySEQTarget/weighting/_weight_pred.py index c865e2e..6a2bca9 100644 --- a/pySEQTarget/weighting/_weight_pred.py +++ b/pySEQTarget/weighting/_weight_pred.py @@ -173,15 +173,19 @@ def _weight_predict(self, WDT): if cense_num_model is not None and cense_denom_model is not None: p_num = _predict_model(self, cense_num_model, WDT).flatten() p_denom = _predict_model(self, cense_denom_model, WDT).flatten() - WDT = WDT.with_columns( - [ - pl.Series("cense_numerator", p_num), - pl.Series("cense_denominator", p_denom), - ] - ).with_columns( - (pl.col("cense_numerator") / pl.col("cense_denominator")).alias( - "_cense" + WDT = ( + WDT.with_columns( + [ + pl.Series("cense_numerator", p_num), + pl.Series("cense_denominator", p_denom), + ] + ) + .with_columns( + (pl.col("cense_numerator") / pl.col("cense_denominator")).alias( + "_cense" + ) ) + .drop(["cense_numerator", "cense_denominator"]) ) else: WDT = WDT.with_columns(pl.lit(1.0).alias("_cense")) @@ -194,15 +198,19 @@ def _weight_predict(self, WDT): if visit_num_model is not None and visit_denom_model is not None: p_num = _predict_model(self, visit_num_model, WDT).flatten() p_denom = _predict_model(self, visit_denom_model, WDT).flatten() - WDT = WDT.with_columns( - [ - pl.Series("visit_numerator", p_num), - pl.Series("visit_denominator", p_denom), - ] - ).with_columns( - (pl.col("visit_numerator") / pl.col("visit_denominator")).alias( - "_visit" + WDT = ( + WDT.with_columns( + [ + pl.Series("visit_numerator", p_num), + pl.Series("visit_denominator", p_denom), + ] + ) + .with_columns( + (pl.col("visit_numerator") / pl.col("visit_denominator")).alias( + "_visit" + ) ) + .drop(["visit_numerator", "visit_denominator"]) ) else: WDT = WDT.with_columns(pl.lit(1.0).alias("_visit")) diff --git a/pySEQTarget/weighting/_weight_stats.py b/pySEQTarget/weighting/_weight_stats.py index b6da331..6004884 100644 --- a/pySEQTarget/weighting/_weight_stats.py +++ b/pySEQTarget/weighting/_weight_stats.py @@ -17,7 +17,7 @@ def _weight_stats(self): ) if self.weight_p99: - self.weight_min = stats.select("weight_p01").item() - self.weight_max = stats.select("weight_p99").item() + self.weight_min = float(stats["weight_p01"][0]) + self.weight_max = float(stats["weight_p99"][0]) return stats diff --git a/pyproject.toml b/pyproject.toml index e0e73dc..1e7dac8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pySEQTarget" -version = "0.12.7" +version = "0.12.8" description = "Sequentially Nested Target Trial Emulation" readme = "README.md" license = {text = "MIT"} diff --git a/tests/test_hazard.py b/tests/test_hazard.py index f057dfd..4ba98b1 100644 --- a/tests/test_hazard.py +++ b/tests/test_hazard.py @@ -35,7 +35,7 @@ def test_bootstrap_hazard(): time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], method="ITT", - parameters=SEQopts(hazard_estimate=True, bootstrap_nboot=2), + parameters=SEQopts(hazard_estimate=True, bootstrap_nboot=2, seed=42), ) s.expand() s.bootstrap() diff --git a/tests/test_offload.py b/tests/test_offload.py index 6c46daf..1e82858 100644 --- a/tests/test_offload.py +++ b/tests/test_offload.py @@ -19,6 +19,7 @@ def test_compevent_offload(): weight_p99=True, weight_preexpansion=True, offload=True, + seed=42, ) model = SEQuential( diff --git a/tests/test_parallel.py b/tests/test_parallel.py index 1033db5..907121d 100644 --- a/tests/test_parallel.py +++ b/tests/test_parallel.py @@ -22,7 +22,7 @@ def test_parallel_ITT(): time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], method="ITT", - parameters=SEQopts(parallel=True, bootstrap_nboot=2, ncores=1), + parameters=SEQopts(parallel=True, bootstrap_nboot=2, ncores=1, seed=42), ) s.expand() s.bootstrap() diff --git a/tests/test_survival.py b/tests/test_survival.py index 36f4261..cd2b18f 100644 --- a/tests/test_survival.py +++ b/tests/test_survival.py @@ -40,7 +40,7 @@ def test_bootstrapped_survival(): time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], method="ITT", - parameters=SEQopts(km_curves=True, bootstrap_nboot=2), + parameters=SEQopts(km_curves=True, bootstrap_nboot=2, seed=42), ) s.expand() s.bootstrap() @@ -83,7 +83,7 @@ def test_subgroup_bootstrapped_survival(): time_varying_cols=["N", "L", "P"], fixed_cols=["sex"], method="ITT", - parameters=SEQopts(km_curves=True, subgroup_colname="sex", bootstrap_nboot=2), + parameters=SEQopts(km_curves=True, subgroup_colname="sex", bootstrap_nboot=2, seed=42), ) s.expand() s.bootstrap() @@ -139,6 +139,7 @@ def test_bootstrapped_compevent(): compevent_colname="LTFU", plot_type="incidence", bootstrap_nboot=2, + seed=42, ), ) s.expand()