Skip to content

Allow fit methods to accept pd.Series and pd.DataFrame (#62)#92

Merged
TomeHirata merged 5 commits intomainfrom
feat/input/arrylike_to_ndarray
Mar 9, 2026
Merged

Allow fit methods to accept pd.Series and pd.DataFrame (#62)#92
TomeHirata merged 5 commits intomainfrom
feat/input/arrylike_to_ndarray

Conversation

@okiner-3
Copy link
Collaborator

@okiner-3 okiner-3 commented Feb 28, 2026

close #62

Summary

  • Allow fit methods of all estimators to accept the following types:
    • pd.Series
    • pd.DataFrame
    • pl.Series
    • pl.DataFrame
    • list
    • tuple
  • Inputs are automatically converted to np.ndarray inside the fit method
  • Add ArrayLike type alias and update type hints for _convert_to_ndarray and all fit methods

Changes

  • Add _convert_to_ndarray helper and ArrayLike type in dte_adj/util.py
  • Apply conversion in fit methods of all 6 estimator classes
  • Update type hints from np.ndarray to ArrayLike for fit method parameters
  • Add polars to dev-dependencies in pyproject.toml
  • Add tests/test_utils.py with test cases for _convert_to_ndarray

@okiner-3 okiner-3 self-assigned this Feb 28, 2026
@okiner-3 okiner-3 requested a review from TomeHirata February 28, 2026 02:54
dte_adj/util.py Outdated
)


def _convert_to_ndarray(data: object) -> np.ndarray:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we update the type hint to Dataframe, ndarray and Series only?

@@ -0,0 +1,161 @@
import unittest
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use test_utils.py?

Copy link
Collaborator

@TomeHirata TomeHirata Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine to test _convert_to_ndarray only. Alternativel, we can add a test case for each estimator class.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we update the type hint for fit?

"""
super().__init__()

def fit(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

"""
super().__init__()

def fit(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@okiner-3 okiner-3 requested a review from TomeHirata March 9, 2026 09:51
dte_adj/local.py Outdated
treatment_indicator (np.ndarray): Treatment indicator variable (D).
outcomes (np.ndarray): Scalar-valued observed outcome.
strata (np.ndarray): Stratum indicators.
covariates (ArrayLike): Pre-treatment covariates.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ArrayLike is not a public API; so let's probably remove these types from the docstring, as it's already shown in the type hints.

TomeHirata
TomeHirata previously approved these changes Mar 9, 2026
Copy link
Collaborator

@TomeHirata TomeHirata left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a minor comment, otherwise LGTM

Copy link
Collaborator

@TomeHirata TomeHirata left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@TomeHirata TomeHirata merged commit 29de8a0 into main Mar 9, 2026
9 checks passed
@TomeHirata TomeHirata deleted the feat/input/arrylike_to_ndarray branch March 9, 2026 11:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Allow to pass pd.Series and pd.DataFrame

2 participants