From 33364d92cc24a0e27971281a94f956ee74b6dec1 Mon Sep 17 00:00:00 2001 From: igerber Date: Wed, 18 Mar 2026 08:58:57 -0400 Subject: [PATCH 01/14] Extend power analysis to all estimators and add simulation-based MDE/sample size - Add estimator registry mapping all 12 estimators to appropriate DGP, fit kwargs, and result extraction profiles - Refactor simulate_power() to use registry lookup instead of hardcoded if/elif chain (was 4 estimators, now 12 + custom fallback) - Add simulate_mde() for bisection search over effect sizes - Add simulate_sample_size() for bisection search over n_units - Add SimulationMDEResults and SimulationSampleSizeResults dataclasses - 34 new tests: registry, estimator coverage, MDE search, sample size search Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/__init__.py | 8 + diff_diff/power.py | 1125 ++++++++++++++++++++++++++++++++++++----- tests/test_power.py | 702 +++++++++++++++++++++---- 3 files changed, 1600 insertions(+), 235 deletions(-) diff --git a/diff_diff/__init__.py b/diff_diff/__init__.py index 02bdf775..126ad239 100644 --- a/diff_diff/__init__.py +++ b/diff_diff/__init__.py @@ -53,11 +53,15 @@ from diff_diff.power import ( PowerAnalysis, PowerResults, + SimulationMDEResults, SimulationPowerResults, + SimulationSampleSizeResults, compute_mde, compute_power, compute_sample_size, + simulate_mde, simulate_power, + simulate_sample_size, ) from diff_diff.pretrends import ( PreTrendsPower, @@ -291,11 +295,15 @@ # Power analysis "PowerAnalysis", "PowerResults", + "SimulationMDEResults", "SimulationPowerResults", + "SimulationSampleSizeResults", "compute_mde", "compute_power", "compute_sample_size", + "simulate_mde", "simulate_power", + "simulate_sample_size", "plot_power_curve", # Pre-trends power analysis "PreTrendsPower", diff --git a/diff_diff/power.py b/diff_diff/power.py index b8aa10dd..2b1a3a74 100644 --- a/diff_diff/power.py +++ b/diff_diff/power.py @@ -33,6 +33,374 @@ MAX_SAMPLE_SIZE = 2**31 - 1 +# --------------------------------------------------------------------------- +# Estimator registry — maps estimator class names to DGP/fit/extract profiles +# --------------------------------------------------------------------------- + + +@dataclass +class _EstimatorProfile: + """Internal profile describing how to run power simulations for an estimator.""" + + default_dgp: Callable + dgp_kwargs_builder: Callable + fit_kwargs_builder: Callable + result_extractor: Callable + min_n: int = 20 + + +# -- DGP kwargs adapters ----------------------------------------------------- + + +def _basic_dgp_kwargs( + n_units: int, + n_periods: int, + treatment_effect: float, + treatment_fraction: float, + treatment_period: int, + sigma: float, +) -> Dict[str, Any]: + return dict( + n_units=n_units, + n_periods=n_periods, + treatment_effect=treatment_effect, + treatment_fraction=treatment_fraction, + treatment_period=treatment_period, + noise_sd=sigma, + ) + + +def _staggered_dgp_kwargs( + n_units: int, + n_periods: int, + treatment_effect: float, + treatment_fraction: float, + treatment_period: int, + sigma: float, +) -> Dict[str, Any]: + return dict( + n_units=n_units, + n_periods=n_periods, + treatment_effect=treatment_effect, + never_treated_frac=1 - treatment_fraction, + cohort_periods=[treatment_period], + dynamic_effects=False, + noise_sd=sigma, + ) + + +def _factor_dgp_kwargs( + n_units: int, + n_periods: int, + treatment_effect: float, + treatment_fraction: float, + treatment_period: int, + sigma: float, +) -> Dict[str, Any]: + n_pre = treatment_period + n_post = n_periods - treatment_period + return dict( + n_units=n_units, + n_pre=n_pre, + n_post=n_post, + n_treated=max(1, int(n_units * treatment_fraction)), + treatment_effect=treatment_effect, + noise_sd=sigma, + ) + + +def _ddd_dgp_kwargs( + n_units: int, + n_periods: int, + treatment_effect: float, + treatment_fraction: float, + treatment_period: int, + sigma: float, +) -> Dict[str, Any]: + return dict( + n_per_cell=max(2, n_units // 8), + treatment_effect=treatment_effect, + noise_sd=sigma, + ) + + +def _continuous_dgp_kwargs( + n_units: int, + n_periods: int, + treatment_effect: float, + treatment_fraction: float, + treatment_period: int, + sigma: float, +) -> Dict[str, Any]: + return dict( + n_units=n_units, + n_periods=n_periods, + att_slope=treatment_effect, + never_treated_frac=1 - treatment_fraction, + cohort_periods=[treatment_period], + noise_sd=sigma, + ) + + +# -- Fit kwargs builders ------------------------------------------------------ + + +def _basic_fit_kwargs( + data: pd.DataFrame, + n_units: int, + n_periods: int, + treatment_period: int, +) -> Dict[str, Any]: + return dict(outcome="outcome", treatment="treated", time="post") + + +def _twfe_fit_kwargs( + data: pd.DataFrame, + n_units: int, + n_periods: int, + treatment_period: int, +) -> Dict[str, Any]: + return dict(outcome="outcome", treatment="treated", time="period", unit="unit") + + +def _multiperiod_fit_kwargs( + data: pd.DataFrame, + n_units: int, + n_periods: int, + treatment_period: int, +) -> Dict[str, Any]: + return dict( + outcome="outcome", + treatment="treated", + time="period", + post_periods=list(range(treatment_period, n_periods)), + ) + + +def _staggered_fit_kwargs( + data: pd.DataFrame, + n_units: int, + n_periods: int, + treatment_period: int, +) -> Dict[str, Any]: + return dict(outcome="outcome", unit="unit", time="period", first_treat="first_treat") + + +def _ddd_fit_kwargs( + data: pd.DataFrame, + n_units: int, + n_periods: int, + treatment_period: int, +) -> Dict[str, Any]: + return dict(outcome="outcome", group="group", partition="partition", time="time") + + +def _trop_fit_kwargs( + data: pd.DataFrame, + n_units: int, + n_periods: int, + treatment_period: int, +) -> Dict[str, Any]: + return dict(outcome="outcome", treatment="treated", unit="unit", time="period") + + +def _sdid_fit_kwargs( + data: pd.DataFrame, + n_units: int, + n_periods: int, + treatment_period: int, +) -> Dict[str, Any]: + periods = sorted(data["period"].unique()) + post_periods = [p for p in periods if p >= treatment_period] + return dict( + outcome="outcome", + treatment="treat", + unit="unit", + time="period", + post_periods=post_periods, + ) + + +def _continuous_fit_kwargs( + data: pd.DataFrame, + n_units: int, + n_periods: int, + treatment_period: int, +) -> Dict[str, Any]: + return dict( + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + dose="dose", + ) + + +# -- Result extractors -------------------------------------------------------- + + +def _extract_simple(result: Any) -> Tuple[float, float, float, Tuple[float, float]]: + return (result.att, result.se, result.p_value, result.conf_int) + + +def _extract_multiperiod( + result: Any, +) -> Tuple[float, float, float, Tuple[float, float]]: + return (result.avg_att, result.avg_se, result.avg_p_value, result.avg_conf_int) + + +def _extract_staggered( + result: Any, +) -> Tuple[float, float, float, Tuple[float, float]]: + _nan = float("nan") + _nan_ci = (_nan, _nan) + + def _first(r: Any, *attrs: str, default: Any = _nan) -> Any: + for a in attrs: + v = getattr(r, a, None) + if v is not None: + return v + return default + + return ( + result.overall_att, + _first(result, "overall_se", "overall_att_se"), + _first(result, "overall_p_value", "overall_att_p_value"), + _first(result, "overall_conf_int", "overall_att_ci", default=_nan_ci), + ) + + +def _extract_continuous( + result: Any, +) -> Tuple[float, float, float, Tuple[float, float]]: + return ( + result.overall_att, + result.overall_att_se, + result.overall_att_p_value, + result.overall_att_conf_int, + ) + + +# -- Registry construction (deferred to avoid import-time cost) --------------- + +_ESTIMATOR_REGISTRY: Optional[Dict[str, _EstimatorProfile]] = None + + +def _get_registry() -> Dict[str, _EstimatorProfile]: + """Lazily build and return the estimator registry.""" + global _ESTIMATOR_REGISTRY # noqa: PLW0603 + if _ESTIMATOR_REGISTRY is not None: + return _ESTIMATOR_REGISTRY + + from diff_diff.prep import ( + generate_continuous_did_data, + generate_ddd_data, + generate_did_data, + generate_factor_data, + generate_staggered_data, + ) + + _ESTIMATOR_REGISTRY = { + # --- Basic DiD group --- + "DifferenceInDifferences": _EstimatorProfile( + default_dgp=generate_did_data, + dgp_kwargs_builder=_basic_dgp_kwargs, + fit_kwargs_builder=_basic_fit_kwargs, + result_extractor=_extract_simple, + min_n=20, + ), + "TwoWayFixedEffects": _EstimatorProfile( + default_dgp=generate_did_data, + dgp_kwargs_builder=_basic_dgp_kwargs, + fit_kwargs_builder=_twfe_fit_kwargs, + result_extractor=_extract_simple, + min_n=20, + ), + "MultiPeriodDiD": _EstimatorProfile( + default_dgp=generate_did_data, + dgp_kwargs_builder=_basic_dgp_kwargs, + fit_kwargs_builder=_multiperiod_fit_kwargs, + result_extractor=_extract_multiperiod, + min_n=20, + ), + # --- Staggered group --- + "CallawaySantAnna": _EstimatorProfile( + default_dgp=generate_staggered_data, + dgp_kwargs_builder=_staggered_dgp_kwargs, + fit_kwargs_builder=_staggered_fit_kwargs, + result_extractor=_extract_staggered, + min_n=40, + ), + "SunAbraham": _EstimatorProfile( + default_dgp=generate_staggered_data, + dgp_kwargs_builder=_staggered_dgp_kwargs, + fit_kwargs_builder=_staggered_fit_kwargs, + result_extractor=_extract_staggered, + min_n=40, + ), + "ImputationDiD": _EstimatorProfile( + default_dgp=generate_staggered_data, + dgp_kwargs_builder=_staggered_dgp_kwargs, + fit_kwargs_builder=_staggered_fit_kwargs, + result_extractor=_extract_staggered, + min_n=40, + ), + "TwoStageDiD": _EstimatorProfile( + default_dgp=generate_staggered_data, + dgp_kwargs_builder=_staggered_dgp_kwargs, + fit_kwargs_builder=_staggered_fit_kwargs, + result_extractor=_extract_staggered, + min_n=40, + ), + "StackedDiD": _EstimatorProfile( + default_dgp=generate_staggered_data, + dgp_kwargs_builder=_staggered_dgp_kwargs, + fit_kwargs_builder=_staggered_fit_kwargs, + result_extractor=_extract_staggered, + min_n=40, + ), + "EfficientDiD": _EstimatorProfile( + default_dgp=generate_staggered_data, + dgp_kwargs_builder=_staggered_dgp_kwargs, + fit_kwargs_builder=_staggered_fit_kwargs, + result_extractor=_extract_staggered, + min_n=40, + ), + # --- Factor model group --- + "TROP": _EstimatorProfile( + default_dgp=generate_factor_data, + dgp_kwargs_builder=_factor_dgp_kwargs, + fit_kwargs_builder=_trop_fit_kwargs, + result_extractor=_extract_simple, + min_n=30, + ), + "SyntheticDiD": _EstimatorProfile( + default_dgp=generate_factor_data, + dgp_kwargs_builder=_factor_dgp_kwargs, + fit_kwargs_builder=_sdid_fit_kwargs, + result_extractor=_extract_simple, + min_n=30, + ), + # --- Triple difference --- + "TripleDifference": _EstimatorProfile( + default_dgp=generate_ddd_data, + dgp_kwargs_builder=_ddd_dgp_kwargs, + fit_kwargs_builder=_ddd_fit_kwargs, + result_extractor=_extract_simple, + min_n=64, + ), + # --- Continuous DiD --- + "ContinuousDiD": _EstimatorProfile( + default_dgp=generate_continuous_did_data, + dgp_kwargs_builder=_continuous_dgp_kwargs, + fit_kwargs_builder=_continuous_fit_kwargs, + result_extractor=_extract_continuous, + min_n=40, + ), + } + return _ESTIMATOR_REGISTRY + + @dataclass class PowerResults: """ @@ -332,10 +700,7 @@ def power_curve_df(self) -> pd.DataFrame: pd.DataFrame DataFrame with effect_size and power columns. """ - return pd.DataFrame({ - "effect_size": self.effect_sizes, - "power": self.powers - }) + return pd.DataFrame({"effect_size": self.effect_sizes, "power": self.powers}) class PowerAnalysis: @@ -463,9 +828,7 @@ def _compute_variance( n_c_pre = n_control n_c_post = n_control - variance = sigma**2 * ( - 1 / n_t_post + 1 / n_t_pre + 1 / n_c_post + 1 / n_c_pre - ) + variance = sigma**2 * (1 / n_t_post + 1 / n_t_pre + 1 / n_c_post + 1 / n_c_pre) elif design == "panel": # Panel DiD with multiple periods # Account for serial correlation via ICC @@ -528,9 +891,7 @@ def power( T = n_pre + n_post design = "panel" if T > 2 else "basic_did" - variance = self._compute_variance( - n_treated, n_control, n_pre, n_post, sigma, rho, design - ) + variance = self._compute_variance(n_treated, n_control, n_pre, n_post, sigma, rho, design) se = np.sqrt(variance) # Calculate power @@ -538,7 +899,8 @@ def power( z_alpha = stats.norm.ppf(1 - self.alpha / 2) # Power = P(reject | effect) = P(|Z| > z_alpha | effect) power_val = ( - 1 - stats.norm.cdf(z_alpha - effect_size / se) + 1 + - stats.norm.cdf(z_alpha - effect_size / se) + stats.norm.cdf(-z_alpha - effect_size / se) ) elif self.alternative == "greater": @@ -551,8 +913,7 @@ def power( # Also compute MDE and required N for reference mde = self._compute_mde_from_se(se) required_n = self._compute_required_n( - effect_size, sigma, n_pre, n_post, rho, design, - n_treated / (n_treated + n_control) + effect_size, sigma, n_pre, n_post, rho, design, n_treated / (n_treated + n_control) ) return PowerResults( @@ -620,9 +981,7 @@ def mde( T = n_pre + n_post design = "panel" if T > 2 else "basic_did" - variance = self._compute_variance( - n_treated, n_control, n_pre, n_post, sigma, rho, design - ) + variance = self._compute_variance(n_treated, n_control, n_pre, n_post, sigma, rho, design) se = np.sqrt(variance) mde = self._compute_mde_from_se(se) @@ -674,7 +1033,9 @@ def _compute_required_n( # = 2 * sigma^2 / N * (1/(p*(1-p))) n_total = ( - 2 * sigma**2 * (z_alpha + z_beta)**2 + 2 + * sigma**2 + * (z_alpha + z_beta) ** 2 / (effect_size**2 * treat_frac * (1 - treat_frac)) ) else: # panel @@ -684,7 +1045,10 @@ def _compute_required_n( # For balanced: Var = 2 * sigma^2 / N * design_effect / T n_total = ( - 2 * sigma**2 * (z_alpha + z_beta)**2 * design_effect + 2 + * sigma**2 + * (z_alpha + z_beta) ** 2 + * design_effect / (effect_size**2 * treat_frac * (1 - treat_frac) * T) ) @@ -744,9 +1108,7 @@ def sample_size( n_total = n_treated + n_control # Compute actual power achieved - variance = self._compute_variance( - n_treated, n_control, n_pre, n_post, sigma, rho, design - ) + variance = self._compute_variance(n_treated, n_control, n_pre, n_post, sigma, rho, design) se = np.sqrt(variance) mde = self._compute_mde_from_se(se) @@ -865,9 +1227,7 @@ def sample_size_curve( DataFrame with columns 'sample_size' and 'power'. """ # Get required N to determine default range - required = self.sample_size( - effect_size, sigma, n_pre, n_post, rho, treat_frac - ) + required = self.sample_size(effect_size, sigma, n_pre, n_post, rho, treat_frac) if sample_sizes is None: min_n = max(10, required.required_n // 4) @@ -914,7 +1274,8 @@ def simulate_power( This function simulates datasets with known treatment effects and estimates power as the fraction of simulations where the null hypothesis is rejected. - This is the recommended approach for complex designs like staggered adoption. + All built-in estimators are supported via an internal registry that selects + the appropriate data-generating process and fit signature automatically. Parameters ---------- @@ -942,8 +1303,9 @@ def simulate_power( seed : int, optional Random seed for reproducibility. data_generator : callable, optional - Custom data generation function. Should accept same signature as - generate_did_data(). If None, uses generate_did_data(). + Custom data generation function. When provided, bypasses the + registry DGP and calls this function with the standard kwargs + (n_units, n_periods, treatment_effect, etc.). data_generator_kwargs : dict, optional Additional keyword arguments for data generator. estimator_kwargs : dict, optional @@ -982,15 +1344,11 @@ def simulate_power( ... ) >>> print(results.power_curve_df()) - With Callaway-Sant'Anna for staggered designs: + With Callaway-Sant'Anna (auto-detected, no custom DGP needed): >>> from diff_diff import CallawaySantAnna >>> cs = CallawaySantAnna() - >>> # Custom data generator for staggered adoption - >>> def staggered_data(n_units, n_periods, treatment_effect, **kwargs): - ... # Your staggered data generation logic - ... ... - >>> results = simulate_power(cs, data_generator=staggered_data, ...) + >>> results = simulate_power(cs, n_simulations=200, seed=42) Notes ----- @@ -1000,20 +1358,25 @@ def simulate_power( 3. Repeat n_simulations times 4. Power = fraction of simulations where p-value < alpha - For staggered designs, you'll need to provide a custom data_generator - that creates appropriate staggered treatment timing. - References ---------- Burlig, F., Preonas, L., & Woerman, M. (2020). "Panel Data and Experimental Design." """ - from diff_diff.prep import generate_did_data - rng = np.random.default_rng(seed) - # Use default data generator if none provided - if data_generator is None: - data_generator = generate_did_data + estimator_name = type(estimator).__name__ + registry = _get_registry() + profile = registry.get(estimator_name) + + # If no profile and no custom data_generator, raise + if profile is None and data_generator is None: + raise ValueError( + f"Estimator '{estimator_name}' not in registry. " + f"Provide a custom data_generator and estimator_kwargs." + ) + + # When a custom data_generator is provided, bypass registry DGP + use_custom_dgp = data_generator is not None data_gen_kwargs = data_generator_kwargs or {} est_kwargs = estimator_kwargs or {} @@ -1024,30 +1387,35 @@ def simulate_power( all_powers = [] - # For the primary effect (last in list), collect detailed results - # Use index-based comparison to avoid float precision issues + # For the primary effect, collect detailed results if len(effect_sizes) == 1: primary_idx = 0 else: - # Find index of treatment_effect in effect_sizes primary_idx = -1 for i, es in enumerate(effect_sizes): if np.isclose(es, treatment_effect): primary_idx = i break if primary_idx == -1: - primary_idx = len(effect_sizes) - 1 # Default to last + primary_idx = len(effect_sizes) - 1 primary_effect = effect_sizes[primary_idx] + # Initialize so they are always bound + primary_estimates: List[float] = [] + primary_ses: List[float] = [] + primary_p_values: List[float] = [] + primary_rejections: List[bool] = [] + primary_ci_contains: List[bool] = [] + for effect_idx, effect in enumerate(effect_sizes): - is_primary = (effect_idx == primary_idx) + is_primary = effect_idx == primary_idx - estimates = [] - ses = [] - p_values = [] - rejections = [] - ci_contains_true = [] + estimates: List[float] = [] + ses: List[float] = [] + p_values: List[float] = [] + rejections: List[bool] = [] + ci_contains_true: List[bool] = [] n_failures = 0 for sim in range(n_simulations): @@ -1055,90 +1423,75 @@ def simulate_power( pct = (sim + effect_idx * n_simulations) / (len(effect_sizes) * n_simulations) print(f" Simulation progress: {pct:.0%}") - # Generate data sim_seed = rng.integers(0, 2**31) - data = data_generator( - n_units=n_units, - n_periods=n_periods, - treatment_effect=effect, - treatment_fraction=treatment_fraction, - treatment_period=treatment_period, - noise_sd=sigma, - seed=sim_seed, - **data_gen_kwargs - ) + + # --- Generate data --- + if use_custom_dgp: + assert data_generator is not None + data = data_generator( + n_units=n_units, + n_periods=n_periods, + treatment_effect=effect, + treatment_fraction=treatment_fraction, + treatment_period=treatment_period, + noise_sd=sigma, + seed=sim_seed, + **data_gen_kwargs, + ) + else: + assert profile is not None + dgp_kwargs = profile.dgp_kwargs_builder( + n_units=n_units, + n_periods=n_periods, + treatment_effect=effect, + treatment_fraction=treatment_fraction, + treatment_period=treatment_period, + sigma=sigma, + ) + dgp_kwargs.update(data_gen_kwargs) + dgp_kwargs.pop("seed", None) + data = profile.default_dgp(seed=sim_seed, **dgp_kwargs) try: - # Fit estimator - # Try to determine the right arguments based on estimator type - estimator_name = type(estimator).__name__ - - if estimator_name == "DifferenceInDifferences": - result = estimator.fit( - data, - outcome="outcome", - treatment="treated", - time="post", - **est_kwargs - ) - elif estimator_name == "TwoWayFixedEffects": - result = estimator.fit( - data, - outcome="outcome", - treatment="treated", - time="period", - unit="unit", - **est_kwargs - ) - elif estimator_name == "MultiPeriodDiD": - post_periods = list(range(treatment_period, n_periods)) - result = estimator.fit( - data, - outcome="outcome", - treatment="treated", - time="period", - post_periods=post_periods, - **est_kwargs - ) - elif estimator_name == "CallawaySantAnna": - # Need to create first_treat column for staggered - # For standard generate_did_data, convert to first_treat format - data = data.copy() - data["first_treat"] = np.where( - data["treated"] == 1, treatment_period, 0 - ) - result = estimator.fit( - data, - outcome="outcome", - unit="unit", - time="period", - first_treat="first_treat", - **est_kwargs + # --- Fit estimator --- + if profile is not None and not use_custom_dgp: + fit_kwargs = profile.fit_kwargs_builder( + data, n_units, n_periods, treatment_period ) + fit_kwargs.update(est_kwargs) else: - # Generic fallback - try common signature - result = estimator.fit( - data, - outcome="outcome", - treatment="treated", - time="post", - **est_kwargs - ) + # Custom DGP fallback: use registry fit kwargs if available, + # otherwise use basic DiD signature + if profile is not None: + fit_kwargs = profile.fit_kwargs_builder( + data, n_units, n_periods, treatment_period + ) + fit_kwargs.update(est_kwargs) + else: + fit_kwargs = dict(outcome="outcome", treatment="treated", time="post") + fit_kwargs.update(est_kwargs) + + result = estimator.fit(data, **fit_kwargs) + + # --- Extract results --- + if profile is not None: + att, se, p_val, ci = profile.result_extractor(result) + else: + att = result.att if hasattr(result, "att") else result.avg_att + se = result.se if hasattr(result, "se") else result.avg_se + p_val = result.p_value if hasattr(result, "p_value") else result.avg_p_value + ci = result.conf_int if hasattr(result, "conf_int") else result.avg_conf_int - # Extract results - att = result.att if hasattr(result, 'att') else result.avg_att - se = result.se if hasattr(result, 'se') else result.avg_se - p_val = result.p_value if hasattr(result, 'p_value') else result.avg_p_value - ci = result.conf_int if hasattr(result, 'conf_int') else result.avg_conf_int + # NaN p-value → treat as non-rejection + rejected = bool(p_val < alpha) if not np.isnan(p_val) else False estimates.append(att) ses.append(se) p_values.append(p_val) - rejections.append(p_val < alpha) + rejections.append(rejected) ci_contains_true.append(ci[0] <= effect <= ci[1]) except Exception as e: - # Track failed simulations n_failures += 1 if progress: print(f" Warning: Simulation {sim} failed: {e}") @@ -1148,21 +1501,18 @@ def simulate_power( failure_rate = n_failures / n_simulations if failure_rate > 0.1: warnings.warn( - f"{n_failures}/{n_simulations} simulations ({failure_rate:.1%}) failed " - f"for effect_size={effect}. Check estimator and data generator.", - UserWarning + f"{n_failures}/{n_simulations} simulations ({failure_rate:.1%}) " + f"failed for effect_size={effect}. " + f"Check estimator and data generator.", + UserWarning, ) if len(estimates) == 0: raise RuntimeError("All simulations failed. Check estimator and data generator.") - # Compute power and SE power_val = np.mean(rejections) - power_se = np.sqrt(power_val * (1 - power_val) / len(rejections)) - all_powers.append(power_val) - # Store detailed results for primary effect if is_primary: primary_estimates = estimates primary_ses = ses @@ -1177,10 +1527,9 @@ def simulate_power( z = stats.norm.ppf(0.975) power_ci = ( max(0.0, power_val - z * power_se), - min(1.0, power_val + z * power_se) + min(1.0, power_val + z * power_se), ) - # Compute summary statistics mean_estimate = np.mean(primary_estimates) std_estimate = np.std(primary_estimates, ddof=1) mean_se = np.mean(primary_ses) @@ -1200,15 +1549,523 @@ def simulate_power( powers=all_powers, true_effect=primary_effect, alpha=alpha, - estimator_name=type(estimator).__name__, + estimator_name=estimator_name, simulation_results=[ {"estimate": e, "se": s, "p_value": p, "rejected": r} - for e, s, p, r in zip(primary_estimates, primary_ses, - primary_p_values, primary_rejections) + for e, s, p, r in zip( + primary_estimates, + primary_ses, + primary_p_values, + primary_rejections, + ) ], ) +# --------------------------------------------------------------------------- +# Simulation-based MDE and sample-size search +# --------------------------------------------------------------------------- + + +@dataclass +class SimulationMDEResults: + """ + Results from simulation-based minimum detectable effect search. + + Attributes + ---------- + mde : float + Minimum detectable effect (smallest effect achieving target power). + power_at_mde : float + Power achieved at the MDE. + target_power : float + Target power used in the search. + alpha : float + Significance level. + n_units : int + Sample size used. + n_simulations_per_step : int + Number of simulations per bisection step. + n_steps : int + Number of bisection steps performed. + search_path : list of dict + Diagnostic trace of ``{effect_size, power}`` at each step. + estimator_name : str + Name of the estimator used. + """ + + mde: float + power_at_mde: float + target_power: float + alpha: float + n_units: int + n_simulations_per_step: int + n_steps: int + search_path: List[Dict[str, float]] + estimator_name: str + + def __repr__(self) -> str: + return ( + f"SimulationMDEResults(mde={self.mde:.4f}, " + f"power_at_mde={self.power_at_mde:.3f}, " + f"n_steps={self.n_steps})" + ) + + def summary(self) -> str: + """Generate a formatted summary.""" + lines = [ + "=" * 65, + "Simulation-Based MDE Results".center(65), + "=" * 65, + "", + f"{'Estimator:':<35} {self.estimator_name}", + f"{'Significance level (alpha):':<35} {self.alpha:.3f}", + f"{'Target power:':<35} {self.target_power:.1%}", + f"{'Sample size (n_units):':<35} {self.n_units}", + f"{'Simulations per step:':<35} {self.n_simulations_per_step}", + "", + "-" * 65, + "Search Results".center(65), + "-" * 65, + f"{'Minimum detectable effect:':<35} {self.mde:.4f}", + f"{'Power at MDE:':<35} {self.power_at_mde:.1%}", + f"{'Bisection steps:':<35} {self.n_steps}", + "=" * 65, + ] + return "\n".join(lines) + + def to_dict(self) -> Dict[str, Any]: + """Convert results to a dictionary.""" + return { + "mde": self.mde, + "power_at_mde": self.power_at_mde, + "target_power": self.target_power, + "alpha": self.alpha, + "n_units": self.n_units, + "n_simulations_per_step": self.n_simulations_per_step, + "n_steps": self.n_steps, + "estimator_name": self.estimator_name, + } + + def to_dataframe(self) -> pd.DataFrame: + """Convert results to a single-row DataFrame.""" + return pd.DataFrame([self.to_dict()]) + + +@dataclass +class SimulationSampleSizeResults: + """ + Results from simulation-based sample size search. + + Attributes + ---------- + required_n : int + Required number of units to achieve target power. + power_at_n : float + Power achieved at the required N. + target_power : float + Target power used in the search. + alpha : float + Significance level. + effect_size : float + Effect size used in the search. + n_simulations_per_step : int + Number of simulations per bisection step. + n_steps : int + Number of bisection steps performed. + search_path : list of dict + Diagnostic trace of ``{n_units, power}`` at each step. + estimator_name : str + Name of the estimator used. + """ + + required_n: int + power_at_n: float + target_power: float + alpha: float + effect_size: float + n_simulations_per_step: int + n_steps: int + search_path: List[Dict[str, float]] + estimator_name: str + + def __repr__(self) -> str: + return ( + f"SimulationSampleSizeResults(required_n={self.required_n}, " + f"power_at_n={self.power_at_n:.3f}, " + f"n_steps={self.n_steps})" + ) + + def summary(self) -> str: + """Generate a formatted summary.""" + lines = [ + "=" * 65, + "Simulation-Based Sample Size Results".center(65), + "=" * 65, + "", + f"{'Estimator:':<35} {self.estimator_name}", + f"{'Significance level (alpha):':<35} {self.alpha:.3f}", + f"{'Target power:':<35} {self.target_power:.1%}", + f"{'Effect size:':<35} {self.effect_size:.4f}", + f"{'Simulations per step:':<35} {self.n_simulations_per_step}", + "", + "-" * 65, + "Search Results".center(65), + "-" * 65, + f"{'Required sample size:':<35} {self.required_n}", + f"{'Power at required N:':<35} {self.power_at_n:.1%}", + f"{'Bisection steps:':<35} {self.n_steps}", + "=" * 65, + ] + return "\n".join(lines) + + def to_dict(self) -> Dict[str, Any]: + """Convert results to a dictionary.""" + return { + "required_n": self.required_n, + "power_at_n": self.power_at_n, + "target_power": self.target_power, + "alpha": self.alpha, + "effect_size": self.effect_size, + "n_simulations_per_step": self.n_simulations_per_step, + "n_steps": self.n_steps, + "estimator_name": self.estimator_name, + } + + def to_dataframe(self) -> pd.DataFrame: + """Convert results to a single-row DataFrame.""" + return pd.DataFrame([self.to_dict()]) + + +def simulate_mde( + estimator: Any, + n_units: int = 100, + n_periods: int = 4, + treatment_fraction: float = 0.5, + treatment_period: int = 2, + sigma: float = 1.0, + n_simulations: int = 200, + power: float = 0.80, + alpha: float = 0.05, + effect_range: Optional[Tuple[float, float]] = None, + tol: float = 0.02, + max_steps: int = 15, + seed: Optional[int] = None, + data_generator: Optional[Callable] = None, + data_generator_kwargs: Optional[Dict[str, Any]] = None, + estimator_kwargs: Optional[Dict[str, Any]] = None, + progress: bool = True, +) -> SimulationMDEResults: + """ + Find the minimum detectable effect via simulation-based bisection search. + + Searches over effect sizes to find the smallest effect that achieves the + target power, using ``simulate_power()`` at each step. + + Parameters + ---------- + estimator : estimator object + DiD estimator to use. + n_units : int, default=100 + Number of units per simulation. + n_periods : int, default=4 + Number of time periods. + treatment_fraction : float, default=0.5 + Fraction of units that are treated. + treatment_period : int, default=2 + First post-treatment period (0-indexed). + sigma : float, default=1.0 + Residual standard deviation. + n_simulations : int, default=200 + Simulations per bisection step. + power : float, default=0.80 + Target power. + alpha : float, default=0.05 + Significance level. + effect_range : tuple of (float, float), optional + ``(lo, hi)`` bracket for the search. If None, auto-brackets. + tol : float, default=0.02 + Convergence tolerance on power. + max_steps : int, default=15 + Maximum bisection steps. + seed : int, optional + Random seed for reproducibility. + data_generator : callable, optional + Custom data generation function. + data_generator_kwargs : dict, optional + Additional keyword arguments for data generator. + estimator_kwargs : dict, optional + Additional keyword arguments for estimator.fit(). + progress : bool, default=True + Whether to print progress updates. + + Returns + ------- + SimulationMDEResults + Results including the MDE and search diagnostics. + + Examples + -------- + >>> from diff_diff import simulate_mde, DifferenceInDifferences + >>> result = simulate_mde(DifferenceInDifferences(), n_simulations=100, seed=42) + >>> print(f"MDE: {result.mde:.3f}") + """ + master_rng = np.random.default_rng(seed) + estimator_name = type(estimator).__name__ + search_path: List[Dict[str, float]] = [] + + common_kwargs: Dict[str, Any] = dict( + estimator=estimator, + n_units=n_units, + n_periods=n_periods, + treatment_fraction=treatment_fraction, + treatment_period=treatment_period, + sigma=sigma, + n_simulations=n_simulations, + alpha=alpha, + data_generator=data_generator, + data_generator_kwargs=data_generator_kwargs, + estimator_kwargs=estimator_kwargs, + progress=False, + ) + + def _power_at(effect: float) -> float: + step_seed = int(master_rng.integers(0, 2**31)) + res = simulate_power(treatment_effect=effect, seed=step_seed, **common_kwargs) + pwr = float(res.power) + search_path.append({"effect_size": effect, "power": pwr}) + if progress: + print(f" MDE search: effect={effect:.4f}, power={pwr:.3f}") + return pwr + + # --- Bracket --- + if effect_range is not None: + lo, hi = effect_range + else: + lo = 0.0 + # Check that power at zero is below target (no inflated Type I error) + power_at_zero = _power_at(0.0) + if power_at_zero >= power: + warnings.warn( + f"Power at effect=0 is {power_at_zero:.2f} >= target {power}. " + f"This suggests inflated Type I error. Returning MDE=0.", + UserWarning, + ) + return SimulationMDEResults( + mde=0.0, + power_at_mde=power_at_zero, + target_power=power, + alpha=alpha, + n_units=n_units, + n_simulations_per_step=n_simulations, + n_steps=len(search_path), + search_path=search_path, + estimator_name=estimator_name, + ) + + hi = sigma + for _ in range(10): + if _power_at(hi) >= power: + break + hi *= 2 + else: + warnings.warn( + f"Could not bracket MDE (power at effect={hi} still below " + f"{power}). Returning best upper bound.", + UserWarning, + ) + + # --- Bisect --- + best_effect = hi + best_power = search_path[-1]["power"] if search_path else 0.0 + + for _ in range(max_steps): + mid = (lo + hi) / 2 + pwr = _power_at(mid) + + if pwr >= power: + hi = mid + best_effect = mid + best_power = pwr + else: + lo = mid + + # Convergence: effect range is tight or power is close enough + if hi - lo < max(tol * hi, 1e-6) or abs(pwr - power) < tol: + break + + return SimulationMDEResults( + mde=best_effect, + power_at_mde=best_power, + target_power=power, + alpha=alpha, + n_units=n_units, + n_simulations_per_step=n_simulations, + n_steps=len(search_path), + search_path=search_path, + estimator_name=estimator_name, + ) + + +def simulate_sample_size( + estimator: Any, + treatment_effect: float = 5.0, + n_periods: int = 4, + treatment_fraction: float = 0.5, + treatment_period: int = 2, + sigma: float = 1.0, + n_simulations: int = 200, + power: float = 0.80, + alpha: float = 0.05, + n_range: Optional[Tuple[int, int]] = None, + max_steps: int = 15, + seed: Optional[int] = None, + data_generator: Optional[Callable] = None, + data_generator_kwargs: Optional[Dict[str, Any]] = None, + estimator_kwargs: Optional[Dict[str, Any]] = None, + progress: bool = True, +) -> SimulationSampleSizeResults: + """ + Find the required sample size via simulation-based bisection search. + + Searches over ``n_units`` to find the smallest N that achieves the + target power, using ``simulate_power()`` at each step. + + Parameters + ---------- + estimator : estimator object + DiD estimator to use. + treatment_effect : float, default=5.0 + True treatment effect to simulate. + n_periods : int, default=4 + Number of time periods. + treatment_fraction : float, default=0.5 + Fraction of units that are treated. + treatment_period : int, default=2 + First post-treatment period (0-indexed). + sigma : float, default=1.0 + Residual standard deviation. + n_simulations : int, default=200 + Simulations per bisection step. + power : float, default=0.80 + Target power. + alpha : float, default=0.05 + Significance level. + n_range : tuple of (int, int), optional + ``(lo, hi)`` bracket for sample size. If None, auto-brackets. + max_steps : int, default=15 + Maximum bisection steps. + seed : int, optional + Random seed for reproducibility. + data_generator : callable, optional + Custom data generation function. + data_generator_kwargs : dict, optional + Additional keyword arguments for data generator. + estimator_kwargs : dict, optional + Additional keyword arguments for estimator.fit(). + progress : bool, default=True + Whether to print progress updates. + + Returns + ------- + SimulationSampleSizeResults + Results including the required N and search diagnostics. + + Examples + -------- + >>> from diff_diff import simulate_sample_size, DifferenceInDifferences + >>> result = simulate_sample_size( + ... DifferenceInDifferences(), treatment_effect=5.0, n_simulations=100, seed=42 + ... ) + >>> print(f"Required N: {result.required_n}") + """ + master_rng = np.random.default_rng(seed) + estimator_name = type(estimator).__name__ + search_path: List[Dict[str, float]] = [] + + # Determine min_n from registry + registry = _get_registry() + profile = registry.get(estimator_name) + min_n = profile.min_n if profile is not None else 20 + + common_kwargs: Dict[str, Any] = dict( + estimator=estimator, + n_periods=n_periods, + treatment_effect=treatment_effect, + treatment_fraction=treatment_fraction, + treatment_period=treatment_period, + sigma=sigma, + n_simulations=n_simulations, + alpha=alpha, + data_generator=data_generator, + data_generator_kwargs=data_generator_kwargs, + estimator_kwargs=estimator_kwargs, + progress=False, + ) + + def _power_at_n(n: int) -> float: + step_seed = int(master_rng.integers(0, 2**31)) + res = simulate_power(n_units=n, seed=step_seed, **common_kwargs) + pwr = float(res.power) + search_path.append({"n_units": float(n), "power": pwr}) + if progress: + print(f" Sample size search: n={n}, power={pwr:.3f}") + return pwr + + # --- Bracket --- + if n_range is not None: + lo, hi = n_range + else: + lo = min_n + hi = max(100, 2 * min_n) + for _ in range(10): + if _power_at_n(hi) >= power: + break + hi *= 2 + else: + warnings.warn( + f"Could not bracket required N (power at n={hi} still below " + f"{power}). Returning best upper bound.", + UserWarning, + ) + + # --- Bisect on integer n_units --- + best_n = hi + best_power = search_path[-1]["power"] if search_path else 0.0 + + for _ in range(max_steps): + if hi - lo <= 2: + break + mid = (lo + hi) // 2 + pwr = _power_at_n(mid) + + if pwr >= power: + hi = mid + best_n = mid + best_power = pwr + else: + lo = mid + + # Final answer is hi (conservative ceiling) — skip if already evaluated + if best_n != hi: + final_pwr = _power_at_n(hi) + if final_pwr >= power: + best_n = hi + best_power = final_pwr + + return SimulationSampleSizeResults( + required_n=best_n, + power_at_n=best_power, + target_power=power, + alpha=alpha, + effect_size=treatment_effect, + n_simulations_per_step=n_simulations, + n_steps=len(search_path), + search_path=search_path, + estimator_name=estimator_name, + ) + + def compute_mde( n_treated: int, n_control: int, diff --git a/tests/test_power.py b/tests/test_power.py index d012cda5..13a4db53 100644 --- a/tests/test_power.py +++ b/tests/test_power.py @@ -5,16 +5,50 @@ import pytest from diff_diff import ( + TROP, + CallawaySantAnna, + ContinuousDiD, DifferenceInDifferences, + EfficientDiD, + ImputationDiD, + MultiPeriodDiD, PowerAnalysis, PowerResults, + SimulationMDEResults, SimulationPowerResults, + SimulationSampleSizeResults, + StackedDiD, + SunAbraham, + SyntheticDiD, + TripleDifference, + TwoStageDiD, + TwoWayFixedEffects, compute_mde, compute_power, compute_sample_size, + simulate_mde, simulate_power, + simulate_sample_size, +) +from diff_diff.power import ( + MAX_SAMPLE_SIZE, + _basic_dgp_kwargs, + _basic_fit_kwargs, + _continuous_dgp_kwargs, + _continuous_fit_kwargs, + _ddd_dgp_kwargs, + _ddd_fit_kwargs, + _extract_continuous, + _extract_multiperiod, + _extract_simple, + _extract_staggered, + _factor_dgp_kwargs, + _get_registry, + _staggered_dgp_kwargs, + _staggered_fit_kwargs, + _trop_fit_kwargs, + _twfe_fit_kwargs, ) -from diff_diff.power import MAX_SAMPLE_SIZE class TestPowerAnalysis: @@ -80,12 +114,7 @@ def test_mde_decreases_with_sample_size(self): def test_power_calculation(self): """Test power calculation.""" pa = PowerAnalysis(alpha=0.05) - result = pa.power( - effect_size=0.5, - n_treated=50, - n_control=50, - sigma=1.0 - ) + result = pa.power(effect_size=0.5, n_treated=50, n_control=50, sigma=1.0) assert isinstance(result, PowerResults) assert 0 < result.power < 1 @@ -95,12 +124,8 @@ def test_power_increases_with_effect_size(self): """Test that power increases with effect size.""" pa = PowerAnalysis() - result_small = pa.power( - effect_size=0.2, n_treated=50, n_control=50, sigma=1.0 - ) - result_large = pa.power( - effect_size=0.8, n_treated=50, n_control=50, sigma=1.0 - ) + result_small = pa.power(effect_size=0.2, n_treated=50, n_control=50, sigma=1.0) + result_large = pa.power(effect_size=0.8, n_treated=50, n_control=50, sigma=1.0) assert result_large.power > result_small.power @@ -108,12 +133,8 @@ def test_power_increases_with_sample_size(self): """Test that power increases with sample size.""" pa = PowerAnalysis() - result_small = pa.power( - effect_size=0.5, n_treated=25, n_control=25, sigma=1.0 - ) - result_large = pa.power( - effect_size=0.5, n_treated=100, n_control=100, sigma=1.0 - ) + result_small = pa.power(effect_size=0.5, n_treated=25, n_control=25, sigma=1.0) + result_large = pa.power(effect_size=0.5, n_treated=100, n_control=100, sigma=1.0) assert result_large.power > result_small.power @@ -140,14 +161,8 @@ def test_panel_design(self): pa = PowerAnalysis(power=0.80) # Panel with multiple periods should have smaller MDE - result_2period = pa.mde( - n_treated=50, n_control=50, sigma=1.0, - n_pre=1, n_post=1 - ) - result_6period = pa.mde( - n_treated=50, n_control=50, sigma=1.0, - n_pre=3, n_post=3 - ) + result_2period = pa.mde(n_treated=50, n_control=50, sigma=1.0, n_pre=1, n_post=1) + result_6period = pa.mde(n_treated=50, n_control=50, sigma=1.0, n_pre=3, n_post=3) # More periods should reduce MDE (more data) assert result_6period.mde < result_2period.mde @@ -157,14 +172,8 @@ def test_icc_effect(self): """Test that intra-cluster correlation affects power.""" pa = PowerAnalysis(power=0.80) - result_no_icc = pa.mde( - n_treated=50, n_control=50, sigma=1.0, - n_pre=3, n_post=3, rho=0.0 - ) - result_with_icc = pa.mde( - n_treated=50, n_control=50, sigma=1.0, - n_pre=3, n_post=3, rho=0.5 - ) + result_no_icc = pa.mde(n_treated=50, n_control=50, sigma=1.0, n_pre=3, n_post=3, rho=0.0) + result_with_icc = pa.mde(n_treated=50, n_control=50, sigma=1.0, n_pre=3, n_post=3, rho=0.5) # Higher ICC should increase MDE (less independent information) assert result_with_icc.mde > result_no_icc.mde @@ -173,8 +182,7 @@ def test_power_curve(self): """Test power curve generation.""" pa = PowerAnalysis() curve = pa.power_curve( - n_treated=50, n_control=50, sigma=1.0, - effect_sizes=[0.1, 0.2, 0.3, 0.5, 0.7, 1.0] + n_treated=50, n_control=50, sigma=1.0, effect_sizes=[0.1, 0.2, 0.3, 0.5, 0.7, 1.0] ) assert isinstance(curve, pd.DataFrame) @@ -196,8 +204,7 @@ def test_sample_size_curve(self): """Test sample size curve generation.""" pa = PowerAnalysis() curve = pa.sample_size_curve( - effect_size=0.5, sigma=1.0, - sample_sizes=[20, 50, 100, 150, 200] + effect_size=0.5, sigma=1.0, sample_sizes=[20, 50, 100, 150, 200] ) assert isinstance(curve, pd.DataFrame) @@ -258,22 +265,14 @@ def test_one_sided_power_calculation(self): pa_two = PowerAnalysis(alternative="two-sided") # For positive effect, 'greater' should have higher power than two-sided - result_greater = pa_greater.power( - effect_size=0.5, n_treated=50, n_control=50, sigma=1.0 - ) - result_two = pa_two.power( - effect_size=0.5, n_treated=50, n_control=50, sigma=1.0 - ) + result_greater = pa_greater.power(effect_size=0.5, n_treated=50, n_control=50, sigma=1.0) + result_two = pa_two.power(effect_size=0.5, n_treated=50, n_control=50, sigma=1.0) assert result_greater.power > result_two.power # For negative effect, 'less' should have higher power - result_less = pa_less.power( - effect_size=-0.5, n_treated=50, n_control=50, sigma=1.0 - ) - result_two_neg = pa_two.power( - effect_size=-0.5, n_treated=50, n_control=50, sigma=1.0 - ) + result_less = pa_less.power(effect_size=-0.5, n_treated=50, n_control=50, sigma=1.0) + result_two_neg = pa_two.power(effect_size=-0.5, n_treated=50, n_control=50, sigma=1.0) assert result_less.power > result_two_neg.power @@ -282,12 +281,8 @@ def test_negative_effect_size(self): pa = PowerAnalysis() # Power should work the same for negative effects (symmetric) - result_pos = pa.power( - effect_size=0.5, n_treated=50, n_control=50, sigma=1.0 - ) - result_neg = pa.power( - effect_size=-0.5, n_treated=50, n_control=50, sigma=1.0 - ) + result_pos = pa.power(effect_size=0.5, n_treated=50, n_control=50, sigma=1.0) + result_neg = pa.power(effect_size=-0.5, n_treated=50, n_control=50, sigma=1.0) # Two-sided test should have same power for positive and negative effects assert abs(result_pos.power - result_neg.power) < 0.01 @@ -297,20 +292,14 @@ def test_extreme_icc(self): pa = PowerAnalysis(power=0.80) # Test with very high ICC (0.99) - result_extreme = pa.mde( - n_treated=50, n_control=50, sigma=1.0, - n_pre=5, n_post=5, rho=0.99 - ) + result_extreme = pa.mde(n_treated=50, n_control=50, sigma=1.0, n_pre=5, n_post=5, rho=0.99) - result_moderate = pa.mde( - n_treated=50, n_control=50, sigma=1.0, - n_pre=5, n_post=5, rho=0.5 - ) + result_moderate = pa.mde(n_treated=50, n_control=50, sigma=1.0, n_pre=5, n_post=5, rho=0.5) # Extreme ICC should have higher MDE (less independent info) assert result_extreme.mde > result_moderate.mde # MDE should still be finite and reasonable - assert result_extreme.mde < float('inf') + assert result_extreme.mde < float("inf") assert result_extreme.mde > 0 @@ -326,12 +315,7 @@ def test_compute_mde(self): def test_compute_power(self): """Test compute_power convenience function.""" - power = compute_power( - effect_size=0.5, - n_treated=50, - n_control=50, - sigma=1.0 - ) + power = compute_power(effect_size=0.5, n_treated=50, n_control=50, sigma=1.0) assert isinstance(power, float) assert 0 < power < 1 @@ -353,12 +337,8 @@ def test_convenience_functions_consistency(self): assert mde_class == mde_func # Power - power_class = pa.power( - effect_size=0.5, n_treated=50, n_control=50, sigma=1.0 - ).power - power_func = compute_power( - effect_size=0.5, n_treated=50, n_control=50, sigma=1.0 - ) + power_class = pa.power(effect_size=0.5, n_treated=50, n_control=50, sigma=1.0).power + power_func = compute_power(effect_size=0.5, n_treated=50, n_control=50, sigma=1.0) assert power_class == power_func # Sample size @@ -563,13 +543,16 @@ class Result: return Result() # Test with low failure rate (should not warn) + from diff_diff.prep import generate_did_data + estimator = FailingEstimator(fail_rate=0.0) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - results = simulate_power( + simulate_power( estimator=estimator, n_simulations=10, progress=False, + data_generator=generate_did_data, ) # Should have completed successfully without warning assert len([x for x in w if "simulations" in str(x.message)]) == 0 @@ -583,10 +566,12 @@ def test_plot_power_curve_dataframe(self): pytest.importorskip("matplotlib") from diff_diff.visualization import plot_power_curve - df = pd.DataFrame({ - "effect_size": [0.1, 0.2, 0.3, 0.5, 0.7, 1.0], - "power": [0.1, 0.2, 0.4, 0.7, 0.9, 0.99] - }) + df = pd.DataFrame( + { + "effect_size": [0.1, 0.2, 0.3, 0.5, 0.7, 1.0], + "power": [0.1, 0.2, 0.4, 0.7, 0.9, 0.99], + } + ) ax = plot_power_curve(df, show=False) assert ax is not None @@ -597,10 +582,7 @@ def test_plot_power_curve_manual_data(self): from diff_diff.visualization import plot_power_curve ax = plot_power_curve( - effect_sizes=[0.1, 0.2, 0.3, 0.5], - powers=[0.1, 0.3, 0.6, 0.9], - mde=0.25, - show=False + effect_sizes=[0.1, 0.2, 0.3, 0.5], powers=[0.1, 0.3, 0.6, 0.9], mde=0.25, show=False ) assert ax is not None @@ -609,10 +591,9 @@ def test_plot_power_curve_sample_size(self): pytest.importorskip("matplotlib") from diff_diff.visualization import plot_power_curve - df = pd.DataFrame({ - "sample_size": [20, 50, 100, 150, 200], - "power": [0.2, 0.5, 0.8, 0.9, 0.95] - }) + df = pd.DataFrame( + {"sample_size": [20, 50, 100, 150, 200], "power": [0.2, 0.5, 0.8, 0.9, 0.95]} + ) ax = plot_power_curve(df, show=False) assert ax is not None @@ -626,10 +607,7 @@ def test_plot_validates_input(self): plot_power_curve(show=False) # No data provided with pytest.raises(ValueError): - plot_power_curve( - effect_sizes=[1, 2, 3], - show=False # Missing powers - ) + plot_power_curve(effect_sizes=[1, 2, 3], show=False) # Missing powers class TestEdgeCases: @@ -648,15 +626,11 @@ def test_extreme_power_values(self): pa = PowerAnalysis() # Zero effect should give ~alpha power - result_zero = pa.power( - effect_size=0.0, n_treated=50, n_control=50, sigma=1.0 - ) + result_zero = pa.power(effect_size=0.0, n_treated=50, n_control=50, sigma=1.0) assert result_zero.power < 0.10 # Huge effect should give ~1.0 power - result_huge = pa.power( - effect_size=100.0, n_treated=50, n_control=50, sigma=1.0 - ) + result_huge = pa.power(effect_size=100.0, n_treated=50, n_control=50, sigma=1.0) assert result_huge.power > 0.99 def test_unbalanced_design(self): @@ -689,3 +663,529 @@ def test_max_sample_size_constant(self): # Verify constant is the expected value assert MAX_SAMPLE_SIZE == 2**31 - 1 + + +# --------------------------------------------------------------------------- +# Registry tests +# --------------------------------------------------------------------------- + + +class TestEstimatorRegistry: + """Tests for the estimator registry.""" + + EXPECTED_ESTIMATORS = [ + "DifferenceInDifferences", + "TwoWayFixedEffects", + "MultiPeriodDiD", + "CallawaySantAnna", + "SunAbraham", + "ImputationDiD", + "TwoStageDiD", + "StackedDiD", + "EfficientDiD", + "TROP", + "SyntheticDiD", + "TripleDifference", + "ContinuousDiD", + ] + + def test_all_estimators_registered(self): + """Every supported estimator has a registry entry.""" + registry = _get_registry() + for name in self.EXPECTED_ESTIMATORS: + assert name in registry, f"{name} missing from registry" + + def test_bacon_excluded(self): + """BaconDecomposition is diagnostic-only and should not be in registry.""" + registry = _get_registry() + assert "BaconDecomposition" not in registry + + def test_dgp_kwargs_builders_return_dicts(self): + """Each DGP kwargs builder returns a non-empty dict.""" + params = dict( + n_units=50, + n_periods=4, + treatment_effect=5.0, + treatment_fraction=0.5, + treatment_period=2, + sigma=1.0, + ) + for builder in [ + _basic_dgp_kwargs, + _staggered_dgp_kwargs, + _factor_dgp_kwargs, + _ddd_dgp_kwargs, + _continuous_dgp_kwargs, + ]: + result = builder(**params) + assert isinstance(result, dict) + assert len(result) > 0 + + def test_fit_kwargs_builders_return_dicts(self): + """Each fit kwargs builder returns a dict with 'outcome'.""" + dummy_df = pd.DataFrame({"period": [0, 1, 2, 3]}) + for builder in [ + _basic_fit_kwargs, + _twfe_fit_kwargs, + _staggered_fit_kwargs, + _ddd_fit_kwargs, + _trop_fit_kwargs, + _continuous_fit_kwargs, + ]: + result = builder(dummy_df, 50, 4, 2) + assert isinstance(result, dict) + assert "outcome" in result + + def test_extract_simple(self): + """_extract_simple extracts from .att/.se/.p_value/.conf_int.""" + + class MockResult: + att = 3.0 + se = 0.5 + p_value = 0.01 + conf_int = (2.0, 4.0) + + att, se, p, ci = _extract_simple(MockResult()) + assert att == 3.0 + assert se == 0.5 + assert p == 0.01 + assert ci == (2.0, 4.0) + + def test_extract_multiperiod(self): + """_extract_multiperiod extracts from avg_* attributes.""" + + class MockResult: + avg_att = 4.0 + avg_se = 0.6 + avg_p_value = 0.001 + avg_conf_int = (2.8, 5.2) + + att, se, p, ci = _extract_multiperiod(MockResult()) + assert att == 4.0 + assert se == 0.6 + assert p == 0.001 + assert ci == (2.8, 5.2) + + def test_extract_staggered_analytical(self): + """_extract_staggered handles analytical result objects.""" + + class MockResult: + overall_att = 2.0 + overall_se = 0.3 + overall_p_value = 0.02 + overall_conf_int = (1.4, 2.6) + + att, se, p, ci = _extract_staggered(MockResult()) + assert att == 2.0 + assert se == 0.3 + assert p == 0.02 + assert ci == (1.4, 2.6) + + def test_extract_staggered_bootstrap_fallback(self): + """_extract_staggered falls back to bootstrap attribute names.""" + + class MockBootstrapResult: + overall_att = 2.0 + overall_att_se = 0.4 + overall_att_p_value = 0.03 + overall_att_ci = (1.2, 2.8) + + att, se, p, ci = _extract_staggered(MockBootstrapResult()) + assert att == 2.0 + assert se == 0.4 + assert p == 0.03 + assert ci == (1.2, 2.8) + + def test_extract_continuous(self): + """_extract_continuous extracts from overall_att_* attributes.""" + + class MockResult: + overall_att = 1.5 + overall_att_se = 0.2 + overall_att_p_value = 0.005 + overall_att_conf_int = (1.1, 1.9) + + att, se, p, ci = _extract_continuous(MockResult()) + assert att == 1.5 + assert se == 0.2 + assert p == 0.005 + assert ci == (1.1, 1.9) + + def test_unknown_estimator_raises_without_data_generator(self): + """Unknown estimator without data_generator raises ValueError.""" + + class UnknownEstimator: + pass + + with pytest.raises(ValueError, match="not in registry"): + simulate_power( + UnknownEstimator(), + n_simulations=5, + progress=False, + ) + + +# --------------------------------------------------------------------------- +# Estimator coverage tests for simulate_power +# --------------------------------------------------------------------------- + + +class TestEstimatorCoverage: + """Verify simulate_power works for each registered estimator.""" + + def _assert_valid_result(self, result, expected_name): + assert 0 <= result.power <= 1 + assert result.estimator_name == expected_name + assert np.isfinite(result.mean_estimate) + assert result.n_simulations > 0 + assert result.coverage >= 0 + + def test_did(self): + result = simulate_power( + DifferenceInDifferences(), + n_simulations=10, + seed=42, + progress=False, + ) + self._assert_valid_result(result, "DifferenceInDifferences") + + def test_twfe(self): + result = simulate_power( + TwoWayFixedEffects(), + n_simulations=10, + seed=42, + progress=False, + ) + self._assert_valid_result(result, "TwoWayFixedEffects") + + def test_multiperiod(self): + result = simulate_power( + MultiPeriodDiD(), + n_simulations=10, + seed=42, + progress=False, + ) + self._assert_valid_result(result, "MultiPeriodDiD") + + def test_callaway_santanna(self): + result = simulate_power( + CallawaySantAnna(), + n_units=60, + n_periods=6, + treatment_period=3, + n_simulations=10, + seed=42, + progress=False, + ) + self._assert_valid_result(result, "CallawaySantAnna") + + def test_sun_abraham(self): + result = simulate_power( + SunAbraham(), + n_units=60, + n_periods=6, + treatment_period=3, + n_simulations=10, + seed=42, + progress=False, + ) + self._assert_valid_result(result, "SunAbraham") + + def test_imputation_did(self): + result = simulate_power( + ImputationDiD(), + n_units=60, + n_periods=6, + treatment_period=3, + n_simulations=10, + seed=42, + progress=False, + ) + self._assert_valid_result(result, "ImputationDiD") + + def test_two_stage_did(self): + result = simulate_power( + TwoStageDiD(), + n_units=60, + n_periods=6, + treatment_period=3, + n_simulations=10, + seed=42, + progress=False, + ) + self._assert_valid_result(result, "TwoStageDiD") + + def test_stacked_did(self): + result = simulate_power( + StackedDiD(), + n_units=60, + n_periods=6, + treatment_period=3, + n_simulations=10, + seed=42, + progress=False, + ) + self._assert_valid_result(result, "StackedDiD") + + def test_efficient_did(self): + result = simulate_power( + EfficientDiD(), + n_units=60, + n_periods=6, + treatment_period=3, + n_simulations=10, + seed=42, + progress=False, + ) + self._assert_valid_result(result, "EfficientDiD") + + def test_triple_difference(self): + result = simulate_power( + TripleDifference(), + n_units=80, + n_simulations=10, + seed=42, + progress=False, + ) + self._assert_valid_result(result, "TripleDifference") + + @pytest.mark.slow + def test_trop(self): + result = simulate_power( + TROP(), + n_units=50, + n_periods=6, + treatment_period=3, + treatment_fraction=0.3, + n_simulations=10, + seed=42, + progress=False, + ) + self._assert_valid_result(result, "TROP") + + @pytest.mark.slow + def test_synthetic_did(self): + result = simulate_power( + SyntheticDiD(), + n_units=50, + n_periods=6, + treatment_period=3, + treatment_fraction=0.3, + n_simulations=10, + seed=42, + progress=False, + ) + self._assert_valid_result(result, "SyntheticDiD") + + def test_continuous_did(self): + result = simulate_power( + ContinuousDiD(), + n_units=100, + n_periods=6, + treatment_period=3, + n_simulations=10, + seed=42, + progress=False, + ) + self._assert_valid_result(result, "ContinuousDiD") + + +# --------------------------------------------------------------------------- +# simulate_mde tests +# --------------------------------------------------------------------------- + + +class TestSimulateMDE: + """Tests for simulate_mde function.""" + + def test_basic_mde(self): + """MDE found for DiD, power at MDE close to target.""" + result = simulate_mde( + DifferenceInDifferences(), + n_units=100, + sigma=1.0, + n_simulations=50, + seed=42, + progress=False, + ) + assert isinstance(result, SimulationMDEResults) + assert result.mde > 0 + assert result.power_at_mde >= result.target_power - 0.10 + + def test_result_methods(self): + """summary(), to_dict(), to_dataframe() work.""" + result = simulate_mde( + DifferenceInDifferences(), + n_simulations=30, + seed=42, + progress=False, + ) + summary = result.summary() + assert "MDE" in summary or "Minimum" in summary + + d = result.to_dict() + assert "mde" in d + assert "estimator_name" in d + + df = result.to_dataframe() + assert isinstance(df, pd.DataFrame) + assert len(df) == 1 + + def test_monotonicity_in_search_path(self): + """The search path records plausible effect_size / power pairs.""" + result = simulate_mde( + DifferenceInDifferences(), + n_simulations=50, + seed=42, + progress=False, + ) + assert len(result.search_path) > 0 + for step in result.search_path: + assert "effect_size" in step + assert "power" in step + assert 0 <= step["power"] <= 1 + + def test_convergence_within_max_steps(self): + """Search terminates within max_steps.""" + result = simulate_mde( + DifferenceInDifferences(), + n_simulations=30, + max_steps=10, + seed=42, + progress=False, + ) + # n_steps includes bracketing steps + bisection + assert result.n_steps <= 25 # generous bound + + def test_custom_data_generator(self): + """Works with user-provided DGP.""" + from diff_diff.prep import generate_did_data + + result = simulate_mde( + DifferenceInDifferences(), + n_simulations=30, + seed=42, + progress=False, + data_generator=generate_did_data, + ) + assert result.mde > 0 + + def test_small_sigma_gives_small_mde(self): + """Small noise → small MDE.""" + result = simulate_mde( + DifferenceInDifferences(), + n_units=100, + sigma=0.1, + n_simulations=50, + seed=42, + progress=False, + ) + assert result.mde < 1.0 + + def test_large_sigma_gives_large_mde(self): + """Large noise → large MDE.""" + result = simulate_mde( + DifferenceInDifferences(), + n_units=50, + sigma=10.0, + n_simulations=50, + seed=42, + progress=False, + ) + assert result.mde > 1.0 + + +# --------------------------------------------------------------------------- +# simulate_sample_size tests +# --------------------------------------------------------------------------- + + +class TestSimulateSampleSize: + """Tests for simulate_sample_size function.""" + + def test_basic_sample_size(self): + """Required N found for DiD, power at N close to target.""" + result = simulate_sample_size( + DifferenceInDifferences(), + treatment_effect=5.0, + sigma=1.0, + n_simulations=50, + seed=42, + progress=False, + ) + assert isinstance(result, SimulationSampleSizeResults) + assert result.required_n > 0 + assert result.power_at_n >= result.target_power - 0.10 + + def test_result_methods(self): + """summary(), to_dict(), to_dataframe() work.""" + result = simulate_sample_size( + DifferenceInDifferences(), + treatment_effect=5.0, + n_simulations=30, + seed=42, + progress=False, + ) + summary = result.summary() + assert "Sample Size" in summary or "Required" in summary + + d = result.to_dict() + assert "required_n" in d + assert "estimator_name" in d + + df = result.to_dataframe() + assert isinstance(df, pd.DataFrame) + assert len(df) == 1 + + def test_monotonicity_in_search_path(self): + """The search path records plausible n_units / power pairs.""" + result = simulate_sample_size( + DifferenceInDifferences(), + treatment_effect=5.0, + n_simulations=50, + seed=42, + progress=False, + ) + assert len(result.search_path) > 0 + for step in result.search_path: + assert "n_units" in step + assert "power" in step + assert 0 <= step["power"] <= 1 + + def test_custom_data_generator(self): + """Works with user-provided DGP.""" + from diff_diff.prep import generate_did_data + + result = simulate_sample_size( + DifferenceInDifferences(), + treatment_effect=5.0, + n_simulations=30, + seed=42, + progress=False, + data_generator=generate_did_data, + ) + assert result.required_n > 0 + + def test_large_effect_gives_small_n(self): + """Large effect → small N.""" + result = simulate_sample_size( + DifferenceInDifferences(), + treatment_effect=20.0, + sigma=1.0, + n_simulations=50, + seed=42, + progress=False, + ) + assert result.required_n <= 100 + + def test_small_effect_gives_large_n(self): + """Small effect → large N.""" + result = simulate_sample_size( + DifferenceInDifferences(), + treatment_effect=0.5, + sigma=5.0, + n_simulations=50, + seed=42, + progress=False, + ) + assert result.required_n >= 50 From c412e4ae0fcfac5f53432f6e4d5ff64000a1c795 Mon Sep 17 00:00:00 2001 From: igerber Date: Wed, 18 Mar 2026 09:18:15 -0400 Subject: [PATCH 02/14] Address PR #208 review: remove ContinuousDiD from registry, validate bisection brackets MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit P0: Remove ContinuousDiD from power analysis registry — the DGP simulated att_slope as treatment_effect but extracted overall_att (a dose-weighted average), so the simulated truth and evaluated estimand didn't match. P1: Add endpoint evaluation for user-supplied effect_range/n_range in simulate_mde() and simulate_sample_size(). Previously these paths skipped evaluation entirely, leaving search_path empty and best_power at 0.0. Now warns when brackets don't contain the target power. P2: Add regression tests for all three fixes. Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/power.py | 86 +++++++++++++++------------------------ tests/test_power.py | 98 +++++++++++++++++++++++++++++++-------------- 2 files changed, 100 insertions(+), 84 deletions(-) diff --git a/diff_diff/power.py b/diff_diff/power.py index 2b1a3a74..c9187c18 100644 --- a/diff_diff/power.py +++ b/diff_diff/power.py @@ -124,24 +124,6 @@ def _ddd_dgp_kwargs( ) -def _continuous_dgp_kwargs( - n_units: int, - n_periods: int, - treatment_effect: float, - treatment_fraction: float, - treatment_period: int, - sigma: float, -) -> Dict[str, Any]: - return dict( - n_units=n_units, - n_periods=n_periods, - att_slope=treatment_effect, - never_treated_frac=1 - treatment_fraction, - cohort_periods=[treatment_period], - noise_sd=sigma, - ) - - # -- Fit kwargs builders ------------------------------------------------------ @@ -221,21 +203,6 @@ def _sdid_fit_kwargs( ) -def _continuous_fit_kwargs( - data: pd.DataFrame, - n_units: int, - n_periods: int, - treatment_period: int, -) -> Dict[str, Any]: - return dict( - outcome="outcome", - unit="unit", - time="period", - first_treat="first_treat", - dose="dose", - ) - - # -- Result extractors -------------------------------------------------------- @@ -270,17 +237,6 @@ def _first(r: Any, *attrs: str, default: Any = _nan) -> Any: ) -def _extract_continuous( - result: Any, -) -> Tuple[float, float, float, Tuple[float, float]]: - return ( - result.overall_att, - result.overall_att_se, - result.overall_att_p_value, - result.overall_att_conf_int, - ) - - # -- Registry construction (deferred to avoid import-time cost) --------------- _ESTIMATOR_REGISTRY: Optional[Dict[str, _EstimatorProfile]] = None @@ -293,7 +249,6 @@ def _get_registry() -> Dict[str, _EstimatorProfile]: return _ESTIMATOR_REGISTRY from diff_diff.prep import ( - generate_continuous_did_data, generate_ddd_data, generate_did_data, generate_factor_data, @@ -389,14 +344,6 @@ def _get_registry() -> Dict[str, _EstimatorProfile]: result_extractor=_extract_simple, min_n=64, ), - # --- Continuous DiD --- - "ContinuousDiD": _EstimatorProfile( - default_dgp=generate_continuous_did_data, - dgp_kwargs_builder=_continuous_dgp_kwargs, - fit_kwargs_builder=_continuous_fit_kwargs, - result_extractor=_extract_continuous, - min_n=40, - ), } return _ESTIMATOR_REGISTRY @@ -1841,6 +1788,31 @@ def _power_at(effect: float) -> float: # --- Bracket --- if effect_range is not None: lo, hi = effect_range + power_lo = _power_at(lo) + power_hi = _power_at(hi) + if power_lo >= power: + warnings.warn( + f"Power at effect={lo} is {power_lo:.2f} >= target {power}. " + f"Lower bound already exceeds target power. Returning lo as MDE.", + UserWarning, + ) + return SimulationMDEResults( + mde=lo, + power_at_mde=power_lo, + target_power=power, + alpha=alpha, + n_units=n_units, + n_simulations_per_step=n_simulations, + n_steps=len(search_path), + search_path=search_path, + estimator_name=estimator_name, + ) + if power_hi < power: + warnings.warn( + f"Target power {power} not bracketed: power at effect={hi} " + f"is {power_hi:.2f}. Upper bound may be too low.", + UserWarning, + ) else: lo = 0.0 # Check that power at zero is below target (no inflated Type I error) @@ -2015,6 +1987,14 @@ def _power_at_n(n: int) -> float: # --- Bracket --- if n_range is not None: lo, hi = n_range + _power_at_n(lo) # evaluate lo to populate search_path + power_hi = _power_at_n(hi) + if power_hi < power: + warnings.warn( + f"Target power {power} not bracketed: power at n={hi} " + f"is {power_hi:.2f}. Upper bound may be too low.", + UserWarning, + ) else: lo = min_n hi = max(100, 2 * min_n) diff --git a/tests/test_power.py b/tests/test_power.py index 13a4db53..a58b713c 100644 --- a/tests/test_power.py +++ b/tests/test_power.py @@ -7,7 +7,6 @@ from diff_diff import ( TROP, CallawaySantAnna, - ContinuousDiD, DifferenceInDifferences, EfficientDiD, ImputationDiD, @@ -34,11 +33,8 @@ MAX_SAMPLE_SIZE, _basic_dgp_kwargs, _basic_fit_kwargs, - _continuous_dgp_kwargs, - _continuous_fit_kwargs, _ddd_dgp_kwargs, _ddd_fit_kwargs, - _extract_continuous, _extract_multiperiod, _extract_simple, _extract_staggered, @@ -686,7 +682,6 @@ class TestEstimatorRegistry: "TROP", "SyntheticDiD", "TripleDifference", - "ContinuousDiD", ] def test_all_estimators_registered(self): @@ -715,7 +710,6 @@ def test_dgp_kwargs_builders_return_dicts(self): _staggered_dgp_kwargs, _factor_dgp_kwargs, _ddd_dgp_kwargs, - _continuous_dgp_kwargs, ]: result = builder(**params) assert isinstance(result, dict) @@ -730,7 +724,6 @@ def test_fit_kwargs_builders_return_dicts(self): _staggered_fit_kwargs, _ddd_fit_kwargs, _trop_fit_kwargs, - _continuous_fit_kwargs, ]: result = builder(dummy_df, 50, 4, 2) assert isinstance(result, dict) @@ -796,20 +789,19 @@ class MockBootstrapResult: assert p == 0.03 assert ci == (1.2, 2.8) - def test_extract_continuous(self): - """_extract_continuous extracts from overall_att_* attributes.""" + def test_continuous_did_not_in_registry(self): + """ContinuousDiD is not in registry and raises without custom data_generator.""" + from diff_diff import ContinuousDiD - class MockResult: - overall_att = 1.5 - overall_att_se = 0.2 - overall_att_p_value = 0.005 - overall_att_conf_int = (1.1, 1.9) + registry = _get_registry() + assert "ContinuousDiD" not in registry - att, se, p, ci = _extract_continuous(MockResult()) - assert att == 1.5 - assert se == 0.2 - assert p == 0.005 - assert ci == (1.1, 1.9) + with pytest.raises(ValueError, match="not in registry"): + simulate_power( + ContinuousDiD(), + n_simulations=5, + progress=False, + ) def test_unknown_estimator_raises_without_data_generator(self): """Unknown estimator without data_generator raises ValueError.""" @@ -977,18 +969,6 @@ def test_synthetic_did(self): ) self._assert_valid_result(result, "SyntheticDiD") - def test_continuous_did(self): - result = simulate_power( - ContinuousDiD(), - n_units=100, - n_periods=6, - treatment_period=3, - n_simulations=10, - seed=42, - progress=False, - ) - self._assert_valid_result(result, "ContinuousDiD") - # --------------------------------------------------------------------------- # simulate_mde tests @@ -1094,6 +1074,34 @@ def test_large_sigma_gives_large_mde(self): ) assert result.mde > 1.0 + def test_explicit_effect_range(self): + """Explicit effect_range evaluates endpoints and populates search_path.""" + result = simulate_mde( + DifferenceInDifferences(), + n_units=100, + sigma=1.0, + n_simulations=30, + effect_range=(0.5, 5.0), + seed=42, + progress=False, + ) + assert result.mde > 0 + assert result.power_at_mde > 0 + assert len(result.search_path) > 0 + + def test_unbracketed_effect_range_warns(self): + """Tiny effect_range that cannot bracket target power warns.""" + with pytest.warns(UserWarning, match="not bracketed"): + simulate_mde( + DifferenceInDifferences(), + n_units=50, + sigma=10.0, + n_simulations=30, + effect_range=(0.0, 0.001), + seed=42, + progress=False, + ) + # --------------------------------------------------------------------------- # simulate_sample_size tests @@ -1189,3 +1197,31 @@ def test_small_effect_gives_large_n(self): progress=False, ) assert result.required_n >= 50 + + def test_explicit_n_range(self): + """Explicit n_range evaluates endpoints and populates search_path.""" + result = simulate_sample_size( + DifferenceInDifferences(), + treatment_effect=5.0, + sigma=1.0, + n_simulations=30, + n_range=(20, 200), + seed=42, + progress=False, + ) + assert result.required_n > 0 + assert result.power_at_n > 0 + assert len(result.search_path) > 0 + + def test_unbracketed_n_range_warns(self): + """Tiny n_range that cannot bracket target power warns.""" + with pytest.warns(UserWarning, match="not bracketed"): + simulate_sample_size( + DifferenceInDifferences(), + treatment_effect=0.01, + sigma=10.0, + n_simulations=30, + n_range=(20, 22), + seed=42, + progress=False, + ) From 2a2fd9536dda9ddcd02034f5b65dfadc0a7c4701 Mon Sep 17 00:00:00 2001 From: igerber Date: Wed, 18 Mar 2026 09:37:55 -0400 Subject: [PATCH 03/14] Address PR #208 review: remove TWFE from registry, add lo-sufficient short-circuit - Remove TwoWayFixedEffects from power analysis registry (time="period" produces treated*period_number, not standard ATT) - Add early return in simulate_sample_size() when lower bound already achieves target power (both explicit n_range and auto-bracket paths) - Narrow docstring from "All" to "Most" built-in estimators - Add regression tests for TWFE exclusion and lo-sufficient scenarios Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/power.py | 50 +++++++++++++++++++++++++++--------------- tests/test_power.py | 53 +++++++++++++++++++++++++++++++++++---------- 2 files changed, 73 insertions(+), 30 deletions(-) diff --git a/diff_diff/power.py b/diff_diff/power.py index c9187c18..52754d40 100644 --- a/diff_diff/power.py +++ b/diff_diff/power.py @@ -136,15 +136,6 @@ def _basic_fit_kwargs( return dict(outcome="outcome", treatment="treated", time="post") -def _twfe_fit_kwargs( - data: pd.DataFrame, - n_units: int, - n_periods: int, - treatment_period: int, -) -> Dict[str, Any]: - return dict(outcome="outcome", treatment="treated", time="period", unit="unit") - - def _multiperiod_fit_kwargs( data: pd.DataFrame, n_units: int, @@ -264,13 +255,6 @@ def _get_registry() -> Dict[str, _EstimatorProfile]: result_extractor=_extract_simple, min_n=20, ), - "TwoWayFixedEffects": _EstimatorProfile( - default_dgp=generate_did_data, - dgp_kwargs_builder=_basic_dgp_kwargs, - fit_kwargs_builder=_twfe_fit_kwargs, - result_extractor=_extract_simple, - min_n=20, - ), "MultiPeriodDiD": _EstimatorProfile( default_dgp=generate_did_data, dgp_kwargs_builder=_basic_dgp_kwargs, @@ -1221,7 +1205,7 @@ def simulate_power( This function simulates datasets with known treatment effects and estimates power as the fraction of simulations where the null hypothesis is rejected. - All built-in estimators are supported via an internal registry that selects + Most built-in estimators are supported via an internal registry that selects the appropriate data-generating process and fit signature automatically. Parameters @@ -1987,7 +1971,24 @@ def _power_at_n(n: int) -> float: # --- Bracket --- if n_range is not None: lo, hi = n_range - _power_at_n(lo) # evaluate lo to populate search_path + power_lo = _power_at_n(lo) + if power_lo >= power: + warnings.warn( + f"Power at n={lo} is {power_lo:.2f} >= target {power}. " + f"Lower bound already achieves target power. Returning lo.", + UserWarning, + ) + return SimulationSampleSizeResults( + required_n=lo, + power_at_n=power_lo, + target_power=power, + alpha=alpha, + effect_size=treatment_effect, + n_simulations_per_step=n_simulations, + n_steps=len(search_path), + search_path=search_path, + estimator_name=estimator_name, + ) power_hi = _power_at_n(hi) if power_hi < power: warnings.warn( @@ -1997,6 +1998,19 @@ def _power_at_n(n: int) -> float: ) else: lo = min_n + power_lo = _power_at_n(lo) + if power_lo >= power: + return SimulationSampleSizeResults( + required_n=lo, + power_at_n=power_lo, + target_power=power, + alpha=alpha, + effect_size=treatment_effect, + n_simulations_per_step=n_simulations, + n_steps=len(search_path), + search_path=search_path, + estimator_name=estimator_name, + ) hi = max(100, 2 * min_n) for _ in range(10): if _power_at_n(hi) >= power: diff --git a/tests/test_power.py b/tests/test_power.py index a58b713c..ae002ac5 100644 --- a/tests/test_power.py +++ b/tests/test_power.py @@ -43,7 +43,6 @@ _staggered_dgp_kwargs, _staggered_fit_kwargs, _trop_fit_kwargs, - _twfe_fit_kwargs, ) @@ -671,7 +670,6 @@ class TestEstimatorRegistry: EXPECTED_ESTIMATORS = [ "DifferenceInDifferences", - "TwoWayFixedEffects", "MultiPeriodDiD", "CallawaySantAnna", "SunAbraham", @@ -720,7 +718,6 @@ def test_fit_kwargs_builders_return_dicts(self): dummy_df = pd.DataFrame({"period": [0, 1, 2, 3]}) for builder in [ _basic_fit_kwargs, - _twfe_fit_kwargs, _staggered_fit_kwargs, _ddd_fit_kwargs, _trop_fit_kwargs, @@ -803,6 +800,18 @@ def test_continuous_did_not_in_registry(self): progress=False, ) + def test_twfe_not_in_registry(self): + """TwoWayFixedEffects is not in registry and raises without custom data_generator.""" + registry = _get_registry() + assert "TwoWayFixedEffects" not in registry + + with pytest.raises(ValueError, match="not in registry"): + simulate_power( + TwoWayFixedEffects(), + n_simulations=5, + progress=False, + ) + def test_unknown_estimator_raises_without_data_generator(self): """Unknown estimator without data_generator raises ValueError.""" @@ -841,15 +850,6 @@ def test_did(self): ) self._assert_valid_result(result, "DifferenceInDifferences") - def test_twfe(self): - result = simulate_power( - TwoWayFixedEffects(), - n_simulations=10, - seed=42, - progress=False, - ) - self._assert_valid_result(result, "TwoWayFixedEffects") - def test_multiperiod(self): result = simulate_power( MultiPeriodDiD(), @@ -1225,3 +1225,32 @@ def test_unbracketed_n_range_warns(self): seed=42, progress=False, ) + + def test_lo_already_sufficient_explicit(self): + """When lo already meets power, return lo immediately with warning.""" + with pytest.warns(UserWarning, match="Lower bound already achieves"): + result = simulate_sample_size( + DifferenceInDifferences(), + treatment_effect=50.0, + sigma=0.1, + n_simulations=50, + n_range=(20, 200), + seed=42, + progress=False, + ) + assert result.required_n == 20 + assert result.power_at_n >= 0.80 + + def test_lo_already_sufficient_auto(self): + """Auto-bracket returns min_n when effect overwhelmingly large.""" + result = simulate_sample_size( + DifferenceInDifferences(), + treatment_effect=50.0, + sigma=0.1, + n_simulations=50, + seed=42, + progress=False, + ) + # min_n for DifferenceInDifferences is 20 + assert result.required_n == 20 + assert result.power_at_n >= 0.80 From 9b29d7ebfb1059fa33fe80e89c617efc3210394c Mon Sep 17 00:00:00 2001 From: igerber Date: Wed, 18 Mar 2026 10:12:40 -0400 Subject: [PATCH 04/14] Address PR #208 review: cap SyntheticDiD treatment_fraction to fix zero-power bug Add _sdid_dgp_kwargs that caps treatment_fraction at 0.4 so the placebo variance method has enough pseudo-controls (n_control > n_treated). Add regression test for the default SyntheticDiD path. Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/power.py | 23 ++++++++++++++++++++++- tests/test_power.py | 17 +++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/diff_diff/power.py b/diff_diff/power.py index 52754d40..98a1264a 100644 --- a/diff_diff/power.py +++ b/diff_diff/power.py @@ -109,6 +109,27 @@ def _factor_dgp_kwargs( ) +def _sdid_dgp_kwargs( + n_units: int, + n_periods: int, + treatment_effect: float, + treatment_fraction: float, + treatment_period: int, + sigma: float, +) -> Dict[str, Any]: + # SyntheticDiD placebo variance requires n_control > n_treated; + # cap at 40% treated to ensure adequate pseudo-controls. + safe_fraction = min(treatment_fraction, 0.4) + return _factor_dgp_kwargs( + n_units=n_units, + n_periods=n_periods, + treatment_effect=treatment_effect, + treatment_fraction=safe_fraction, + treatment_period=treatment_period, + sigma=sigma, + ) + + def _ddd_dgp_kwargs( n_units: int, n_periods: int, @@ -315,7 +336,7 @@ def _get_registry() -> Dict[str, _EstimatorProfile]: ), "SyntheticDiD": _EstimatorProfile( default_dgp=generate_factor_data, - dgp_kwargs_builder=_factor_dgp_kwargs, + dgp_kwargs_builder=_sdid_dgp_kwargs, fit_kwargs_builder=_sdid_fit_kwargs, result_extractor=_extract_simple, min_n=30, diff --git a/tests/test_power.py b/tests/test_power.py index ae002ac5..26d0d4b2 100644 --- a/tests/test_power.py +++ b/tests/test_power.py @@ -40,6 +40,7 @@ _extract_staggered, _factor_dgp_kwargs, _get_registry, + _sdid_dgp_kwargs, _staggered_dgp_kwargs, _staggered_fit_kwargs, _trop_fit_kwargs, @@ -707,6 +708,7 @@ def test_dgp_kwargs_builders_return_dicts(self): _basic_dgp_kwargs, _staggered_dgp_kwargs, _factor_dgp_kwargs, + _sdid_dgp_kwargs, _ddd_dgp_kwargs, ]: result = builder(**params) @@ -969,6 +971,21 @@ def test_synthetic_did(self): ) self._assert_valid_result(result, "SyntheticDiD") + @pytest.mark.slow + def test_synthetic_did_default_fraction(self): + """Default treatment_fraction=0.5 must not produce zero power.""" + result = simulate_power( + SyntheticDiD(), + n_units=50, + n_periods=6, + treatment_period=3, + n_simulations=10, + seed=42, + progress=False, + ) + self._assert_valid_result(result, "SyntheticDiD") + assert result.power > 0, "Default SyntheticDiD path gave zero power" + # --------------------------------------------------------------------------- # simulate_mde tests From 35df6249c89f875cd1b8a4616b6778c67413935a Mon Sep 17 00:00:00 2001 From: igerber Date: Wed, 18 Mar 2026 10:37:10 -0400 Subject: [PATCH 05/14] Address PR #208 review: replace silent SyntheticDiD treatment_fraction cap with fail-fast validation Remove _sdid_dgp_kwargs that silently capped treatment_fraction at 0.4. Add upfront ValueError in simulate_power() when SyntheticDiD placebo variance has n_control <= n_treated, with actionable error message suggesting lowering treatment_fraction or using bootstrap variance. Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/power.py | 38 +++++++-------- docs/methodology/REGISTRY.md | 1 + tests/test_power.py | 92 +++++++++++++++++++++++++++++++++--- 3 files changed, 103 insertions(+), 28 deletions(-) diff --git a/diff_diff/power.py b/diff_diff/power.py index 98a1264a..ba1be6cd 100644 --- a/diff_diff/power.py +++ b/diff_diff/power.py @@ -109,27 +109,6 @@ def _factor_dgp_kwargs( ) -def _sdid_dgp_kwargs( - n_units: int, - n_periods: int, - treatment_effect: float, - treatment_fraction: float, - treatment_period: int, - sigma: float, -) -> Dict[str, Any]: - # SyntheticDiD placebo variance requires n_control > n_treated; - # cap at 40% treated to ensure adequate pseudo-controls. - safe_fraction = min(treatment_fraction, 0.4) - return _factor_dgp_kwargs( - n_units=n_units, - n_periods=n_periods, - treatment_effect=treatment_effect, - treatment_fraction=safe_fraction, - treatment_period=treatment_period, - sigma=sigma, - ) - - def _ddd_dgp_kwargs( n_units: int, n_periods: int, @@ -336,7 +315,7 @@ def _get_registry() -> Dict[str, _EstimatorProfile]: ), "SyntheticDiD": _EstimatorProfile( default_dgp=generate_factor_data, - dgp_kwargs_builder=_sdid_dgp_kwargs, + dgp_kwargs_builder=_factor_dgp_kwargs, fit_kwargs_builder=_sdid_fit_kwargs, result_extractor=_extract_simple, min_n=30, @@ -1330,6 +1309,21 @@ def simulate_power( # When a custom data_generator is provided, bypass registry DGP use_custom_dgp = data_generator is not None + # SyntheticDiD placebo variance requires n_control > n_treated + if estimator_name == "SyntheticDiD" and not use_custom_dgp: + vm = getattr(estimator, "variance_method", "placebo") + n_treated = max(1, int(n_units * treatment_fraction)) + n_control = n_units - n_treated + if vm == "placebo" and n_control <= n_treated: + raise ValueError( + f"SyntheticDiD placebo variance requires more control than " + f"treated units (got n_control={n_control}, " + f"n_treated={n_treated} from treatment_fraction=" + f"{treatment_fraction}). Either lower treatment_fraction " + f"so that n_control > n_treated, or use " + f"SyntheticDiD(variance_method='bootstrap')." + ) + data_gen_kwargs = data_generator_kwargs or {} est_kwargs = estimator_kwargs or {} diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 54bfbac7..8eee4fac 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -1111,6 +1111,7 @@ Convergence criterion: stop when objective decrease < min_decrease² (default mi - **Single pre-period**: `compute_time_weights` returns `[1.0]` when `n_pre <= 1` (Frank-Wolfe on a 1-element simplex is trivial). - **Bootstrap with 0 control or 0 treated in resample**: Skip iteration (`continue`). If ALL bootstrap iterations fail, raises `ValueError`. If only 1 succeeds, warns and returns SE=0.0. If >5% failure rate, warns about reliability. - **Placebo with n_control <= n_treated**: Warns that not enough control units for placebo variance estimation, returns SE=0.0 and empty placebo effects array. The check is `n_control - n_treated < 1`. +- **Note:** Power analysis functions (`simulate_power`, `simulate_mde`, `simulate_sample_size`) raise `ValueError` for placebo variance when `n_control <= n_treated` (fail-fast before simulation). - **Negative weights attempted**: Frank-Wolfe operates on the simplex (non-negative, sum-to-1), so weights are always feasible by construction. The step size is clipped to [0, 1] and the move is toward a simplex vertex. - **Perfect pre-treatment fit**: Regularization (ζ² ||ω||²) prevents overfitting by penalizing weight concentration. - **Single treated unit**: Valid; placebo variance uses jackknife-style permutations of controls. diff --git a/tests/test_power.py b/tests/test_power.py index 26d0d4b2..a10fad05 100644 --- a/tests/test_power.py +++ b/tests/test_power.py @@ -40,7 +40,6 @@ _extract_staggered, _factor_dgp_kwargs, _get_registry, - _sdid_dgp_kwargs, _staggered_dgp_kwargs, _staggered_fit_kwargs, _trop_fit_kwargs, @@ -708,7 +707,6 @@ def test_dgp_kwargs_builders_return_dicts(self): _basic_dgp_kwargs, _staggered_dgp_kwargs, _factor_dgp_kwargs, - _sdid_dgp_kwargs, _ddd_dgp_kwargs, ]: result = builder(**params) @@ -971,20 +969,102 @@ def test_synthetic_did(self): ) self._assert_valid_result(result, "SyntheticDiD") + def test_sdid_placebo_rejects_high_fraction(self): + """SyntheticDiD placebo variance raises when n_control <= n_treated.""" + with pytest.raises(ValueError, match="placebo variance requires more control"): + simulate_power( + SyntheticDiD(), + treatment_fraction=0.5, + n_simulations=5, + seed=42, + progress=False, + ) + @pytest.mark.slow - def test_synthetic_did_default_fraction(self): - """Default treatment_fraction=0.5 must not produce zero power.""" + def test_sdid_placebo_boundary_fraction(self): + """treatment_fraction=0.49 with 50 units gives n_control=26 > n_treated=24.""" result = simulate_power( SyntheticDiD(), + treatment_fraction=0.49, n_units=50, n_periods=6, treatment_period=3, - n_simulations=10, + n_simulations=5, + seed=42, + progress=False, + ) + self._assert_valid_result(result, "SyntheticDiD") + + @pytest.mark.slow + def test_sdid_bootstrap_allows_high_fraction(self): + """Bootstrap variance method bypasses the placebo constraint.""" + result = simulate_power( + SyntheticDiD(variance_method="bootstrap"), + treatment_fraction=0.5, + n_units=50, + n_periods=6, + treatment_period=3, + n_simulations=5, seed=42, progress=False, ) self._assert_valid_result(result, "SyntheticDiD") - assert result.power > 0, "Default SyntheticDiD path gave zero power" + assert result.power >= 0 + + def test_sdid_mde_rejects_high_fraction(self): + """simulate_mde raises for SyntheticDiD placebo with high treatment_fraction.""" + with pytest.raises(ValueError, match="placebo variance requires more control"): + simulate_mde( + SyntheticDiD(), + treatment_fraction=0.5, + n_simulations=5, + seed=42, + progress=False, + ) + + def test_sdid_sample_size_rejects_high_fraction(self): + """simulate_sample_size raises for SyntheticDiD placebo with high fraction.""" + with pytest.raises(ValueError, match="placebo variance requires more control"): + simulate_sample_size( + SyntheticDiD(), + treatment_fraction=0.5, + n_simulations=5, + seed=42, + progress=False, + ) + + @pytest.mark.slow + def test_sdid_mde(self): + """simulate_mde works for SyntheticDiD with valid treatment_fraction.""" + result = simulate_mde( + SyntheticDiD(), + treatment_fraction=0.3, + n_units=50, + n_periods=6, + treatment_period=3, + n_simulations=5, + effect_range=(0.5, 3.0), + seed=42, + progress=False, + ) + assert isinstance(result, SimulationMDEResults) + assert result.mde > 0 + + @pytest.mark.slow + def test_sdid_sample_size(self): + """simulate_sample_size works for SyntheticDiD with valid fraction.""" + result = simulate_sample_size( + SyntheticDiD(), + treatment_fraction=0.3, + n_periods=6, + treatment_period=3, + n_simulations=5, + n_range=(30, 80), + seed=42, + progress=False, + ) + assert isinstance(result, SimulationSampleSizeResults) + assert result.required_n > 0 # --------------------------------------------------------------------------- From bbbacc0a91f0fb6ce445c6ac00c4872ecb696399 Mon Sep 17 00:00:00 2001 From: igerber Date: Wed, 18 Mar 2026 14:20:48 -0400 Subject: [PATCH 06/14] Address PR #208 review: add TWFE to registry, fix unregistered-estimator fallback Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/power.py | 23 ++++++++++-- tests/test_power.py | 91 ++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 101 insertions(+), 13 deletions(-) diff --git a/diff_diff/power.py b/diff_diff/power.py index ba1be6cd..efe33e5c 100644 --- a/diff_diff/power.py +++ b/diff_diff/power.py @@ -136,6 +136,15 @@ def _basic_fit_kwargs( return dict(outcome="outcome", treatment="treated", time="post") +def _twfe_fit_kwargs( + data: pd.DataFrame, + n_units: int, + n_periods: int, + treatment_period: int, +) -> Dict[str, Any]: + return dict(outcome="outcome", treatment="treated", time="post", unit="unit") + + def _multiperiod_fit_kwargs( data: pd.DataFrame, n_units: int, @@ -255,6 +264,13 @@ def _get_registry() -> Dict[str, _EstimatorProfile]: result_extractor=_extract_simple, min_n=20, ), + "TwoWayFixedEffects": _EstimatorProfile( + default_dgp=generate_did_data, + dgp_kwargs_builder=_basic_dgp_kwargs, + fit_kwargs_builder=_twfe_fit_kwargs, + result_extractor=_extract_simple, + min_n=20, + ), "MultiPeriodDiD": _EstimatorProfile( default_dgp=generate_did_data, dgp_kwargs_builder=_basic_dgp_kwargs, @@ -1303,7 +1319,9 @@ def simulate_power( if profile is None and data_generator is None: raise ValueError( f"Estimator '{estimator_name}' not in registry. " - f"Provide a custom data_generator and estimator_kwargs." + f"Provide a custom data_generator and estimator_kwargs " + f"(the full dict of keyword arguments for estimator.fit(), " + f"e.g. dict(outcome='y', treatment='treat', time='period'))." ) # When a custom data_generator is provided, bypass registry DGP @@ -1414,8 +1432,7 @@ def simulate_power( ) fit_kwargs.update(est_kwargs) else: - fit_kwargs = dict(outcome="outcome", treatment="treated", time="post") - fit_kwargs.update(est_kwargs) + fit_kwargs = dict(est_kwargs) result = estimator.fit(data, **fit_kwargs) diff --git a/tests/test_power.py b/tests/test_power.py index a10fad05..e5e3e1b6 100644 --- a/tests/test_power.py +++ b/tests/test_power.py @@ -44,6 +44,7 @@ _staggered_fit_kwargs, _trop_fit_kwargs, ) +from diff_diff.prep import generate_did_data class TestPowerAnalysis: @@ -800,17 +801,10 @@ def test_continuous_did_not_in_registry(self): progress=False, ) - def test_twfe_not_in_registry(self): - """TwoWayFixedEffects is not in registry and raises without custom data_generator.""" + def test_twfe_in_registry(self): + """TwoWayFixedEffects is in the registry.""" registry = _get_registry() - assert "TwoWayFixedEffects" not in registry - - with pytest.raises(ValueError, match="not in registry"): - simulate_power( - TwoWayFixedEffects(), - n_simulations=5, - progress=False, - ) + assert "TwoWayFixedEffects" in registry def test_unknown_estimator_raises_without_data_generator(self): """Unknown estimator without data_generator raises ValueError.""" @@ -1066,6 +1060,83 @@ def test_sdid_sample_size(self): assert isinstance(result, SimulationSampleSizeResults) assert result.required_n > 0 + @pytest.mark.slow + def test_twfe(self): + result = simulate_power( + TwoWayFixedEffects(), + n_simulations=5, + seed=42, + progress=False, + ) + self._assert_valid_result(result, "TwoWayFixedEffects") + + @pytest.mark.slow + def test_twfe_mde(self): + result = simulate_mde( + TwoWayFixedEffects(), + n_simulations=5, + effect_range=(0.5, 5.0), + seed=42, + progress=False, + ) + assert isinstance(result, SimulationMDEResults) + assert result.mde > 0 + + @pytest.mark.slow + def test_twfe_sample_size(self): + result = simulate_sample_size( + TwoWayFixedEffects(), + n_simulations=5, + n_range=(20, 100), + seed=42, + progress=False, + ) + assert isinstance(result, SimulationSampleSizeResults) + assert result.required_n > 0 + + @pytest.mark.slow + def test_custom_fallback_unregistered_estimator(self): + """Unregistered estimator works with custom data_generator and estimator_kwargs.""" + + class _UnregisteredEstimator: + """Unregistered wrapper for testing custom fallback.""" + + def __init__(self): + self._inner = DifferenceInDifferences() + + def fit(self, data, **kwargs): + return self._inner.fit(data, **kwargs) + + result = simulate_power( + _UnregisteredEstimator(), + data_generator=generate_did_data, + estimator_kwargs=dict(outcome="outcome", treatment="treated", time="post"), + n_simulations=5, + seed=42, + progress=False, + ) + assert 0 <= result.power <= 1 + assert result.n_simulations > 0 + + def test_custom_fallback_missing_kwargs_raises(self): + """Unregistered estimator with no estimator_kwargs fails on fit.""" + + class _UnregisteredEstimator: + def __init__(self): + self._inner = DifferenceInDifferences() + + def fit(self, data, **kwargs): + return self._inner.fit(data, **kwargs) + + with pytest.raises((ValueError, TypeError, RuntimeError)): + simulate_power( + _UnregisteredEstimator(), + data_generator=generate_did_data, + n_simulations=5, + seed=42, + progress=False, + ) + # --------------------------------------------------------------------------- # simulate_mde tests From d1f2316d13d8258a26b879fc1c80ffa6ff589e67 Mon Sep 17 00:00:00 2001 From: igerber Date: Wed, 18 Mar 2026 14:54:15 -0400 Subject: [PATCH 07/14] Add result_extractor param, fix stale power.rst docs, and add missing API symbols - Add `result_extractor` parameter to simulate_power, simulate_mde, and simulate_sample_size for unregistered estimators with non-standard schemas - Fix power.rst: correct PowerAnalysis method names, example code, and add SimulationMDEResults/SimulationSampleSizeResults/simulate_mde/simulate_sample_size - Add 4 missing symbols to docs/api/index.rst autosummary - Add api/power.rst to doc snippet smoke tests - Add tests for custom result_extractor and MDE forwarding Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/power.py | 18 +++++++ docs/api/index.rst | 4 ++ docs/api/power.rst | 91 +++++++++++++++++++++----------- tests/test_doc_snippets.py | 103 +++++++++++++++++++------------------ tests/test_power.py | 53 +++++++++++++++++++ 5 files changed, 189 insertions(+), 80 deletions(-) diff --git a/diff_diff/power.py b/diff_diff/power.py index efe33e5c..5aa8c93f 100644 --- a/diff_diff/power.py +++ b/diff_diff/power.py @@ -1214,6 +1214,7 @@ def simulate_power( data_generator: Optional[Callable] = None, data_generator_kwargs: Optional[Dict[str, Any]] = None, estimator_kwargs: Optional[Dict[str, Any]] = None, + result_extractor: Optional[Callable] = None, progress: bool = True, ) -> SimulationPowerResults: """ @@ -1257,6 +1258,11 @@ def simulate_power( Additional keyword arguments for data generator. estimator_kwargs : dict, optional Additional keyword arguments for estimator.fit(). + result_extractor : callable, optional + Custom function to extract results from the estimator output. + Takes the estimator result object and returns a tuple of + ``(att, se, p_value, conf_int)``. Useful for unregistered + estimators with non-standard result schemas. progress : bool, default=True Whether to print progress updates. @@ -1439,6 +1445,8 @@ def simulate_power( # --- Extract results --- if profile is not None: att, se, p_val, ci = profile.result_extractor(result) + elif result_extractor is not None: + att, se, p_val, ci = result_extractor(result) else: att = result.att if hasattr(result, "att") else result.avg_att se = result.se if hasattr(result, "se") else result.avg_se @@ -1717,6 +1725,7 @@ def simulate_mde( data_generator: Optional[Callable] = None, data_generator_kwargs: Optional[Dict[str, Any]] = None, estimator_kwargs: Optional[Dict[str, Any]] = None, + result_extractor: Optional[Callable] = None, progress: bool = True, ) -> SimulationMDEResults: """ @@ -1759,6 +1768,9 @@ def simulate_mde( Additional keyword arguments for data generator. estimator_kwargs : dict, optional Additional keyword arguments for estimator.fit(). + result_extractor : callable, optional + Custom function to extract results from the estimator output. + Forwarded to ``simulate_power()``. progress : bool, default=True Whether to print progress updates. @@ -1789,6 +1801,7 @@ def simulate_mde( data_generator=data_generator, data_generator_kwargs=data_generator_kwargs, estimator_kwargs=estimator_kwargs, + result_extractor=result_extractor, progress=False, ) @@ -1911,6 +1924,7 @@ def simulate_sample_size( data_generator: Optional[Callable] = None, data_generator_kwargs: Optional[Dict[str, Any]] = None, estimator_kwargs: Optional[Dict[str, Any]] = None, + result_extractor: Optional[Callable] = None, progress: bool = True, ) -> SimulationSampleSizeResults: """ @@ -1951,6 +1965,9 @@ def simulate_sample_size( Additional keyword arguments for data generator. estimator_kwargs : dict, optional Additional keyword arguments for estimator.fit(). + result_extractor : callable, optional + Custom function to extract results from the estimator output. + Forwarded to ``simulate_power()``. progress : bool, default=True Whether to print progress updates. @@ -1988,6 +2005,7 @@ def simulate_sample_size( data_generator=data_generator, data_generator_kwargs=data_generator_kwargs, estimator_kwargs=estimator_kwargs, + result_extractor=result_extractor, progress=False, ) diff --git a/docs/api/index.rst b/docs/api/index.rst index d139b768..f1532991 100644 --- a/docs/api/index.rst +++ b/docs/api/index.rst @@ -148,10 +148,14 @@ Power analysis for study design: diff_diff.PowerAnalysis diff_diff.PowerResults diff_diff.SimulationPowerResults + diff_diff.SimulationMDEResults + diff_diff.SimulationSampleSizeResults diff_diff.compute_power diff_diff.compute_mde diff_diff.compute_sample_size diff_diff.simulate_power + diff_diff.simulate_mde + diff_diff.simulate_sample_size Pre-Trends Power Analysis ------------------------- diff --git a/docs/api/power.rst b/docs/api/power.rst index 0e17b75c..52d57c39 100644 --- a/docs/api/power.rst +++ b/docs/api/power.rst @@ -30,9 +30,11 @@ Main class for analytical power calculations. .. autosummary:: - ~PowerAnalysis.compute_power - ~PowerAnalysis.compute_mde - ~PowerAnalysis.compute_sample_size + ~PowerAnalysis.power + ~PowerAnalysis.mde + ~PowerAnalysis.sample_size + ~PowerAnalysis.power_curve + ~PowerAnalysis.sample_size_curve Example ~~~~~~~ @@ -41,29 +43,19 @@ Example from diff_diff import PowerAnalysis - # Create power analysis object - pa = PowerAnalysis( - effect_size=0.5, - n_treated=100, - n_control=100, - n_pre=4, - n_post=4, - sigma=1.0, - rho=0.5, # Within-unit correlation - alpha=0.05 - ) + pa = PowerAnalysis(alpha=0.05, power=0.80) # Compute power - power = pa.compute_power() - print(f"Power: {power:.2%}") + result = pa.power(effect_size=0.5, n_treated=100, n_control=100, sigma=1.0) + print(f"Power: {result.power:.2%}") # Compute MDE at 80% power - mde = pa.compute_mde(power=0.80) - print(f"MDE: {mde:.3f}") + result = pa.mde(n_treated=100, n_control=100, sigma=1.0) + print(f"MDE: {result.mde:.3f}") # Required sample size - n = pa.compute_sample_size(power=0.80) - print(f"Required N per group: {n}") + result = pa.sample_size(effect_size=0.5, sigma=1.0) + print(f"Required N: {result.required_n}") PowerResults ------------ @@ -85,6 +77,26 @@ Results from simulation-based power analysis. :undoc-members: :show-inheritance: +SimulationMDEResults +-------------------- + +Results from simulation-based MDE search. + +.. autoclass:: diff_diff.SimulationMDEResults + :members: + :undoc-members: + :show-inheritance: + +SimulationSampleSizeResults +--------------------------- + +Results from simulation-based sample size search. + +.. autoclass:: diff_diff.SimulationSampleSizeResults + :members: + :undoc-members: + :show-inheritance: + Convenience Functions --------------------- @@ -116,6 +128,20 @@ Simulation-based power for any DiD estimator. .. autofunction:: diff_diff.simulate_power +simulate_mde +~~~~~~~~~~~~~ + +Simulation-based MDE for any DiD estimator. + +.. autofunction:: diff_diff.simulate_mde + +simulate_sample_size +~~~~~~~~~~~~~~~~~~~~ + +Simulation-based sample size for any DiD estimator. + +.. autofunction:: diff_diff.simulate_sample_size + Complete Example ---------------- @@ -125,8 +151,8 @@ Complete Example PowerAnalysis, compute_mde, simulate_power, + simulate_mde, DifferenceInDifferences, - plot_power_curve, ) # Quick MDE calculation @@ -145,20 +171,23 @@ Complete Example # Simulation-based power for DiD estimator sim_results = simulate_power( estimator=DifferenceInDifferences(), - effect_size=0.5, - n_treated=100, - n_control=100, - n_periods=8, - treatment_start=4, + treatment_effect=5.0, + n_units=100, + n_periods=4, + treatment_period=2, sigma=1.0, - n_simulations=1000 + n_simulations=20, ) print(f"Simulated power: {sim_results.power:.2%}") - # Power curve - pa = PowerAnalysis(n_treated=100, n_control=100, n_pre=4, n_post=4, sigma=1.0) - ax = plot_power_curve(pa, effect_range=(0, 1), n_points=50) - ax.figure.savefig('power_curve.png') + # Simulation-based MDE + mde_results = simulate_mde( + estimator=DifferenceInDifferences(), + n_units=100, + n_simulations=10, + max_steps=5, + ) + print(f"Simulated MDE: {mde_results.mde:.3f}") See Also -------- diff --git a/tests/test_doc_snippets.py b/tests/test_doc_snippets.py index 68b9ab42..65f62c50 100644 --- a/tests/test_doc_snippets.py +++ b/tests/test_doc_snippets.py @@ -36,6 +36,7 @@ "api/visualization.rst", "api/honest_did.rst", "api/pretrends.rst", + "api/power.rst", "python_comparison.rst", "r_comparison.rst", ] @@ -62,12 +63,8 @@ ) # Heuristic: skip ``::`` blocks that look like shell or prose, not Python. -_SHELL_HINTS_RE = re.compile( - r"^\s*(\$\s|#!|pip\s+install|maturin\s)", re.MULTILINE -) -_PROSE_HINT_RE = re.compile( - r"^[A-Z][a-z]+ [a-z]+ [a-z]+", re.MULTILINE # English prose sentence -) +_SHELL_HINTS_RE = re.compile(r"^\s*(\$\s|#!|pip\s+install|maturin\s)", re.MULTILINE) +_PROSE_HINT_RE = re.compile(r"^[A-Z][a-z]+ [a-z]+ [a-z]+", re.MULTILINE) # English prose sentence def _extract_snippets(rst_path: Path) -> List[Tuple[int, str]]: @@ -140,6 +137,7 @@ def _collect_cases() -> List[Tuple[str, str, Optional[str]]]: _CASES = _collect_cases() + # --------------------------------------------------------------------------- # Shared namespace builder # --------------------------------------------------------------------------- @@ -167,9 +165,7 @@ def _build_namespace() -> dict: # Synthetic datasets that doc snippets commonly reference rng = np.random.default_rng(42) - staggered = diff_diff.generate_staggered_data( - n_units=60, n_periods=10, seed=42 - ) + staggered = diff_diff.generate_staggered_data(n_units=60, n_periods=10, seed=42) # Add alias columns that doc snippets expect # Use a simple time split (not unit-specific) so basic 2x2 DID works mid = staggered["period"].median() @@ -219,16 +215,18 @@ def _build_namespace() -> dict: # ------------------------------------------------------------------ def _mock_load_card_krueger(**kwargs): n = 40 - return pd.DataFrame({ - "store_id": range(n), - "state": ["NJ"] * (n // 2) + ["PA"] * (n // 2), - "chain": (["bk", "kfc", "roys", "wendys"] * 10)[:n], - "emp_pre": rng.normal(20, 5, n), - "emp_post": rng.normal(21, 5, n), - "wage_pre": rng.normal(4.5, 0.3, n), - "wage_post": rng.normal(5.0, 0.3, n), - "treated": [1] * (n // 2) + [0] * (n // 2), - }) + return pd.DataFrame( + { + "store_id": range(n), + "state": ["NJ"] * (n // 2) + ["PA"] * (n // 2), + "chain": (["bk", "kfc", "roys", "wendys"] * 10)[:n], + "emp_pre": rng.normal(20, 5, n), + "emp_post": rng.normal(21, 5, n), + "wage_pre": rng.normal(4.5, 0.3, n), + "wage_post": rng.normal(5.0, 0.3, n), + "treated": [1] * (n // 2) + [0] * (n // 2), + } + ) def _mock_load_castle_doctrine(**kwargs): states = [f"S{i:02d}" for i in range(10)] @@ -236,17 +234,18 @@ def _mock_load_castle_doctrine(**kwargs): rows = [(s, y) for s in states for y in years] n = len(rows) ft = [0] * 55 + [2005] * 22 + [2007] * 22 + [2009] * 11 - return pd.DataFrame({ - "state": [r[0] for r in rows], - "year": [r[1] for r in rows], - "first_treat": ft[:n], - "homicide_rate": rng.normal(5, 1, n), - "population": rng.integers(500000, 5000000, n), - "income": rng.normal(30000, 5000, n), - "treated": [1 if ft[i] and r[1] >= ft[i] else 0 - for i, r in enumerate(rows)][:n], - "cohort": ft[:n], - }) + return pd.DataFrame( + { + "state": [r[0] for r in rows], + "year": [r[1] for r in rows], + "first_treat": ft[:n], + "homicide_rate": rng.normal(5, 1, n), + "population": rng.integers(500000, 5000000, n), + "income": rng.normal(30000, 5000, n), + "treated": [1 if ft[i] and r[1] >= ft[i] else 0 for i, r in enumerate(rows)][:n], + "cohort": ft[:n], + } + ) def _mock_load_divorce_laws(**kwargs): states = [f"S{i:02d}" for i in range(10)] @@ -254,17 +253,18 @@ def _mock_load_divorce_laws(**kwargs): rows = [(s, y) for s in states for y in years] n = len(rows) ft = [0] * 125 + [1970] * 50 + [1975] * 50 + [1980] * 25 - return pd.DataFrame({ - "state": [r[0] for r in rows], - "year": [r[1] for r in rows], - "first_treat": ft[:n], - "divorce_rate": rng.normal(4, 1, n), - "female_lfp": rng.normal(50, 5, n), - "suicide_rate": rng.normal(5, 2, n), - "treated": [1 if ft[i] and r[1] >= ft[i] else 0 - for i, r in enumerate(rows)][:n], - "cohort": ft[:n], - }) + return pd.DataFrame( + { + "state": [r[0] for r in rows], + "year": [r[1] for r in rows], + "first_treat": ft[:n], + "divorce_rate": rng.normal(4, 1, n), + "female_lfp": rng.normal(50, 5, n), + "suicide_rate": rng.normal(5, 2, n), + "treated": [1 if ft[i] and r[1] >= ft[i] else 0 for i, r in enumerate(rows)][:n], + "cohort": ft[:n], + } + ) def _mock_load_mpdta(**kwargs): counties = list(range(1, 21)) @@ -272,14 +272,16 @@ def _mock_load_mpdta(**kwargs): rows = [(c, y) for c in counties for y in years] n = len(rows) ft = ([0] * 25 + [2004] * 25 + [2006] * 25 + [2007] * 25)[:n] - return pd.DataFrame({ - "countyreal": [r[0] for r in rows], - "year": [r[1] for r in rows], - "lpop": rng.normal(10, 1, n), - "lemp": rng.normal(8, 0.5, n), - "first_treat": ft, - "treat": [1 if f != 0 else 0 for f in ft], - }) + return pd.DataFrame( + { + "countyreal": [r[0] for r in rows], + "year": [r[1] for r in rows], + "lpop": rng.normal(10, 1, n), + "lemp": rng.normal(8, 0.5, n), + "first_treat": ft, + "treat": [1 if f != 0 else 0 for f in ft], + } + ) _dataset_dispatch = { "card_krueger": _mock_load_card_krueger, @@ -303,6 +305,7 @@ def _mock_list_datasets(): # Inject mocks into namespace so `from diff_diff.datasets import ...` works import types + mock_datasets_mod = types.ModuleType("diff_diff.datasets") mock_datasets_mod.load_card_krueger = _mock_load_card_krueger mock_datasets_mod.load_castle_doctrine = _mock_load_castle_doctrine @@ -311,6 +314,7 @@ def _mock_list_datasets(): mock_datasets_mod.load_dataset = _mock_load_dataset mock_datasets_mod.list_datasets = _mock_list_datasets import sys + sys.modules["diff_diff.datasets"] = mock_datasets_mod diff_diff.datasets = mock_datasets_mod @@ -333,6 +337,7 @@ def _restore_datasets_module(): """Restore diff_diff.datasets after each test to prevent mock leaking.""" import sys as _sys import diff_diff as _dd + orig_mod = _sys.modules.get("diff_diff.datasets") orig_attr = getattr(_dd, "datasets", None) yield diff --git a/tests/test_power.py b/tests/test_power.py index e5e3e1b6..4c20a44c 100644 --- a/tests/test_power.py +++ b/tests/test_power.py @@ -1137,6 +1137,59 @@ def fit(self, data, **kwargs): progress=False, ) + @pytest.mark.slow + def test_custom_result_extractor(self): + """Custom result_extractor works for unregistered estimator.""" + + class _UnregisteredEstimator: + def __init__(self): + self._inner = DifferenceInDifferences() + + def fit(self, data, **kwargs): + return self._inner.fit(data, **kwargs) + + def _custom_extractor(result): + return (result.att, result.se, result.p_value, result.conf_int) + + result = simulate_power( + _UnregisteredEstimator(), + data_generator=generate_did_data, + estimator_kwargs=dict(outcome="outcome", treatment="treated", time="post"), + result_extractor=_custom_extractor, + n_simulations=5, + seed=42, + progress=False, + ) + assert 0 <= result.power <= 1 + assert result.n_simulations > 0 + + @pytest.mark.slow + def test_custom_result_extractor_mde_forwarding(self): + """result_extractor forwards correctly through simulate_mde.""" + + class _UnregisteredEstimator: + def __init__(self): + self._inner = DifferenceInDifferences() + + def fit(self, data, **kwargs): + return self._inner.fit(data, **kwargs) + + def _custom_extractor(result): + return (result.att, result.se, result.p_value, result.conf_int) + + result = simulate_mde( + _UnregisteredEstimator(), + data_generator=generate_did_data, + estimator_kwargs=dict(outcome="outcome", treatment="treated", time="post"), + result_extractor=_custom_extractor, + n_simulations=5, + effect_range=(0.5, 5.0), + seed=42, + progress=False, + ) + assert isinstance(result, SimulationMDEResults) + assert result.mde > 0 + # --------------------------------------------------------------------------- # simulate_mde tests From 800d37f224f76b2218d164098a3fdec8d3797e6d Mon Sep 17 00:00:00 2001 From: igerber Date: Wed, 18 Mar 2026 15:59:42 -0400 Subject: [PATCH 08/14] Warn when staggered estimator settings don't match default power DGP Add _check_staggered_dgp_compat() that detects control_group="not_yet_treated", clean_control="strict", and anticipation>0 on staggered estimators using the automatic single-cohort DGP, and emits UserWarning with specific guidance on supplying data_generator_kwargs or a custom data_generator. - 5 warning/no-warning unit tests + 2 slow regression tests with matching DGP - Document limitation in REGISTRY.md PowerAnalysis section Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/power.py | 77 +++++++++++++++++++++++ docs/methodology/REGISTRY.md | 1 + tests/test_power.py | 115 +++++++++++++++++++++++++++++++++++ 3 files changed, 193 insertions(+) diff --git a/diff_diff/power.py b/diff_diff/power.py index 5aa8c93f..1cecc276 100644 --- a/diff_diff/power.py +++ b/diff_diff/power.py @@ -237,6 +237,79 @@ def _first(r: Any, *attrs: str, default: Any = _nan) -> Any: ) +# -- Staggered DGP compatibility check ---------------------------------------- + +_STAGGERED_ESTIMATORS = frozenset( + { + "CallawaySantAnna", + "SunAbraham", + "ImputationDiD", + "TwoStageDiD", + "StackedDiD", + "EfficientDiD", + } +) + + +def _check_staggered_dgp_compat( + estimator: Any, + data_generator_kwargs: Optional[Dict[str, Any]], +) -> None: + """Warn if a staggered estimator's settings don't match the default DGP.""" + name = type(estimator).__name__ + if name not in _STAGGERED_ESTIMATORS: + return + + dgp_overrides = data_generator_kwargs or {} + issues: List[str] = [] + + # Check control_group="not_yet_treated" (CS, SA) + cg = getattr(estimator, "control_group", "never_treated") + if cg == "not_yet_treated" and "cohort_periods" not in dgp_overrides: + issues.append( + f' - {name} has control_group="not_yet_treated" but the default ' + f"DGP generates a single treatment cohort with never-treated " + f"controls. Power may not reflect the intended not-yet-treated " + f"design.\n" + f" Fix: pass data_generator_kwargs=" + f'{{"cohort_periods": [2, 4], "never_treated_frac": 0.0}} ' + f"(or a custom data_generator)." + ) + + # Check anticipation > 0 (all staggered) + antic = getattr(estimator, "anticipation", 0) + if antic > 0: + issues.append( + f" - {name} has anticipation={antic} but the default DGP does " + f"not model anticipatory effects. The estimator will look for " + f"treatment effects {antic} period(s) before the DGP generates " + f"them, biasing power estimates.\n" + f" Fix: supply a custom data_generator that shifts the " + f"effect onset." + ) + + # Check clean_control on StackedDiD + if name == "StackedDiD": + cc = getattr(estimator, "clean_control", "not_yet_treated") + if cc == "strict" and "cohort_periods" not in dgp_overrides: + issues.append( + ' - StackedDiD has clean_control="strict" but the default ' + "single-cohort DGP makes strict controls equivalent to " + "never-treated controls.\n" + " Fix: pass data_generator_kwargs=" + '{"cohort_periods": [2, 4]} ' + "to test true strict clean-control behavior." + ) + + if issues: + msg = ( + f"Staggered power DGP mismatch for {name}. The default " + f"single-cohort DGP may not match the estimator " + f"configuration:\n" + "\n".join(issues) + ) + warnings.warn(msg, UserWarning, stacklevel=2) + + # -- Registry construction (deferred to avoid import-time cost) --------------- _ESTIMATOR_REGISTRY: Optional[Dict[str, _EstimatorProfile]] = None @@ -1351,6 +1424,10 @@ def simulate_power( data_gen_kwargs = data_generator_kwargs or {} est_kwargs = estimator_kwargs or {} + # Warn if staggered estimator settings don't match auto DGP + if profile is not None and not use_custom_dgp: + _check_staggered_dgp_compat(estimator, data_generator_kwargs) + # Determine effect sizes to test if effect_sizes is None: effect_sizes = [treatment_effect] diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 8eee4fac..fc7339cf 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -1712,6 +1712,7 @@ n = 2(t_{α/2} + t_{1-κ})² σ² / MDE² - Very small effects: may require infeasibly large samples - High ICC: dramatically reduces effective sample size - Unequal allocation: optimal is often 50-50 but depends on costs +- **Note:** The simulation-based power registry (`simulate_power`, `simulate_mde`, `simulate_sample_size`) uses a single-cohort staggered DGP by default. Estimators configured with `control_group="not_yet_treated"`, `clean_control="strict"`, or `anticipation>0` will receive a `UserWarning` because the default DGP does not match their identification strategy. Users must supply `data_generator_kwargs` (e.g., `cohort_periods=[2, 4]`, `never_treated_frac=0.0`) or a custom `data_generator` to match the estimator design. **Reference implementation(s):** - R: `pwr` package (general), `DeclareDesign` (simulation-based) diff --git a/tests/test_power.py b/tests/test_power.py index 4c20a44c..e48291eb 100644 --- a/tests/test_power.py +++ b/tests/test_power.py @@ -1,5 +1,7 @@ """Tests for power analysis module.""" +import warnings + import numpy as np import pandas as pd import pytest @@ -1190,6 +1192,119 @@ def _custom_extractor(result): assert isinstance(result, SimulationMDEResults) assert result.mde > 0 + # -- Staggered DGP compatibility warnings -- + + def test_staggered_dgp_warns_not_yet_treated(self): + """Auto DGP warns when CS has control_group='not_yet_treated'.""" + with pytest.warns(UserWarning, match="not_yet_treated"): + simulate_power( + CallawaySantAnna(control_group="not_yet_treated"), + n_simulations=3, + seed=42, + progress=False, + ) + + def test_staggered_dgp_warns_anticipation(self): + """Auto DGP warns when staggered estimator has anticipation > 0.""" + with pytest.warns(UserWarning, match="anticipation=1"): + simulate_power( + CallawaySantAnna(anticipation=1), + n_simulations=3, + seed=42, + progress=False, + ) + + def test_staggered_dgp_warns_strict_clean_control(self): + """Auto DGP warns when StackedDiD has clean_control='strict'.""" + with pytest.warns(UserWarning, match="strict"): + simulate_power( + StackedDiD(clean_control="strict"), + n_simulations=3, + seed=42, + progress=False, + ) + + def test_staggered_dgp_no_warn_custom_dgp_bypasses_check(self): + """Custom data_generator bypasses DGP compat check entirely.""" + from diff_diff.prep import generate_staggered_data + + def _custom_staggered(**kwargs): + # Adapt simulate_power's standard kwargs to generate_staggered_data + return generate_staggered_data( + n_units=kwargs["n_units"], + n_periods=kwargs["n_periods"], + treatment_effect=kwargs["treatment_effect"], + cohort_periods=[2, 4], + never_treated_frac=0.0, + noise_sd=kwargs["noise_sd"], + seed=kwargs["seed"], + ) + + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + simulate_power( + CallawaySantAnna(control_group="not_yet_treated"), + data_generator=_custom_staggered, + n_periods=6, + treatment_period=3, + estimator_kwargs=dict( + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ), + n_simulations=3, + seed=42, + progress=False, + ) + + def test_staggered_dgp_no_warn_with_dgp_kwargs_override(self): + """data_generator_kwargs with cohort_periods suppresses warning.""" + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + result = simulate_power( + CallawaySantAnna(control_group="not_yet_treated"), + n_periods=6, + treatment_period=3, + data_generator_kwargs=dict(cohort_periods=[2, 4], never_treated_frac=0.0), + n_simulations=3, + seed=42, + progress=False, + ) + assert 0 <= result.power <= 1 + + @pytest.mark.slow + def test_cs_not_yet_treated_with_matching_dgp(self): + """CS with control_group='not_yet_treated' and multi-cohort DGP.""" + result = simulate_power( + CallawaySantAnna(control_group="not_yet_treated"), + n_units=60, + n_periods=6, + treatment_period=3, + data_generator_kwargs=dict(cohort_periods=[2, 4], never_treated_frac=0.0), + n_simulations=10, + seed=42, + progress=False, + ) + assert 0 <= result.power <= 1 + assert result.n_simulations > 0 + + @pytest.mark.slow + def test_stacked_did_strict_with_matching_dgp(self): + """StackedDiD with clean_control='strict' and multi-cohort DGP.""" + result = simulate_power( + StackedDiD(clean_control="strict", kappa_pre=1, kappa_post=1), + n_units=80, + n_periods=8, + treatment_period=4, + data_generator_kwargs=dict(cohort_periods=[3, 5]), + n_simulations=10, + seed=42, + progress=False, + ) + assert 0 <= result.power <= 1 + assert result.n_simulations > 0 + # --------------------------------------------------------------------------- # simulate_mde tests From 9e18a3ea05f874da0a9bc63f1f66fc0eb5f2f343 Mon Sep 17 00:00:00 2001 From: igerber Date: Wed, 18 Mar 2026 16:17:28 -0400 Subject: [PATCH 09/14] Fix SDID placebo check to account for data_generator_kwargs n_treated override Move the SyntheticDiD placebo-variance fail-fast after kwargs merge so that data_generator_kwargs={"n_treated": N} overrides are caught. Add tests for the override path across simulate_power, simulate_mde, and simulate_sample_size. Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/power.py | 23 +++++++++++++---------- tests/test_power.py | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 10 deletions(-) diff --git a/diff_diff/power.py b/diff_diff/power.py index 1cecc276..38d5cb40 100644 --- a/diff_diff/power.py +++ b/diff_diff/power.py @@ -1406,24 +1406,27 @@ def simulate_power( # When a custom data_generator is provided, bypass registry DGP use_custom_dgp = data_generator is not None - # SyntheticDiD placebo variance requires n_control > n_treated + data_gen_kwargs = data_generator_kwargs or {} + est_kwargs = estimator_kwargs or {} + + # SyntheticDiD placebo variance requires n_control > n_treated. + # Check after merging data_generator_kwargs so overrides of n_treated + # are accounted for. if estimator_name == "SyntheticDiD" and not use_custom_dgp: vm = getattr(estimator, "variance_method", "placebo") - n_treated = max(1, int(n_units * treatment_fraction)) - n_control = n_units - n_treated - if vm == "placebo" and n_control <= n_treated: + effective_n_treated = data_gen_kwargs.get( + "n_treated", max(1, int(n_units * treatment_fraction)) + ) + n_control = n_units - effective_n_treated + if vm == "placebo" and n_control <= effective_n_treated: raise ValueError( f"SyntheticDiD placebo variance requires more control than " f"treated units (got n_control={n_control}, " - f"n_treated={n_treated} from treatment_fraction=" - f"{treatment_fraction}). Either lower treatment_fraction " - f"so that n_control > n_treated, or use " + f"n_treated={effective_n_treated}). Either lower " + f"treatment_fraction so that n_control > n_treated, or use " f"SyntheticDiD(variance_method='bootstrap')." ) - data_gen_kwargs = data_generator_kwargs or {} - est_kwargs = estimator_kwargs or {} - # Warn if staggered estimator settings don't match auto DGP if profile is not None and not use_custom_dgp: _check_staggered_dgp_compat(estimator, data_generator_kwargs) diff --git a/tests/test_power.py b/tests/test_power.py index e48291eb..c22c33c9 100644 --- a/tests/test_power.py +++ b/tests/test_power.py @@ -1029,6 +1029,45 @@ def test_sdid_sample_size_rejects_high_fraction(self): progress=False, ) + def test_sdid_placebo_rejects_n_treated_override(self): + """SDID placebo raises when data_generator_kwargs overrides n_treated.""" + with pytest.raises(ValueError, match="placebo variance requires more control"): + simulate_power( + SyntheticDiD(), + n_units=50, + treatment_fraction=0.3, + data_generator_kwargs=dict(n_treated=30), + n_simulations=5, + seed=42, + progress=False, + ) + + def test_sdid_mde_rejects_n_treated_override(self): + """simulate_mde raises when kwargs override makes n_control <= n_treated.""" + with pytest.raises(ValueError, match="placebo variance requires more control"): + simulate_mde( + SyntheticDiD(), + n_units=50, + treatment_fraction=0.3, + data_generator_kwargs=dict(n_treated=30), + n_simulations=5, + seed=42, + progress=False, + ) + + def test_sdid_sample_size_rejects_n_treated_override(self): + """simulate_sample_size raises when kwargs override is infeasible.""" + with pytest.raises(ValueError, match="placebo variance requires more control"): + simulate_sample_size( + SyntheticDiD(), + treatment_fraction=0.3, + data_generator_kwargs=dict(n_treated=30), + n_range=(50, 100), + n_simulations=5, + seed=42, + progress=False, + ) + @pytest.mark.slow def test_sdid_mde(self): """simulate_mde works for SyntheticDiD with valid treatment_fraction.""" From 0e912bbed172b3958c4e87088e35665c9a0b17c5 Mon Sep 17 00:00:00 2001 From: igerber Date: Wed, 18 Mar 2026 16:53:49 -0400 Subject: [PATCH 10/14] =?UTF-8?q?Warn=20when=20TripleDifference=20power=20?= =?UTF-8?q?params=20don't=20match=20fixed=202=C3=972=C3=972=20DGP?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Opus 4.6 (1M context) --- TODO.md | 1 + diff_diff/power.py | 61 +++++++++++++++++ docs/methodology/REGISTRY.md | 1 + tests/test_power.py | 128 +++++++++++++++++++++++++++++++++++ 4 files changed, 191 insertions(+) diff --git a/TODO.md b/TODO.md index fe779df5..01f6a33f 100644 --- a/TODO.md +++ b/TODO.md @@ -47,6 +47,7 @@ Deferred items from PR reviews that were not addressed before merge. | Bootstrap NaN-gating gap: manual SE/CI/p-value without non-finite filtering or SE<=0 guard | `imputation_bootstrap.py`, `two_stage_bootstrap.py` | #177 | Medium — migrate to `compute_effect_bootstrap_stats` from `bootstrap_utils.py` | | EfficientDiD: warn when cohort share is very small (< 2 units or < 1% of sample) — inverted in Omega*/EIF | `efficient_did_weights.py` | #192 | Low | | EfficientDiD: API docs / tutorial page for new public estimator | `docs/` | #192 | Medium | +| TripleDifference power: `generate_ddd_data` is a fixed 2×2×2 cross-sectional DGP — no multi-period or unbalanced-group support. Add a `generate_ddd_panel_data` for panel DDD power analysis. | `prep_dgp.py`, `power.py` | #208 | Low | #### Performance diff --git a/diff_diff/power.py b/diff_diff/power.py index 38d5cb40..e798e120 100644 --- a/diff_diff/power.py +++ b/diff_diff/power.py @@ -310,6 +310,57 @@ def _check_staggered_dgp_compat( warnings.warn(msg, UserWarning, stacklevel=2) +def _check_ddd_dgp_compat( + n_units: int, + n_periods: int, + treatment_fraction: float, + treatment_period: int, + data_generator_kwargs: Optional[Dict[str, Any]], +) -> None: + """Warn when simulation inputs don't match DDD's fixed 2×2×2 design.""" + overrides = data_generator_kwargs or {} + issues: List[str] = [] + + # DDD is a fixed 2-period factorial; n_periods and treatment_period are ignored + if n_periods != 2 and "n_per_cell" not in overrides: + issues.append( + f"n_periods={n_periods} is ignored (DDD uses a fixed " f"2-period design: pre/post)" + ) + if treatment_period != 1 and "n_per_cell" not in overrides: + issues.append( + f"treatment_period={treatment_period} is ignored (DDD " + f"always treats in the second period)" + ) + + # DDD's 2×2×2 factorial has inherent 50% treatment fraction + if treatment_fraction != 0.5 and "n_per_cell" not in overrides: + issues.append( + f"treatment_fraction={treatment_fraction} is ignored " + f"(DDD uses a balanced 2×2×2 factorial where 50% of " + f"groups are treated)" + ) + + # n_units rounding: n_per_cell = max(2, n_units // 8) + effective_n_per_cell = overrides.get("n_per_cell", max(2, n_units // 8)) + effective_n = effective_n_per_cell * 8 + if effective_n != n_units: + issues.append( + f"effective sample size is {effective_n} " + f"(n_per_cell={effective_n_per_cell} × 8 cells), " + f"not the requested n_units={n_units}" + ) + + if issues: + warnings.warn( + "TripleDifference uses a fixed 2×2×2 factorial DGP " + "(group × partition × time). " + + "; ".join(issues) + + ". Pass a custom data_generator for non-standard DDD designs.", + UserWarning, + stacklevel=2, + ) + + # -- Registry construction (deferred to avoid import-time cost) --------------- _ESTIMATOR_REGISTRY: Optional[Dict[str, _EstimatorProfile]] = None @@ -1431,6 +1482,16 @@ def simulate_power( if profile is not None and not use_custom_dgp: _check_staggered_dgp_compat(estimator, data_generator_kwargs) + # Warn if DDD design inputs are silently ignored + if estimator_name == "TripleDifference" and not use_custom_dgp: + _check_ddd_dgp_compat( + n_units, + n_periods, + treatment_fraction, + treatment_period, + data_generator_kwargs, + ) + # Determine effect sizes to test if effect_sizes is None: effect_sizes = [treatment_effect] diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index fc7339cf..6345731f 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -1713,6 +1713,7 @@ n = 2(t_{α/2} + t_{1-κ})² σ² / MDE² - High ICC: dramatically reduces effective sample size - Unequal allocation: optimal is often 50-50 but depends on costs - **Note:** The simulation-based power registry (`simulate_power`, `simulate_mde`, `simulate_sample_size`) uses a single-cohort staggered DGP by default. Estimators configured with `control_group="not_yet_treated"`, `clean_control="strict"`, or `anticipation>0` will receive a `UserWarning` because the default DGP does not match their identification strategy. Users must supply `data_generator_kwargs` (e.g., `cohort_periods=[2, 4]`, `never_treated_frac=0.0`) or a custom `data_generator` to match the estimator design. +- **Note:** The `TripleDifference` registry adapter uses `generate_ddd_data`, a fixed 2×2×2 factorial DGP (group × partition × time). The `n_periods`, `treatment_period`, and `treatment_fraction` parameters are ignored — DDD always simulates 2 periods with balanced groups. `n_units` is mapped to `n_per_cell = max(2, n_units // 8)` (effective total N = `n_per_cell × 8`), so non-multiples of 8 are rounded down and values below 16 are clamped to 16. A `UserWarning` is emitted when simulation inputs differ from the effective DDD design. **Reference implementation(s):** - R: `pwr` package (general), `DeclareDesign` (simulation-based) diff --git a/tests/test_power.py b/tests/test_power.py index c22c33c9..a91e78da 100644 --- a/tests/test_power.py +++ b/tests/test_power.py @@ -931,12 +931,140 @@ def test_triple_difference(self): result = simulate_power( TripleDifference(), n_units=80, + n_periods=2, + treatment_period=1, n_simulations=10, seed=42, progress=False, ) self._assert_valid_result(result, "TripleDifference") + def test_ddd_warns_ignored_params(self): + """TripleDifference warns when simulation params don't match DDD design.""" + with pytest.warns(UserWarning, match="n_periods=6 is ignored"): + simulate_power( + TripleDifference(), + n_units=80, + n_periods=6, + treatment_period=3, + treatment_fraction=0.3, + n_simulations=2, + seed=42, + progress=False, + ) + + def test_ddd_warns_nonaligned_n_units(self): + """TripleDifference warns when n_units doesn't map cleanly to 8 cells.""" + with pytest.warns(UserWarning, match="effective sample size is 64"): + simulate_power( + TripleDifference(), + n_units=65, + n_periods=2, + treatment_period=1, + n_simulations=2, + seed=42, + progress=False, + ) + + def test_ddd_small_n_units_warns(self): + """TripleDifference warns when n_units < 16 (clamped to 16).""" + with pytest.warns(UserWarning, match="effective sample size is 16"): + simulate_power( + TripleDifference(), + n_units=10, + n_periods=2, + treatment_period=1, + n_simulations=2, + seed=42, + progress=False, + ) + + def test_ddd_no_warn_aligned(self): + """No warning when n_units is a multiple of 8 and defaults match DDD.""" + with warnings.catch_warnings(): + warnings.simplefilter("error") + simulate_power( + TripleDifference(), + n_units=80, + n_periods=2, + treatment_period=1, + treatment_fraction=0.5, + n_simulations=2, + seed=42, + progress=False, + ) + + def test_ddd_no_warn_custom_dgp(self): + """Custom data_generator bypasses the DDD compat check.""" + + def custom_dgp(**kwargs): + from diff_diff.prep_dgp import generate_ddd_data + + return generate_ddd_data(n_per_cell=10, seed=kwargs.get("seed")) + + with warnings.catch_warnings(): + warnings.simplefilter("error") + simulate_power( + TripleDifference(), + n_units=65, + n_periods=6, + data_generator=custom_dgp, + estimator_kwargs=dict( + outcome="outcome", + group="group", + partition="partition", + time="time", + ), + n_simulations=2, + seed=42, + progress=False, + ) + + def test_ddd_no_warn_n_per_cell_override(self): + """data_generator_kwargs with n_per_cell suppresses DDD param warnings.""" + with warnings.catch_warnings(): + warnings.simplefilter("error") + simulate_power( + TripleDifference(), + n_units=80, + n_periods=6, + data_generator_kwargs=dict(n_per_cell=10), + n_simulations=2, + seed=42, + progress=False, + ) + + @pytest.mark.slow + def test_ddd_mde(self): + """simulate_mde works for TripleDifference.""" + result = simulate_mde( + TripleDifference(), + n_units=80, + n_periods=2, + treatment_period=1, + n_simulations=5, + effect_range=(0.5, 5.0), + seed=42, + progress=False, + ) + assert isinstance(result, SimulationMDEResults) + assert result.mde > 0 + + @pytest.mark.slow + def test_ddd_sample_size(self): + """simulate_sample_size works for TripleDifference.""" + result = simulate_sample_size( + TripleDifference(), + n_periods=2, + treatment_period=1, + n_simulations=5, + n_range=(64, 200), + seed=42, + progress=False, + ) + assert isinstance(result, SimulationSampleSizeResults) + assert result.required_n > 0 + @pytest.mark.slow def test_trop(self): result = simulate_power( From fa93beae35118378807aacbcecad8c3522477b21 Mon Sep 17 00:00:00 2001 From: igerber Date: Wed, 18 Mar 2026 17:23:30 -0400 Subject: [PATCH 11/14] Fix DDD n_per_cell suppression scope and add sample-size auto-bracket warning Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/power.py | 12 +++++++++--- docs/methodology/REGISTRY.md | 2 +- tests/test_power.py | 38 +++++++++++++++++++++++++----------- 3 files changed, 37 insertions(+), 15 deletions(-) diff --git a/diff_diff/power.py b/diff_diff/power.py index e798e120..a9871cf3 100644 --- a/diff_diff/power.py +++ b/diff_diff/power.py @@ -322,18 +322,18 @@ def _check_ddd_dgp_compat( issues: List[str] = [] # DDD is a fixed 2-period factorial; n_periods and treatment_period are ignored - if n_periods != 2 and "n_per_cell" not in overrides: + if n_periods != 2: issues.append( f"n_periods={n_periods} is ignored (DDD uses a fixed " f"2-period design: pre/post)" ) - if treatment_period != 1 and "n_per_cell" not in overrides: + if treatment_period != 1: issues.append( f"treatment_period={treatment_period} is ignored (DDD " f"always treats in the second period)" ) # DDD's 2×2×2 factorial has inherent 50% treatment fraction - if treatment_fraction != 0.5 and "n_per_cell" not in overrides: + if treatment_fraction != 0.5: issues.append( f"treatment_fraction={treatment_fraction} is ignored " f"(DDD uses a balanced 2×2×2 factorial where 50% of " @@ -2191,6 +2191,12 @@ def _power_at_n(n: int) -> float: lo = min_n power_lo = _power_at_n(lo) if power_lo >= power: + warnings.warn( + f"Power at registry floor n={lo} is {power_lo:.2f} >= " + f"target {power}. No smaller sample sizes were evaluated. " + f"Pass n_range=(lo, hi) to search below this floor.", + UserWarning, + ) return SimulationSampleSizeResults( required_n=lo, power_at_n=power_lo, diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 6345731f..76ceed89 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -1713,7 +1713,7 @@ n = 2(t_{α/2} + t_{1-κ})² σ² / MDE² - High ICC: dramatically reduces effective sample size - Unequal allocation: optimal is often 50-50 but depends on costs - **Note:** The simulation-based power registry (`simulate_power`, `simulate_mde`, `simulate_sample_size`) uses a single-cohort staggered DGP by default. Estimators configured with `control_group="not_yet_treated"`, `clean_control="strict"`, or `anticipation>0` will receive a `UserWarning` because the default DGP does not match their identification strategy. Users must supply `data_generator_kwargs` (e.g., `cohort_periods=[2, 4]`, `never_treated_frac=0.0`) or a custom `data_generator` to match the estimator design. -- **Note:** The `TripleDifference` registry adapter uses `generate_ddd_data`, a fixed 2×2×2 factorial DGP (group × partition × time). The `n_periods`, `treatment_period`, and `treatment_fraction` parameters are ignored — DDD always simulates 2 periods with balanced groups. `n_units` is mapped to `n_per_cell = max(2, n_units // 8)` (effective total N = `n_per_cell × 8`), so non-multiples of 8 are rounded down and values below 16 are clamped to 16. A `UserWarning` is emitted when simulation inputs differ from the effective DDD design. +- **Note:** The `TripleDifference` registry adapter uses `generate_ddd_data`, a fixed 2×2×2 factorial DGP (group × partition × time). The `n_periods`, `treatment_period`, and `treatment_fraction` parameters are ignored — DDD always simulates 2 periods with balanced groups. `n_units` is mapped to `n_per_cell = max(2, n_units // 8)` (effective total N = `n_per_cell × 8`), so non-multiples of 8 are rounded down and values below 16 are clamped to 16. A `UserWarning` is emitted when simulation inputs differ from the effective DDD design. Passing `n_per_cell` in `data_generator_kwargs` suppresses the effective-N rounding warning but not warnings for ignored parameters (`n_periods`, `treatment_period`, `treatment_fraction`). **Reference implementation(s):** - R: `pwr` package (general), `DeclareDesign` (simulation-based) diff --git a/tests/test_power.py b/tests/test_power.py index a91e78da..5008109f 100644 --- a/tests/test_power.py +++ b/tests/test_power.py @@ -1021,13 +1021,28 @@ def custom_dgp(**kwargs): ) def test_ddd_no_warn_n_per_cell_override(self): - """data_generator_kwargs with n_per_cell suppresses DDD param warnings.""" + """n_per_cell override suppresses rounding warning but not ignored-param warnings.""" + with pytest.warns(UserWarning, match="n_periods=6 is ignored"): + simulate_power( + TripleDifference(), + n_units=80, + n_periods=6, + treatment_period=1, + data_generator_kwargs=dict(n_per_cell=10), + n_simulations=2, + seed=42, + progress=False, + ) + + def test_ddd_n_per_cell_suppresses_rounding(self): + """n_per_cell override suppresses effective-N rounding warning.""" with warnings.catch_warnings(): warnings.simplefilter("error") simulate_power( TripleDifference(), n_units=80, - n_periods=6, + n_periods=2, + treatment_period=1, data_generator_kwargs=dict(n_per_cell=10), n_simulations=2, seed=42, @@ -1745,15 +1760,16 @@ def test_lo_already_sufficient_explicit(self): assert result.power_at_n >= 0.80 def test_lo_already_sufficient_auto(self): - """Auto-bracket returns min_n when effect overwhelmingly large.""" - result = simulate_sample_size( - DifferenceInDifferences(), - treatment_effect=50.0, - sigma=0.1, - n_simulations=50, - seed=42, - progress=False, - ) + """Auto-bracket warns and returns min_n when effect overwhelmingly large.""" + with pytest.warns(UserWarning, match="registry floor"): + result = simulate_sample_size( + DifferenceInDifferences(), + treatment_effect=50.0, + sigma=0.1, + n_simulations=50, + seed=42, + progress=False, + ) # min_n for DifferenceInDifferences is 20 assert result.required_n == 20 assert result.power_at_n >= 0.80 From 5d234c519972e4d8f2516ca38b4a1b177fd01a11 Mon Sep 17 00:00:00 2001 From: igerber Date: Fri, 20 Mar 2026 08:31:29 -0400 Subject: [PATCH 12/14] Expose effective_n_units in DDD power results and snap sample-size search to grid Addresses AI review P1/P2: TripleDifference results now report the effective sample size when DDD grid rounding occurs, and simulate_sample_size() bisects on multiples of 8 so required_n is always a realizable DDD design. Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/power.py | 95 +++++++++++++++++++++++++++++++----- docs/methodology/REGISTRY.md | 2 +- tests/test_power.py | 65 ++++++++++++++++++++++++ 3 files changed, 149 insertions(+), 13 deletions(-) diff --git a/diff_diff/power.py b/diff_diff/power.py index a9871cf3..bffc73b8 100644 --- a/diff_diff/power.py +++ b/diff_diff/power.py @@ -310,6 +310,18 @@ def _check_staggered_dgp_compat( warnings.warn(msg, UserWarning, stacklevel=2) +def _ddd_effective_n( + n_units: int, data_generator_kwargs: Optional[Dict[str, Any]] +) -> Optional[int]: + """Return effective DDD sample size, or None if no rounding occurred.""" + overrides = data_generator_kwargs or {} + if "n_per_cell" in overrides: + eff = overrides["n_per_cell"] * 8 + else: + eff = max(2, n_units // 8) * 8 + return eff if eff != n_units else None + + def _check_ddd_dgp_compat( n_units: int, n_periods: int, @@ -318,7 +330,6 @@ def _check_ddd_dgp_compat( data_generator_kwargs: Optional[Dict[str, Any]], ) -> None: """Warn when simulation inputs don't match DDD's fixed 2×2×2 design.""" - overrides = data_generator_kwargs or {} issues: List[str] = [] # DDD is a fixed 2-period factorial; n_periods and treatment_period are ignored @@ -341,12 +352,12 @@ def _check_ddd_dgp_compat( ) # n_units rounding: n_per_cell = max(2, n_units // 8) - effective_n_per_cell = overrides.get("n_per_cell", max(2, n_units // 8)) - effective_n = effective_n_per_cell * 8 - if effective_n != n_units: + eff_n = _ddd_effective_n(n_units, data_generator_kwargs) + if eff_n is not None: + eff_n_per_cell = eff_n // 8 issues.append( - f"effective sample size is {effective_n} " - f"(n_per_cell={effective_n_per_cell} × 8 cells), " + f"effective sample size is {eff_n} " + f"(n_per_cell={eff_n_per_cell} × 8 cells), " f"not the requested n_units={n_units}" ) @@ -648,6 +659,9 @@ class SimulationPowerResults: Significance level. estimator_name : str Name of the estimator used. + effective_n_units : int or None + Effective sample size when it differs from the requested ``n_units`` + (e.g., due to DDD grid rounding). ``None`` when no rounding occurred. """ power: float @@ -667,6 +681,7 @@ class SimulationPowerResults: bias: float = field(init=False) rmse: float = field(init=False) simulation_results: Optional[List[Dict[str, Any]]] = field(default=None, repr=False) + effective_n_units: Optional[int] = None def __post_init__(self): """Compute derived statistics.""" @@ -716,8 +731,12 @@ def summary(self) -> str: f"{'RMSE:':<35} {self.rmse:.4f}", f"{'Mean standard error:':<35} {self.mean_se:.4f}", f"{'Coverage (CI contains true):':<35} {self.coverage:.1%}", - "=" * 65, ] + if self.effective_n_units is not None: + lines.append( + f"{'Effective sample size:':<35} {self.effective_n_units}" f" (DDD grid-rounded)" + ) + lines.append("=" * 65) return "\n".join(lines) def print_summary(self) -> None: @@ -733,7 +752,7 @@ def to_dict(self) -> Dict[str, Any]: Dict[str, Any] Dictionary containing simulation power results. """ - return { + d: Dict[str, Any] = { "power": self.power, "power_se": self.power_se, "power_ci_lower": self.power_ci[0], @@ -749,7 +768,9 @@ def to_dict(self) -> Dict[str, Any]: "true_effect": self.true_effect, "alpha": self.alpha, "estimator_name": self.estimator_name, + "effective_n_units": self.effective_n_units, } + return d def to_dataframe(self) -> pd.DataFrame: """ @@ -1491,6 +1512,9 @@ def simulate_power( treatment_period, data_generator_kwargs, ) + effective_n_units = _ddd_effective_n(n_units, data_generator_kwargs) + else: + effective_n_units = None # Determine effect sizes to test if effect_sizes is None: @@ -1671,6 +1695,7 @@ def simulate_power( primary_rejections, ) ], + effective_n_units=effective_n_units, ) @@ -1704,6 +1729,9 @@ class SimulationMDEResults: Diagnostic trace of ``{effect_size, power}`` at each step. estimator_name : str Name of the estimator used. + effective_n_units : int or None + Effective sample size when it differs from the requested ``n_units`` + (e.g., due to DDD grid rounding). ``None`` when no rounding occurred. """ mde: float @@ -1715,6 +1743,7 @@ class SimulationMDEResults: n_steps: int search_path: List[Dict[str, float]] estimator_name: str + effective_n_units: Optional[int] = None def __repr__(self) -> str: return ( @@ -1734,6 +1763,12 @@ def summary(self) -> str: f"{'Significance level (alpha):':<35} {self.alpha:.3f}", f"{'Target power:':<35} {self.target_power:.1%}", f"{'Sample size (n_units):':<35} {self.n_units}", + ] + if self.effective_n_units is not None: + lines.append( + f"{'Effective sample size:':<35} {self.effective_n_units}" f" (DDD grid-rounded)" + ) + lines += [ f"{'Simulations per step:':<35} {self.n_simulations_per_step}", "", "-" * 65, @@ -1754,6 +1789,7 @@ def to_dict(self) -> Dict[str, Any]: "target_power": self.target_power, "alpha": self.alpha, "n_units": self.n_units, + "effective_n_units": self.effective_n_units, "n_simulations_per_step": self.n_simulations_per_step, "n_steps": self.n_steps, "estimator_name": self.estimator_name, @@ -1789,6 +1825,10 @@ class SimulationSampleSizeResults: Diagnostic trace of ``{n_units, power}`` at each step. estimator_name : str Name of the estimator used. + effective_n_units : int or None + Effective sample size when it differs from ``required_n`` + (e.g., due to DDD grid rounding). ``None`` when no rounding occurred + or when the search already snapped to the estimator's grid. """ required_n: int @@ -1800,6 +1840,7 @@ class SimulationSampleSizeResults: n_steps: int search_path: List[Dict[str, float]] estimator_name: str + effective_n_units: Optional[int] = None def __repr__(self) -> str: return ( @@ -1827,8 +1868,12 @@ def summary(self) -> str: f"{'Required sample size:':<35} {self.required_n}", f"{'Power at required N:':<35} {self.power_at_n:.1%}", f"{'Bisection steps:':<35} {self.n_steps}", - "=" * 65, ] + if self.effective_n_units is not None: + lines.append( + f"{'Effective sample size:':<35} {self.effective_n_units}" f" (DDD grid-rounded)" + ) + lines.append("=" * 65) return "\n".join(lines) def to_dict(self) -> Dict[str, Any]: @@ -1842,6 +1887,7 @@ def to_dict(self) -> Dict[str, Any]: "n_simulations_per_step": self.n_simulations_per_step, "n_steps": self.n_steps, "estimator_name": self.estimator_name, + "effective_n_units": self.effective_n_units, } def to_dataframe(self) -> pd.DataFrame: @@ -1930,6 +1976,12 @@ def simulate_mde( estimator_name = type(estimator).__name__ search_path: List[Dict[str, float]] = [] + # Compute effective N for DDD (N is fixed throughout MDE search) + if estimator_name == "TripleDifference" and data_generator is None: + effective_n_units = _ddd_effective_n(n_units, data_generator_kwargs) + else: + effective_n_units = None + common_kwargs: Dict[str, Any] = dict( estimator=estimator, n_units=n_units, @@ -1976,6 +2028,7 @@ def _power_at(effect: float) -> float: n_steps=len(search_path), search_path=search_path, estimator_name=estimator_name, + effective_n_units=effective_n_units, ) if power_hi < power: warnings.warn( @@ -2003,6 +2056,7 @@ def _power_at(effect: float) -> float: n_steps=len(search_path), search_path=search_path, estimator_name=estimator_name, + effective_n_units=effective_n_units, ) hi = sigma @@ -2046,6 +2100,7 @@ def _power_at(effect: float) -> float: n_steps=len(search_path), search_path=search_path, estimator_name=estimator_name, + effective_n_units=effective_n_units, ) @@ -2134,6 +2189,18 @@ def simulate_sample_size( profile = registry.get(estimator_name) min_n = profile.min_n if profile is not None else 20 + # DDD grid snapping: bisection candidates must be multiples of 8 + is_ddd_grid = estimator_name == "TripleDifference" and data_generator is None + grid_step = 8 if is_ddd_grid else 1 + convergence_threshold = grid_step + 1 # 9 for DDD, 2 for others + + def _snap_n(n: int, direction: str = "down") -> int: + if grid_step == 1: + return n + if direction == "up": + return max(min_n, ((n + grid_step - 1) // grid_step) * grid_step) + return max(min_n, (n // grid_step) * grid_step) + common_kwargs: Dict[str, Any] = dict( estimator=estimator, n_periods=n_periods, @@ -2161,7 +2228,9 @@ def _power_at_n(n: int) -> float: # --- Bracket --- if n_range is not None: - lo, hi = n_range + lo, hi = _snap_n(n_range[0], "up"), _snap_n(n_range[1], "down") + if lo > hi: + lo = hi # collapsed bracket — evaluate single point power_lo = _power_at_n(lo) if power_lo >= power: warnings.warn( @@ -2225,9 +2294,11 @@ def _power_at_n(n: int) -> float: best_power = search_path[-1]["power"] if search_path else 0.0 for _ in range(max_steps): - if hi - lo <= 2: + if hi - lo <= convergence_threshold: + break + mid = _snap_n((lo + hi) // 2) + if mid <= lo or mid >= hi: break - mid = (lo + hi) // 2 pwr = _power_at_n(mid) if pwr >= power: diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 237cf390..1f0c3685 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -1718,7 +1718,7 @@ n = 2(t_{α/2} + t_{1-κ})² σ² / MDE² - High ICC: dramatically reduces effective sample size - Unequal allocation: optimal is often 50-50 but depends on costs - **Note:** The simulation-based power registry (`simulate_power`, `simulate_mde`, `simulate_sample_size`) uses a single-cohort staggered DGP by default. Estimators configured with `control_group="not_yet_treated"`, `clean_control="strict"`, or `anticipation>0` will receive a `UserWarning` because the default DGP does not match their identification strategy. Users must supply `data_generator_kwargs` (e.g., `cohort_periods=[2, 4]`, `never_treated_frac=0.0`) or a custom `data_generator` to match the estimator design. -- **Note:** The `TripleDifference` registry adapter uses `generate_ddd_data`, a fixed 2×2×2 factorial DGP (group × partition × time). The `n_periods`, `treatment_period`, and `treatment_fraction` parameters are ignored — DDD always simulates 2 periods with balanced groups. `n_units` is mapped to `n_per_cell = max(2, n_units // 8)` (effective total N = `n_per_cell × 8`), so non-multiples of 8 are rounded down and values below 16 are clamped to 16. A `UserWarning` is emitted when simulation inputs differ from the effective DDD design. Passing `n_per_cell` in `data_generator_kwargs` suppresses the effective-N rounding warning but not warnings for ignored parameters (`n_periods`, `treatment_period`, `treatment_fraction`). +- **Note:** The `TripleDifference` registry adapter uses `generate_ddd_data`, a fixed 2×2×2 factorial DGP (group × partition × time). The `n_periods`, `treatment_period`, and `treatment_fraction` parameters are ignored — DDD always simulates 2 periods with balanced groups. `n_units` is mapped to `n_per_cell = max(2, n_units // 8)` (effective total N = `n_per_cell × 8`), so non-multiples of 8 are rounded down and values below 16 are clamped to 16. A `UserWarning` is emitted when simulation inputs differ from the effective DDD design. When rounding occurs, all result objects (`SimulationPowerResults`, `SimulationMDEResults`, `SimulationSampleSizeResults`) set `effective_n_units` to the actual sample size used; it is `None` when no rounding occurred. `simulate_sample_size()` snaps bisection candidates to multiples of 8 so that `required_n` is always a realizable DDD sample size. Passing `n_per_cell` in `data_generator_kwargs` suppresses the effective-N rounding warning but not warnings for ignored parameters (`n_periods`, `treatment_period`, `treatment_fraction`). **Reference implementation(s):** - R: `pwr` package (general), `DeclareDesign` (simulation-based) diff --git a/tests/test_power.py b/tests/test_power.py index 5008109f..c2eaa4ef 100644 --- a/tests/test_power.py +++ b/tests/test_power.py @@ -1049,6 +1049,37 @@ def test_ddd_n_per_cell_suppresses_rounding(self): progress=False, ) + def test_ddd_power_effective_n_nonaligned(self): + """simulate_power reports effective_n_units when n_units isn't grid-aligned.""" + with pytest.warns(UserWarning, match="effective sample size is 64"): + result = simulate_power( + TripleDifference(), + n_units=65, + n_periods=2, + treatment_period=1, + n_simulations=2, + seed=42, + progress=False, + ) + assert result.effective_n_units == 64 + assert result.to_dict()["effective_n_units"] == 64 + assert "Effective sample size" in result.summary() + + def test_ddd_power_effective_n_aligned(self): + """simulate_power sets effective_n_units=None when n_units is grid-aligned.""" + result = simulate_power( + TripleDifference(), + n_units=80, + n_periods=2, + treatment_period=1, + n_simulations=2, + seed=42, + progress=False, + ) + assert result.effective_n_units is None + assert result.to_dict()["effective_n_units"] is None + assert "Effective sample size" not in result.summary() + @pytest.mark.slow def test_ddd_mde(self): """simulate_mde works for TripleDifference.""" @@ -1064,6 +1095,24 @@ def test_ddd_mde(self): ) assert isinstance(result, SimulationMDEResults) assert result.mde > 0 + assert result.effective_n_units is None + + @pytest.mark.slow + def test_ddd_mde_effective_n(self): + """simulate_mde reports effective_n_units for non-aligned n_units.""" + with pytest.warns(UserWarning, match="effective sample size is 64"): + result = simulate_mde( + TripleDifference(), + n_units=65, + n_periods=2, + treatment_period=1, + n_simulations=5, + effect_range=(0.5, 5.0), + seed=42, + progress=False, + ) + assert result.effective_n_units == 64 + assert result.to_dict()["effective_n_units"] == 64 @pytest.mark.slow def test_ddd_sample_size(self): @@ -1080,6 +1129,22 @@ def test_ddd_sample_size(self): assert isinstance(result, SimulationSampleSizeResults) assert result.required_n > 0 + @pytest.mark.slow + def test_ddd_sample_size_grid_aligned(self): + """simulate_sample_size returns grid-aligned required_n for DDD.""" + result = simulate_sample_size( + TripleDifference(), + n_periods=2, + treatment_period=1, + n_simulations=5, + n_range=(64, 200), + seed=42, + progress=False, + ) + assert ( + result.required_n % 8 == 0 + ), f"DDD required_n={result.required_n} is not a multiple of 8" + @pytest.mark.slow def test_trop(self): result = simulate_power( From 94fa3d4be34eba430a80d914d002e1a5b24c1fe5 Mon Sep 17 00:00:00 2001 From: igerber Date: Fri, 20 Mar 2026 09:19:03 -0400 Subject: [PATCH 13/14] Block DGP key collisions, search below registry floor, tighten staggered compat Address AI review P0/P1/P2 findings: - P0: Reject data_generator_kwargs that override registry-managed keys (treatment_effect, noise_sd, n_units, etc.) to prevent silent desync - P1: simulate_sample_size() now searches below the registry floor when the floor already achieves target power, finding the true minimum N - P1: _check_staggered_dgp_compat() checks len(set(cohort_periods)) >= 2 instead of key existence, so single-cohort overrides like [2] still warn - P2: Add regression tests for all three fixes - Fix best_power initialization in bisection after downward search Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/power.py | 133 +++++++++++++++++++++++++--------- docs/methodology/REGISTRY.md | 1 + tests/test_power.py | 137 ++++++++++++++++++++++++++++++++++- 3 files changed, 234 insertions(+), 37 deletions(-) diff --git a/diff_diff/power.py b/diff_diff/power.py index bffc73b8..5da4b3b6 100644 --- a/diff_diff/power.py +++ b/diff_diff/power.py @@ -237,6 +237,20 @@ def _first(r: Any, *attrs: str, default: Any = _nan) -> Any: ) +# Keys derived from simulate_power() public params — overriding these +# via data_generator_kwargs would desync the DGP from the result object. +_PROTECTED_DGP_KEYS = frozenset( + { + "treatment_effect", # → true_effect in results / MDE search variable + "noise_sd", # → sigma param + "n_units", # → sample-size search variable + "n_periods", # → n_periods param + "treatment_fraction", # → treatment_fraction param + "treatment_period", # → treatment_period param + } +) + + # -- Staggered DGP compatibility check ---------------------------------------- _STAGGERED_ESTIMATORS = frozenset( @@ -261,11 +275,13 @@ def _check_staggered_dgp_compat( return dgp_overrides = data_generator_kwargs or {} + cohort_periods = dgp_overrides.get("cohort_periods") + has_multi_cohort = cohort_periods is not None and len(set(cohort_periods)) >= 2 issues: List[str] = [] # Check control_group="not_yet_treated" (CS, SA) cg = getattr(estimator, "control_group", "never_treated") - if cg == "not_yet_treated" and "cohort_periods" not in dgp_overrides: + if cg == "not_yet_treated" and not has_multi_cohort: issues.append( f' - {name} has control_group="not_yet_treated" but the default ' f"DGP generates a single treatment cohort with never-treated " @@ -291,7 +307,7 @@ def _check_staggered_dgp_compat( # Check clean_control on StackedDiD if name == "StackedDiD": cc = getattr(estimator, "clean_control", "not_yet_treated") - if cc == "strict" and "cohort_periods" not in dgp_overrides: + if cc == "strict" and not has_multi_cohort: issues.append( ' - StackedDiD has clean_control="strict" but the default ' "single-cohort DGP makes strict controls equivalent to " @@ -1503,6 +1519,28 @@ def simulate_power( if profile is not None and not use_custom_dgp: _check_staggered_dgp_compat(estimator, data_generator_kwargs) + # Block registry-path collisions on search-critical keys + if profile is not None and not use_custom_dgp and data_gen_kwargs: + sample_dgp_keys = set( + profile.dgp_kwargs_builder( + n_units=n_units, + n_periods=n_periods, + treatment_effect=treatment_effect, + treatment_fraction=treatment_fraction, + treatment_period=treatment_period, + sigma=sigma, + ).keys() + ) + collisions = _PROTECTED_DGP_KEYS & set(data_gen_kwargs) & sample_dgp_keys + if collisions: + raise ValueError( + f"data_generator_kwargs contains keys that conflict with " + f"registry-managed simulation inputs: {sorted(collisions)}. " + f"These are controlled by simulate_power() parameters directly. " + f"Use the corresponding function parameters instead, or pass a " + f"custom data_generator to override the DGP entirely." + ) + # Warn if DDD design inputs are silently ignored if estimator_name == "TripleDifference" and not use_custom_dgp: _check_ddd_dgp_compat( @@ -2194,12 +2232,13 @@ def simulate_sample_size( grid_step = 8 if is_ddd_grid else 1 convergence_threshold = grid_step + 1 # 9 for DDD, 2 for others - def _snap_n(n: int, direction: str = "down") -> int: + def _snap_n(n: int, direction: str = "down", floor: Optional[int] = None) -> int: if grid_step == 1: return n + actual_floor = floor if floor is not None else min_n if direction == "up": - return max(min_n, ((n + grid_step - 1) // grid_step) * grid_step) - return max(min_n, (n // grid_step) * grid_step) + return max(actual_floor, ((n + grid_step - 1) // grid_step) * grid_step) + return max(actual_floor, (n // grid_step) * grid_step) common_kwargs: Dict[str, Any] = dict( estimator=estimator, @@ -2260,38 +2299,66 @@ def _power_at_n(n: int) -> float: lo = min_n power_lo = _power_at_n(lo) if power_lo >= power: - warnings.warn( - f"Power at registry floor n={lo} is {power_lo:.2f} >= " - f"target {power}. No smaller sample sizes were evaluated. " - f"Pass n_range=(lo, hi) to search below this floor.", - UserWarning, - ) - return SimulationSampleSizeResults( - required_n=lo, - power_at_n=power_lo, - target_power=power, - alpha=alpha, - effect_size=treatment_effect, - n_simulations_per_step=n_simulations, - n_steps=len(search_path), - search_path=search_path, - estimator_name=estimator_name, - ) - hi = max(100, 2 * min_n) - for _ in range(10): - if _power_at_n(hi) >= power: - break - hi *= 2 + # Floor achieves target — search downward for true minimum + hi = lo + abs_min = 16 if is_ddd_grid else 4 + found_lower = False + probe = _snap_n(max(abs_min, lo // 2), floor=abs_min) + for _ in range(8): + if probe >= hi or probe < abs_min: + break + pwr = _power_at_n(probe) + if pwr < power: + lo = probe + found_lower = True + break + hi = probe + probe = _snap_n(max(abs_min, probe // 2), floor=abs_min) + if not found_lower: + # Even smallest viable N achieves target — return best found + best = min( + (s for s in search_path if s["power"] >= power), + key=lambda s: s["n_units"], + ) + warnings.warn( + f"Power at n={int(best['n_units'])} is " + f"{best['power']:.2f} >= target {power}. Could not " + f"find a smaller N below target power. Pass " + f"n_range=(lo, hi) to refine.", + UserWarning, + ) + return SimulationSampleSizeResults( + required_n=int(best["n_units"]), + power_at_n=best["power"], + target_power=power, + alpha=alpha, + effect_size=treatment_effect, + n_simulations_per_step=n_simulations, + n_steps=len(search_path), + search_path=search_path, + estimator_name=estimator_name, + ) + # Fall through to bisection with lo..hi bracket else: - warnings.warn( - f"Could not bracket required N (power at n={hi} still below " - f"{power}). Returning best upper bound.", - UserWarning, - ) + hi = max(100, 2 * min_n) + for _ in range(10): + if _power_at_n(hi) >= power: + break + hi *= 2 + else: + warnings.warn( + f"Could not bracket required N (power at n={hi} still " + f"below {power}). Returning best upper bound.", + UserWarning, + ) # --- Bisect on integer n_units --- best_n = hi - best_power = search_path[-1]["power"] if search_path else 0.0 + # Look up power at hi (search_path[-1] may not be hi after downward search) + best_power = next( + (s["power"] for s in reversed(search_path) if int(s["n_units"]) == hi), + search_path[-1]["power"] if search_path else 0.0, + ) for _ in range(max_steps): if hi - lo <= convergence_threshold: diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 1f0c3685..d309d595 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -1717,6 +1717,7 @@ n = 2(t_{α/2} + t_{1-κ})² σ² / MDE² - Very small effects: may require infeasibly large samples - High ICC: dramatically reduces effective sample size - Unequal allocation: optimal is often 50-50 but depends on costs +- **Note:** `data_generator_kwargs` keys that overlap with registry-managed simulation inputs (`treatment_effect`, `noise_sd`, `n_units`, `n_periods`, `treatment_fraction`, `treatment_period`) are rejected with `ValueError` to prevent silent desync between the DGP and result metadata. Use the corresponding `simulate_power()` parameters directly, or pass a custom `data_generator` to override the DGP entirely. - **Note:** The simulation-based power registry (`simulate_power`, `simulate_mde`, `simulate_sample_size`) uses a single-cohort staggered DGP by default. Estimators configured with `control_group="not_yet_treated"`, `clean_control="strict"`, or `anticipation>0` will receive a `UserWarning` because the default DGP does not match their identification strategy. Users must supply `data_generator_kwargs` (e.g., `cohort_periods=[2, 4]`, `never_treated_frac=0.0`) or a custom `data_generator` to match the estimator design. - **Note:** The `TripleDifference` registry adapter uses `generate_ddd_data`, a fixed 2×2×2 factorial DGP (group × partition × time). The `n_periods`, `treatment_period`, and `treatment_fraction` parameters are ignored — DDD always simulates 2 periods with balanced groups. `n_units` is mapped to `n_per_cell = max(2, n_units // 8)` (effective total N = `n_per_cell × 8`), so non-multiples of 8 are rounded down and values below 16 are clamped to 16. A `UserWarning` is emitted when simulation inputs differ from the effective DDD design. When rounding occurs, all result objects (`SimulationPowerResults`, `SimulationMDEResults`, `SimulationSampleSizeResults`) set `effective_n_units` to the actual sample size used; it is `None` when no rounding occurred. `simulate_sample_size()` snaps bisection candidates to multiples of 8 so that `required_n` is always a realizable DDD sample size. Passing `n_per_cell` in `data_generator_kwargs` suppresses the effective-N rounding warning but not warnings for ignored parameters (`n_periods`, `treatment_period`, `treatment_fraction`). diff --git a/tests/test_power.py b/tests/test_power.py index c2eaa4ef..d1cb4ea5 100644 --- a/tests/test_power.py +++ b/tests/test_power.py @@ -1825,8 +1825,8 @@ def test_lo_already_sufficient_explicit(self): assert result.power_at_n >= 0.80 def test_lo_already_sufficient_auto(self): - """Auto-bracket warns and returns min_n when effect overwhelmingly large.""" - with pytest.warns(UserWarning, match="registry floor"): + """Auto-bracket searches downward when floor already achieves power.""" + with pytest.warns(UserWarning, match="Could not find a smaller N"): result = simulate_sample_size( DifferenceInDifferences(), treatment_effect=50.0, @@ -1835,6 +1835,135 @@ def test_lo_already_sufficient_auto(self): seed=42, progress=False, ) - # min_n for DifferenceInDifferences is 20 - assert result.required_n == 20 + # Effect is so large even abs_min=4 achieves power + assert result.required_n <= 20 assert result.power_at_n >= 0.80 + + @pytest.mark.slow + def test_sample_size_searches_below_floor(self): + """Large effect → downward search finds required_n below registry floor.""" + result = simulate_sample_size( + DifferenceInDifferences(), + treatment_effect=50.0, + sigma=1.0, + n_simulations=5, + seed=42, + progress=False, + ) + # min_n for DiD is 20; huge effect should find smaller N + assert result.required_n < 20 + + +class TestDGPKeyCollisions: + """Verify registry-path DGP key collision detection.""" + + def test_reject_treatment_effect_collision(self): + """treatment_effect in data_generator_kwargs raises ValueError.""" + with pytest.raises(ValueError, match="conflict"): + simulate_power( + DifferenceInDifferences(), + n_simulations=2, + seed=42, + progress=False, + data_generator_kwargs={"treatment_effect": 99}, + ) + + def test_reject_noise_sd_collision(self): + """noise_sd in data_generator_kwargs raises ValueError.""" + with pytest.raises(ValueError, match="conflict"): + simulate_power( + DifferenceInDifferences(), + n_simulations=2, + seed=42, + progress=False, + data_generator_kwargs={"noise_sd": 5.0}, + ) + + def test_allow_cohort_periods_override(self): + """cohort_periods is not a protected key — no collision.""" + # Should not raise + simulate_power( + CallawaySantAnna(), + n_units=60, + n_periods=6, + treatment_period=3, + n_simulations=2, + seed=42, + progress=False, + data_generator_kwargs={"cohort_periods": [2, 4]}, + ) + + def test_allow_n_per_cell_override(self): + """n_per_cell is not a protected key — no collision for DDD.""" + # Should not raise (n_per_cell is in DDD builder output but not + # in _PROTECTED_DGP_KEYS, so 3-way intersection is empty) + simulate_power( + TripleDifference(), + n_simulations=2, + seed=42, + progress=False, + data_generator_kwargs={"n_per_cell": 15}, + ) + + def test_collision_skipped_for_custom_dgp(self): + """Custom data_generator bypasses collision check entirely.""" + # unit_fe_sd is accepted by generate_did_data; collision check is + # skipped because a custom data_generator is provided. + simulate_power( + DifferenceInDifferences(), + n_simulations=2, + seed=42, + progress=False, + data_generator=generate_did_data, + data_generator_kwargs={"unit_fe_sd": 3.0}, + ) + + +class TestStaggeredSingleCohort: + """Verify staggered DGP compat check handles single-cohort overrides.""" + + def test_staggered_single_cohort_still_warns(self): + """CS with cohort_periods=[2] still warns — single cohort.""" + with pytest.warns(UserWarning, match="DGP mismatch"): + simulate_power( + CallawaySantAnna(control_group="not_yet_treated"), + n_units=60, + n_periods=6, + treatment_period=3, + n_simulations=2, + seed=42, + progress=False, + data_generator_kwargs={"cohort_periods": [2]}, + ) + + def test_staggered_multi_cohort_no_warn(self): + """CS with cohort_periods=[2, 4] does NOT warn.""" + with warnings.catch_warnings(): + warnings.simplefilter("error") + simulate_power( + CallawaySantAnna(control_group="not_yet_treated"), + n_units=60, + n_periods=6, + treatment_period=3, + n_simulations=2, + seed=42, + progress=False, + data_generator_kwargs={ + "cohort_periods": [2, 4], + "never_treated_frac": 0.0, + }, + ) + + def test_stacked_strict_single_cohort_warns(self): + """StackedDiD clean_control='strict' with cohort_periods=[2] warns.""" + with pytest.warns(UserWarning, match="DGP mismatch"): + simulate_power( + StackedDiD(clean_control="strict"), + n_units=60, + n_periods=6, + treatment_period=3, + n_simulations=2, + seed=42, + progress=False, + data_generator_kwargs={"cohort_periods": [2]}, + ) From b83d531ee49e0da9f1e0ba2d90ba5ebd965849f9 Mon Sep 17 00:00:00 2001 From: igerber Date: Fri, 20 Mar 2026 10:56:58 -0400 Subject: [PATCH 14/14] Guard n_pre/n_post and DDD n_per_cell derived-key desync in power analysis Add n_pre and n_post to _PROTECTED_DGP_KEYS to prevent SyntheticDiD/TROP DGP desync when overriding timing-derived keys. Reject n_per_cell in simulate_sample_size() for TripleDifference since it freezes effective sample size across bisection. Add 4 regression tests and update REGISTRY.md. Co-Authored-By: Claude Opus 4.6 (1M context) --- diff_diff/power.py | 11 ++++++++ docs/methodology/REGISTRY.md | 3 ++- tests/test_power.py | 50 ++++++++++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 1 deletion(-) diff --git a/diff_diff/power.py b/diff_diff/power.py index 5da4b3b6..1bd1d7bd 100644 --- a/diff_diff/power.py +++ b/diff_diff/power.py @@ -247,6 +247,8 @@ def _first(r: Any, *attrs: str, default: Any = _nan) -> Any: "n_periods", # → n_periods param "treatment_fraction", # → treatment_fraction param "treatment_period", # → treatment_period param + "n_pre", # → derived from treatment_period in factor-model DGPs + "n_post", # → derived from n_periods - treatment_period in factor-model DGPs } ) @@ -2232,6 +2234,15 @@ def simulate_sample_size( grid_step = 8 if is_ddd_grid else 1 convergence_threshold = grid_step + 1 # 9 for DDD, 2 for others + if is_ddd_grid and data_generator_kwargs and "n_per_cell" in data_generator_kwargs: + raise ValueError( + "data_generator_kwargs contains 'n_per_cell', which conflicts with " + "the sample-size search in simulate_sample_size(). For " + "TripleDifference, n_per_cell is derived from n_units (the search " + "variable). Use simulate_power() with a fixed n_per_cell override " + "instead, or pass a custom data_generator." + ) + def _snap_n(n: int, direction: str = "down", floor: Optional[int] = None) -> int: if grid_step == 1: return n diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index d309d595..5c47173f 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -1717,7 +1717,8 @@ n = 2(t_{α/2} + t_{1-κ})² σ² / MDE² - Very small effects: may require infeasibly large samples - High ICC: dramatically reduces effective sample size - Unequal allocation: optimal is often 50-50 but depends on costs -- **Note:** `data_generator_kwargs` keys that overlap with registry-managed simulation inputs (`treatment_effect`, `noise_sd`, `n_units`, `n_periods`, `treatment_fraction`, `treatment_period`) are rejected with `ValueError` to prevent silent desync between the DGP and result metadata. Use the corresponding `simulate_power()` parameters directly, or pass a custom `data_generator` to override the DGP entirely. +- **Note:** `data_generator_kwargs` keys that overlap with registry-managed simulation inputs (`treatment_effect`, `noise_sd`, `n_units`, `n_periods`, `treatment_fraction`, `treatment_period`, `n_pre`, `n_post`) are rejected with `ValueError` to prevent silent desync between the DGP and result metadata. `n_pre` and `n_post` are derived from `treatment_period` and `n_periods` in factor-model DGPs (SyntheticDiD, TROP); the 3-way intersection check naturally scopes the rejection to those estimators only. Use the corresponding `simulate_power()` parameters directly, or pass a custom `data_generator` to override the DGP entirely. +- **Note:** `simulate_sample_size()` rejects `n_per_cell` in `data_generator_kwargs` for `TripleDifference` because `n_per_cell` is derived from `n_units` (the search variable). A fixed override would freeze the effective sample size across bisection iterations, making the search degenerate. Use `simulate_power()` with a fixed `n_per_cell` override instead, or pass a custom `data_generator`. - **Note:** The simulation-based power registry (`simulate_power`, `simulate_mde`, `simulate_sample_size`) uses a single-cohort staggered DGP by default. Estimators configured with `control_group="not_yet_treated"`, `clean_control="strict"`, or `anticipation>0` will receive a `UserWarning` because the default DGP does not match their identification strategy. Users must supply `data_generator_kwargs` (e.g., `cohort_periods=[2, 4]`, `never_treated_frac=0.0`) or a custom `data_generator` to match the estimator design. - **Note:** The `TripleDifference` registry adapter uses `generate_ddd_data`, a fixed 2×2×2 factorial DGP (group × partition × time). The `n_periods`, `treatment_period`, and `treatment_fraction` parameters are ignored — DDD always simulates 2 periods with balanced groups. `n_units` is mapped to `n_per_cell = max(2, n_units // 8)` (effective total N = `n_per_cell × 8`), so non-multiples of 8 are rounded down and values below 16 are clamped to 16. A `UserWarning` is emitted when simulation inputs differ from the effective DDD design. When rounding occurs, all result objects (`SimulationPowerResults`, `SimulationMDEResults`, `SimulationSampleSizeResults`) set `effective_n_units` to the actual sample size used; it is `None` when no rounding occurred. `simulate_sample_size()` snaps bisection candidates to multiples of 8 so that `required_n` is always a realizable DDD sample size. Passing `n_per_cell` in `data_generator_kwargs` suppresses the effective-N rounding warning but not warnings for ignored parameters (`n_periods`, `treatment_period`, `treatment_fraction`). diff --git a/tests/test_power.py b/tests/test_power.py index d1cb4ea5..a39137fd 100644 --- a/tests/test_power.py +++ b/tests/test_power.py @@ -1853,6 +1853,18 @@ def test_sample_size_searches_below_floor(self): # min_n for DiD is 20; huge effect should find smaller N assert result.required_n < 20 + def test_reject_n_per_cell_in_ddd_sample_size(self): + """n_per_cell override in simulate_sample_size raises for DDD.""" + with pytest.raises(ValueError, match="n_per_cell"): + simulate_sample_size( + TripleDifference(), + treatment_effect=5.0, + n_simulations=2, + seed=42, + progress=False, + data_generator_kwargs={"n_per_cell": 10}, + ) + class TestDGPKeyCollisions: """Verify registry-path DGP key collision detection.""" @@ -1905,6 +1917,44 @@ def test_allow_n_per_cell_override(self): data_generator_kwargs={"n_per_cell": 15}, ) + def test_reject_n_pre_collision_sdid(self): + """n_pre in data_generator_kwargs raises for SyntheticDiD (factor DGP).""" + with pytest.raises(ValueError, match="conflict"): + simulate_power( + SyntheticDiD(variance_method="bootstrap"), + n_simulations=2, + seed=42, + progress=False, + data_generator_kwargs={"n_pre": 1}, + ) + + def test_reject_n_post_collision_trop(self): + """n_post in data_generator_kwargs raises for TROP (factor DGP).""" + with pytest.raises(ValueError, match="conflict"): + simulate_power( + TROP(), + n_simulations=2, + seed=42, + progress=False, + data_generator_kwargs={"n_post": 5}, + ) + + def test_n_pre_not_rejected_for_basic_did(self): + """n_pre passes collision guard for basic DiD (not a derived key there). + + Basic DGP doesn't return n_pre, so 3-way intersection is empty. + generate_did_data rejects n_pre (not a valid param), proving the + collision guard did NOT fire (would have raised ValueError("conflict")). + """ + with pytest.raises(TypeError, match="n_pre"): + simulate_power( + DifferenceInDifferences(), + n_simulations=2, + seed=42, + progress=False, + data_generator_kwargs={"n_pre": 1}, + ) + def test_collision_skipped_for_custom_dgp(self): """Custom data_generator bypasses collision check entirely.""" # unit_fe_sd is accepted by generate_did_data; collision check is