Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pySEQTarget/analysis/_hazard.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from lifelines import CoxPHFitter

from ..helpers._predict_model import _safe_predict
from ._outcome_fit import _cast_categories


def _calculate_hazard(self):
Expand Down Expand Up @@ -111,14 +112,14 @@ def _hazard_handler(self, data, idx, boot_idx, rng):
[pl.lit(val).alias(f"{self.treatment_col}{self.indicator_baseline}")]
)

tmp_pd = tmp.to_pandas()
tmp_pd = _cast_categories(self, tmp.to_pandas())
outcome_prob = _safe_predict(outcome_model, tmp_pd)
outcome_sim = rng.binomial(1, outcome_prob)

tmp = tmp.with_columns([pl.Series("outcome", outcome_sim)])

if ce_model is not None:
ce_tmp_pd = tmp.to_pandas()
ce_tmp_pd = _cast_categories(self, tmp.to_pandas())
ce_prob = _safe_predict(ce_model, ce_tmp_pd)
ce_sim = rng.binomial(1, ce_prob)
tmp = tmp.with_columns([pl.Series("ce", ce_sim)])
Expand Down
6 changes: 4 additions & 2 deletions pySEQTarget/analysis/_outcome_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ def _apply_spline_formula(formula, indicator_squared):


def _cast_categories(self, df_pd):
df_pd[self.treatment_col] = df_pd[self.treatment_col].astype("category")
if self.treatment_col in df_pd.columns:
df_pd[self.treatment_col] = df_pd[self.treatment_col].astype("category")
tx_bas = f"{self.treatment_col}{self.indicator_baseline}"
df_pd[tx_bas] = df_pd[tx_bas].astype("category")
if tx_bas in df_pd.columns:
df_pd[tx_bas] = df_pd[tx_bas].astype("category")

if self.followup_class and not self.followup_spline:
df_pd["followup"] = df_pd["followup"].astype("category")
Expand Down
3 changes: 2 additions & 1 deletion pySEQTarget/analysis/_survival_pred.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import polars as pl

from ..helpers._predict_model import _safe_predict
from ._outcome_fit import _cast_categories


def _get_outcome_predictions(self, TxDT, idx=None):
data = TxDT.to_pandas()
data = _cast_categories(self, TxDT.to_pandas())
predictions = {"outcome": []}
if self.compevent_colname is not None:
predictions["compevent"] = []
Expand Down
12 changes: 8 additions & 4 deletions pySEQTarget/helpers/_fix_categories.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import pandas as pd


def _fix_categories_for_predict(model, newdata):
"""
Fix categorical column ordering in newdata to match what the model expects.
Expand All @@ -13,9 +16,10 @@ def _fix_categories_for_predict(model, newdata):
col_name = factor.name()
if col_name in newdata.columns:
expected_categories = list(factor_info.categories)
newdata[col_name] = newdata[col_name].astype(str)
newdata[col_name] = newdata[col_name].astype("category")
newdata[col_name] = newdata[col_name].cat.set_categories(
expected_categories
cat_type = pd.CategoricalDtype(
categories=expected_categories
)
newdata[col_name] = newdata[col_name].astype(
type(expected_categories[0])
).astype(cat_type)
return newdata
8 changes: 6 additions & 2 deletions pySEQTarget/helpers/_predict_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@ def _safe_predict(model, data, clip_probs=True):
if clip_probs:
probs = np.array(probs)
if np.any(np.isnan(probs)):
warnings.warn("NaN values in predicted probabilities, replacing with 0.5")
probs = np.where(np.isnan(probs), 0.5, probs)
raise ValueError(
"NaN values in predicted probabilities. This typically indicates "
"a mismatch between the model's training data types and the "
"prediction data (e.g. missing categorical casting), or numerical "
"overflow in the model coefficients."
)
probs = np.clip(probs, 0, 1)

return probs
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "pySEQTarget"
version = "0.12.4"
version = "0.12.5"
description = "Sequentially Nested Target Trial Emulation"
readme = "README.md"
license = {text = "MIT"}
Expand Down
30 changes: 30 additions & 0 deletions tests/test_fix_categories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import numpy as np
import pandas as pd
import statsmodels.api as sm
import statsmodels.formula.api as smf

from pySEQTarget.helpers._fix_categories import _fix_categories_for_predict


def test_fix_categories_preserves_integer_dtype():
"""Regression test: integer categorical columns must not be coerced to strings.
The old code did .astype(str).astype('category') which caused category-level
mismatches and NaN predictions."""
# Mirror actual usage: column is pre-cast to categorical dtype, plain name in formula
df = pd.DataFrame({
"y": [0, 1, 0, 1, 0, 1, 0, 1],
"tx": pd.Categorical([0, 0, 0, 0, 1, 1, 1, 1]),
})
model = smf.glm("y ~ tx", data=df, family=sm.families.Binomial()).fit(disp=False)

# Plain integers — what Polars→pandas produces before _cast_categories runs
newdata = pd.DataFrame({"tx": [0, 1, 0, 1]})
fixed = _fix_categories_for_predict(model, newdata)

# Categories must remain integer-typed, not strings
assert isinstance(fixed["tx"].dtype, pd.CategoricalDtype)
assert fixed["tx"].dtype.categories.dtype == df["tx"].cat.categories.dtype

# No NaNs in predictions — this would fail on the old code
probs = model.predict(fixed)
assert not np.any(np.isnan(probs))