From 1d278b152d21324f1ac42e7dfb61c5b46fc5df8a Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Thu, 19 Mar 2026 15:34:38 +0100 Subject: [PATCH 1/6] raise valueError rather than warning --- pySEQTarget/helpers/_predict_model.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pySEQTarget/helpers/_predict_model.py b/pySEQTarget/helpers/_predict_model.py index fd3ed30..385cf4e 100644 --- a/pySEQTarget/helpers/_predict_model.py +++ b/pySEQTarget/helpers/_predict_model.py @@ -32,8 +32,12 @@ def _safe_predict(model, data, clip_probs=True): if clip_probs: probs = np.array(probs) if np.any(np.isnan(probs)): - warnings.warn("NaN values in predicted probabilities, replacing with 0.5") - probs = np.where(np.isnan(probs), 0.5, probs) + raise ValueError( + "NaN values in predicted probabilities. This typically indicates " + "a mismatch between the model's training data types and the " + "prediction data (e.g. missing categorical casting), or numerical " + "overflow in the model coefficients." + ) probs = np.clip(probs, 0, 1) return probs From 798caff46ba4628a07da677293dfb1984983b1c2 Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Thu, 19 Mar 2026 15:35:05 +0100 Subject: [PATCH 2/6] store type and cast back --- pySEQTarget/helpers/_fix_categories.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pySEQTarget/helpers/_fix_categories.py b/pySEQTarget/helpers/_fix_categories.py index ea9155c..e99d15e 100644 --- a/pySEQTarget/helpers/_fix_categories.py +++ b/pySEQTarget/helpers/_fix_categories.py @@ -1,3 +1,6 @@ +import pandas as pd + + def _fix_categories_for_predict(model, newdata): """ Fix categorical column ordering in newdata to match what the model expects. @@ -13,9 +16,10 @@ def _fix_categories_for_predict(model, newdata): col_name = factor.name() if col_name in newdata.columns: expected_categories = list(factor_info.categories) - newdata[col_name] = newdata[col_name].astype(str) - newdata[col_name] = newdata[col_name].astype("category") - newdata[col_name] = newdata[col_name].cat.set_categories( - expected_categories + cat_type = pd.CategoricalDtype( + categories=expected_categories ) + newdata[col_name] = newdata[col_name].astype( + type(expected_categories[0]) + ).astype(cat_type) return newdata From 3557e9c1386abd689d6ba158a517d31499df84ba Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Thu, 19 Mar 2026 15:35:44 +0100 Subject: [PATCH 3/6] update cast_categories to handle missing treatment cols --- pySEQTarget/analysis/_outcome_fit.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pySEQTarget/analysis/_outcome_fit.py b/pySEQTarget/analysis/_outcome_fit.py index e211bb4..1bfc045 100644 --- a/pySEQTarget/analysis/_outcome_fit.py +++ b/pySEQTarget/analysis/_outcome_fit.py @@ -23,9 +23,11 @@ def _apply_spline_formula(formula, indicator_squared): def _cast_categories(self, df_pd): - df_pd[self.treatment_col] = df_pd[self.treatment_col].astype("category") + if self.treatment_col in df_pd.columns: + df_pd[self.treatment_col] = df_pd[self.treatment_col].astype("category") tx_bas = f"{self.treatment_col}{self.indicator_baseline}" - df_pd[tx_bas] = df_pd[tx_bas].astype("category") + if tx_bas in df_pd.columns: + df_pd[tx_bas] = df_pd[tx_bas].astype("category") if self.followup_class and not self.followup_spline: df_pd["followup"] = df_pd["followup"].astype("category") From 06bc41489e6c258d0fbbf506321148b01157510e Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Thu, 19 Mar 2026 15:36:02 +0100 Subject: [PATCH 4/6] add category casting to other predictions --- pySEQTarget/analysis/_hazard.py | 5 +++-- pySEQTarget/analysis/_survival_pred.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/pySEQTarget/analysis/_hazard.py b/pySEQTarget/analysis/_hazard.py index 4f39a6c..41039ee 100644 --- a/pySEQTarget/analysis/_hazard.py +++ b/pySEQTarget/analysis/_hazard.py @@ -5,6 +5,7 @@ from lifelines import CoxPHFitter from ..helpers._predict_model import _safe_predict +from ._outcome_fit import _cast_categories def _calculate_hazard(self): @@ -111,14 +112,14 @@ def _hazard_handler(self, data, idx, boot_idx, rng): [pl.lit(val).alias(f"{self.treatment_col}{self.indicator_baseline}")] ) - tmp_pd = tmp.to_pandas() + tmp_pd = _cast_categories(self, tmp.to_pandas()) outcome_prob = _safe_predict(outcome_model, tmp_pd) outcome_sim = rng.binomial(1, outcome_prob) tmp = tmp.with_columns([pl.Series("outcome", outcome_sim)]) if ce_model is not None: - ce_tmp_pd = tmp.to_pandas() + ce_tmp_pd = _cast_categories(self, tmp.to_pandas()) ce_prob = _safe_predict(ce_model, ce_tmp_pd) ce_sim = rng.binomial(1, ce_prob) tmp = tmp.with_columns([pl.Series("ce", ce_sim)]) diff --git a/pySEQTarget/analysis/_survival_pred.py b/pySEQTarget/analysis/_survival_pred.py index 6e6579c..8ec8dbd 100644 --- a/pySEQTarget/analysis/_survival_pred.py +++ b/pySEQTarget/analysis/_survival_pred.py @@ -1,10 +1,11 @@ import polars as pl from ..helpers._predict_model import _safe_predict +from ._outcome_fit import _cast_categories def _get_outcome_predictions(self, TxDT, idx=None): - data = TxDT.to_pandas() + data = _cast_categories(self, TxDT.to_pandas()) predictions = {"outcome": []} if self.compevent_colname is not None: predictions["compevent"] = [] From 2d3492c8d7f07bdc1062c41dce71f8cdb8e1a04e Mon Sep 17 00:00:00 2001 From: Ryan O'Dea <70209371+ryan-odea@users.noreply.github.com> Date: Thu, 19 Mar 2026 15:46:37 +0100 Subject: [PATCH 5/6] bump version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 830ec6c..323721d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pySEQTarget" -version = "0.12.4" +version = "0.12.5" description = "Sequentially Nested Target Trial Emulation" readme = "README.md" license = {text = "MIT"} From 5114a23ceacf8899a53a46c20ffcda3b2d3cfb78 Mon Sep 17 00:00:00 2001 From: Tom Palmer Date: Thu, 19 Mar 2026 15:25:21 +0000 Subject: [PATCH 6/6] Add regression test for integer categorical dtype preservation in _fix_categories_for_predict Add isinstance check to silence pyright warning on CategoricalDtype --- tests/test_fix_categories.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 tests/test_fix_categories.py diff --git a/tests/test_fix_categories.py b/tests/test_fix_categories.py new file mode 100644 index 0000000..fdfca10 --- /dev/null +++ b/tests/test_fix_categories.py @@ -0,0 +1,30 @@ +import numpy as np +import pandas as pd +import statsmodels.api as sm +import statsmodels.formula.api as smf + +from pySEQTarget.helpers._fix_categories import _fix_categories_for_predict + + +def test_fix_categories_preserves_integer_dtype(): + """Regression test: integer categorical columns must not be coerced to strings. + The old code did .astype(str).astype('category') which caused category-level + mismatches and NaN predictions.""" + # Mirror actual usage: column is pre-cast to categorical dtype, plain name in formula + df = pd.DataFrame({ + "y": [0, 1, 0, 1, 0, 1, 0, 1], + "tx": pd.Categorical([0, 0, 0, 0, 1, 1, 1, 1]), + }) + model = smf.glm("y ~ tx", data=df, family=sm.families.Binomial()).fit(disp=False) + + # Plain integers — what Polars→pandas produces before _cast_categories runs + newdata = pd.DataFrame({"tx": [0, 1, 0, 1]}) + fixed = _fix_categories_for_predict(model, newdata) + + # Categories must remain integer-typed, not strings + assert isinstance(fixed["tx"].dtype, pd.CategoricalDtype) + assert fixed["tx"].dtype.categories.dtype == df["tx"].cat.categories.dtype + + # No NaNs in predictions — this would fail on the old code + probs = model.predict(fixed) + assert not np.any(np.isnan(probs))