Expose CATE estimation and inference methods on DRPolicyForest/Tree#1024
Open
jcharit1 wants to merge 1 commit intopy-why:mainfrom
Open
Expose CATE estimation and inference methods on DRPolicyForest/Tree#1024jcharit1 wants to merge 1 commit intopy-why:mainfrom
jcharit1 wants to merge 1 commit intopy-why:mainfrom
Conversation
DRPolicyForest and DRPolicyTree internally fit a DRLearner but do not expose its CATE estimation and inference API. Users who need both optimal policy assignments and CATE confidence intervals may consider strategies such as a 3-way data split: 1. Data for learning CATEs (ForestDRLearner) 2. Data for learning policies (DRPolicyForest using out-of-sample rewards) 3. Data for out-of-sample evaluation with CIs While statistically valid, this approach can be less data-efficient. EconML's cross-fitting within DRPolicyForest already provides noise separation between CATE estimation and policy learning, which could allow users to consolidate the first two splits. Exposing the underlying CATE inference methods enables this more data-efficient workflow. Changes: - Add delegation methods on _BaseDRPolicyLearner for effect(), effect_interval(), effect_inference(), const_marginal_effect(), const_marginal_effect_interval(), const_marginal_effect_inference(), marginal_effect(), ate(), ate_interval(), ate_inference(), shap_values(), score(), and model_final_ - Pass inference parameter through _BaseDRPolicyLearner.fit() to the underlying DRLearner - Fit per-treatment RegressionForest CATE models alongside the policy model in _PolicyModelFinal to support GenericModelFinalInferenceDiscrete - Override _get_inference_options in _DRLearnerWrapper to enable automatic inference when CATE models are available Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Jimmy Charité <jimmy.charite@gmail.com>
d65213c to
af02ab9
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
effect(),effect_inference(),effect_interval(),shap_values(),score())Changes
_BaseDRPolicyLearnerforeffect(),effect_interval(),effect_inference(),const_marginal_effect(),const_marginal_effect_interval(),const_marginal_effect_inference(),marginal_effect(),ate(),ate_interval(),ate_inference(),shap_values(),score(), andmodel_final_inferenceparameter through_BaseDRPolicyLearner.fit()to the underlying DRLearnerRegressionForestCATE models alongside the policy model in_PolicyModelFinalto supportGenericModelFinalInferenceDiscrete_get_inference_optionsin_DRLearnerWrapperto enable automatic inference when CATE models are availableMotivation
When learning personalization policies from A/B test data, a natural workflow is:
DRPolicyForestWithout this change, users must manually split data across separate
ForestDRLearnerandDRPolicyForestfits, losing the data efficiency that EconML's built-in cross-fitting provides.Test plan
test_policy_forest.pytests passeffect(),effect_inference(),effect_interval(),const_marginal_effect_inference(),shap_values(),score()all work onDRPolicyForestafter the change🤖 Generated with Claude Code