diff --git a/TODO.md b/TODO.md index fe779df..01f6a33 100644 --- a/TODO.md +++ b/TODO.md @@ -47,6 +47,7 @@ Deferred items from PR reviews that were not addressed before merge. | Bootstrap NaN-gating gap: manual SE/CI/p-value without non-finite filtering or SE<=0 guard | `imputation_bootstrap.py`, `two_stage_bootstrap.py` | #177 | Medium — migrate to `compute_effect_bootstrap_stats` from `bootstrap_utils.py` | | EfficientDiD: warn when cohort share is very small (< 2 units or < 1% of sample) — inverted in Omega*/EIF | `efficient_did_weights.py` | #192 | Low | | EfficientDiD: API docs / tutorial page for new public estimator | `docs/` | #192 | Medium | +| TripleDifference power: `generate_ddd_data` is a fixed 2×2×2 cross-sectional DGP — no multi-period or unbalanced-group support. Add a `generate_ddd_panel_data` for panel DDD power analysis. | `prep_dgp.py`, `power.py` | #208 | Low | #### Performance diff --git a/diff_diff/__init__.py b/diff_diff/__init__.py index 02bdf77..126ad23 100644 --- a/diff_diff/__init__.py +++ b/diff_diff/__init__.py @@ -53,11 +53,15 @@ from diff_diff.power import ( PowerAnalysis, PowerResults, + SimulationMDEResults, SimulationPowerResults, + SimulationSampleSizeResults, compute_mde, compute_power, compute_sample_size, + simulate_mde, simulate_power, + simulate_sample_size, ) from diff_diff.pretrends import ( PreTrendsPower, @@ -291,11 +295,15 @@ # Power analysis "PowerAnalysis", "PowerResults", + "SimulationMDEResults", "SimulationPowerResults", + "SimulationSampleSizeResults", "compute_mde", "compute_power", "compute_sample_size", + "simulate_mde", "simulate_power", + "simulate_sample_size", "plot_power_curve", # Pre-trends power analysis "PreTrendsPower", diff --git a/diff_diff/power.py b/diff_diff/power.py index b8aa10d..a9871cf 100644 --- a/diff_diff/power.py +++ b/diff_diff/power.py @@ -33,6 +33,445 @@ MAX_SAMPLE_SIZE = 2**31 - 1 +# --------------------------------------------------------------------------- +# Estimator registry — maps estimator class names to DGP/fit/extract profiles +# --------------------------------------------------------------------------- + + +@dataclass +class _EstimatorProfile: + """Internal profile describing how to run power simulations for an estimator.""" + + default_dgp: Callable + dgp_kwargs_builder: Callable + fit_kwargs_builder: Callable + result_extractor: Callable + min_n: int = 20 + + +# -- DGP kwargs adapters ----------------------------------------------------- + + +def _basic_dgp_kwargs( + n_units: int, + n_periods: int, + treatment_effect: float, + treatment_fraction: float, + treatment_period: int, + sigma: float, +) -> Dict[str, Any]: + return dict( + n_units=n_units, + n_periods=n_periods, + treatment_effect=treatment_effect, + treatment_fraction=treatment_fraction, + treatment_period=treatment_period, + noise_sd=sigma, + ) + + +def _staggered_dgp_kwargs( + n_units: int, + n_periods: int, + treatment_effect: float, + treatment_fraction: float, + treatment_period: int, + sigma: float, +) -> Dict[str, Any]: + return dict( + n_units=n_units, + n_periods=n_periods, + treatment_effect=treatment_effect, + never_treated_frac=1 - treatment_fraction, + cohort_periods=[treatment_period], + dynamic_effects=False, + noise_sd=sigma, + ) + + +def _factor_dgp_kwargs( + n_units: int, + n_periods: int, + treatment_effect: float, + treatment_fraction: float, + treatment_period: int, + sigma: float, +) -> Dict[str, Any]: + n_pre = treatment_period + n_post = n_periods - treatment_period + return dict( + n_units=n_units, + n_pre=n_pre, + n_post=n_post, + n_treated=max(1, int(n_units * treatment_fraction)), + treatment_effect=treatment_effect, + noise_sd=sigma, + ) + + +def _ddd_dgp_kwargs( + n_units: int, + n_periods: int, + treatment_effect: float, + treatment_fraction: float, + treatment_period: int, + sigma: float, +) -> Dict[str, Any]: + return dict( + n_per_cell=max(2, n_units // 8), + treatment_effect=treatment_effect, + noise_sd=sigma, + ) + + +# -- Fit kwargs builders ------------------------------------------------------ + + +def _basic_fit_kwargs( + data: pd.DataFrame, + n_units: int, + n_periods: int, + treatment_period: int, +) -> Dict[str, Any]: + return dict(outcome="outcome", treatment="treated", time="post") + + +def _twfe_fit_kwargs( + data: pd.DataFrame, + n_units: int, + n_periods: int, + treatment_period: int, +) -> Dict[str, Any]: + return dict(outcome="outcome", treatment="treated", time="post", unit="unit") + + +def _multiperiod_fit_kwargs( + data: pd.DataFrame, + n_units: int, + n_periods: int, + treatment_period: int, +) -> Dict[str, Any]: + return dict( + outcome="outcome", + treatment="treated", + time="period", + post_periods=list(range(treatment_period, n_periods)), + ) + + +def _staggered_fit_kwargs( + data: pd.DataFrame, + n_units: int, + n_periods: int, + treatment_period: int, +) -> Dict[str, Any]: + return dict(outcome="outcome", unit="unit", time="period", first_treat="first_treat") + + +def _ddd_fit_kwargs( + data: pd.DataFrame, + n_units: int, + n_periods: int, + treatment_period: int, +) -> Dict[str, Any]: + return dict(outcome="outcome", group="group", partition="partition", time="time") + + +def _trop_fit_kwargs( + data: pd.DataFrame, + n_units: int, + n_periods: int, + treatment_period: int, +) -> Dict[str, Any]: + return dict(outcome="outcome", treatment="treated", unit="unit", time="period") + + +def _sdid_fit_kwargs( + data: pd.DataFrame, + n_units: int, + n_periods: int, + treatment_period: int, +) -> Dict[str, Any]: + periods = sorted(data["period"].unique()) + post_periods = [p for p in periods if p >= treatment_period] + return dict( + outcome="outcome", + treatment="treat", + unit="unit", + time="period", + post_periods=post_periods, + ) + + +# -- Result extractors -------------------------------------------------------- + + +def _extract_simple(result: Any) -> Tuple[float, float, float, Tuple[float, float]]: + return (result.att, result.se, result.p_value, result.conf_int) + + +def _extract_multiperiod( + result: Any, +) -> Tuple[float, float, float, Tuple[float, float]]: + return (result.avg_att, result.avg_se, result.avg_p_value, result.avg_conf_int) + + +def _extract_staggered( + result: Any, +) -> Tuple[float, float, float, Tuple[float, float]]: + _nan = float("nan") + _nan_ci = (_nan, _nan) + + def _first(r: Any, *attrs: str, default: Any = _nan) -> Any: + for a in attrs: + v = getattr(r, a, None) + if v is not None: + return v + return default + + return ( + result.overall_att, + _first(result, "overall_se", "overall_att_se"), + _first(result, "overall_p_value", "overall_att_p_value"), + _first(result, "overall_conf_int", "overall_att_ci", default=_nan_ci), + ) + + +# -- Staggered DGP compatibility check ---------------------------------------- + +_STAGGERED_ESTIMATORS = frozenset( + { + "CallawaySantAnna", + "SunAbraham", + "ImputationDiD", + "TwoStageDiD", + "StackedDiD", + "EfficientDiD", + } +) + + +def _check_staggered_dgp_compat( + estimator: Any, + data_generator_kwargs: Optional[Dict[str, Any]], +) -> None: + """Warn if a staggered estimator's settings don't match the default DGP.""" + name = type(estimator).__name__ + if name not in _STAGGERED_ESTIMATORS: + return + + dgp_overrides = data_generator_kwargs or {} + issues: List[str] = [] + + # Check control_group="not_yet_treated" (CS, SA) + cg = getattr(estimator, "control_group", "never_treated") + if cg == "not_yet_treated" and "cohort_periods" not in dgp_overrides: + issues.append( + f' - {name} has control_group="not_yet_treated" but the default ' + f"DGP generates a single treatment cohort with never-treated " + f"controls. Power may not reflect the intended not-yet-treated " + f"design.\n" + f" Fix: pass data_generator_kwargs=" + f'{{"cohort_periods": [2, 4], "never_treated_frac": 0.0}} ' + f"(or a custom data_generator)." + ) + + # Check anticipation > 0 (all staggered) + antic = getattr(estimator, "anticipation", 0) + if antic > 0: + issues.append( + f" - {name} has anticipation={antic} but the default DGP does " + f"not model anticipatory effects. The estimator will look for " + f"treatment effects {antic} period(s) before the DGP generates " + f"them, biasing power estimates.\n" + f" Fix: supply a custom data_generator that shifts the " + f"effect onset." + ) + + # Check clean_control on StackedDiD + if name == "StackedDiD": + cc = getattr(estimator, "clean_control", "not_yet_treated") + if cc == "strict" and "cohort_periods" not in dgp_overrides: + issues.append( + ' - StackedDiD has clean_control="strict" but the default ' + "single-cohort DGP makes strict controls equivalent to " + "never-treated controls.\n" + " Fix: pass data_generator_kwargs=" + '{"cohort_periods": [2, 4]} ' + "to test true strict clean-control behavior." + ) + + if issues: + msg = ( + f"Staggered power DGP mismatch for {name}. The default " + f"single-cohort DGP may not match the estimator " + f"configuration:\n" + "\n".join(issues) + ) + warnings.warn(msg, UserWarning, stacklevel=2) + + +def _check_ddd_dgp_compat( + n_units: int, + n_periods: int, + treatment_fraction: float, + treatment_period: int, + data_generator_kwargs: Optional[Dict[str, Any]], +) -> None: + """Warn when simulation inputs don't match DDD's fixed 2×2×2 design.""" + overrides = data_generator_kwargs or {} + issues: List[str] = [] + + # DDD is a fixed 2-period factorial; n_periods and treatment_period are ignored + if n_periods != 2: + issues.append( + f"n_periods={n_periods} is ignored (DDD uses a fixed " f"2-period design: pre/post)" + ) + if treatment_period != 1: + issues.append( + f"treatment_period={treatment_period} is ignored (DDD " + f"always treats in the second period)" + ) + + # DDD's 2×2×2 factorial has inherent 50% treatment fraction + if treatment_fraction != 0.5: + issues.append( + f"treatment_fraction={treatment_fraction} is ignored " + f"(DDD uses a balanced 2×2×2 factorial where 50% of " + f"groups are treated)" + ) + + # n_units rounding: n_per_cell = max(2, n_units // 8) + effective_n_per_cell = overrides.get("n_per_cell", max(2, n_units // 8)) + effective_n = effective_n_per_cell * 8 + if effective_n != n_units: + issues.append( + f"effective sample size is {effective_n} " + f"(n_per_cell={effective_n_per_cell} × 8 cells), " + f"not the requested n_units={n_units}" + ) + + if issues: + warnings.warn( + "TripleDifference uses a fixed 2×2×2 factorial DGP " + "(group × partition × time). " + + "; ".join(issues) + + ". Pass a custom data_generator for non-standard DDD designs.", + UserWarning, + stacklevel=2, + ) + + +# -- Registry construction (deferred to avoid import-time cost) --------------- + +_ESTIMATOR_REGISTRY: Optional[Dict[str, _EstimatorProfile]] = None + + +def _get_registry() -> Dict[str, _EstimatorProfile]: + """Lazily build and return the estimator registry.""" + global _ESTIMATOR_REGISTRY # noqa: PLW0603 + if _ESTIMATOR_REGISTRY is not None: + return _ESTIMATOR_REGISTRY + + from diff_diff.prep import ( + generate_ddd_data, + generate_did_data, + generate_factor_data, + generate_staggered_data, + ) + + _ESTIMATOR_REGISTRY = { + # --- Basic DiD group --- + "DifferenceInDifferences": _EstimatorProfile( + default_dgp=generate_did_data, + dgp_kwargs_builder=_basic_dgp_kwargs, + fit_kwargs_builder=_basic_fit_kwargs, + result_extractor=_extract_simple, + min_n=20, + ), + "TwoWayFixedEffects": _EstimatorProfile( + default_dgp=generate_did_data, + dgp_kwargs_builder=_basic_dgp_kwargs, + fit_kwargs_builder=_twfe_fit_kwargs, + result_extractor=_extract_simple, + min_n=20, + ), + "MultiPeriodDiD": _EstimatorProfile( + default_dgp=generate_did_data, + dgp_kwargs_builder=_basic_dgp_kwargs, + fit_kwargs_builder=_multiperiod_fit_kwargs, + result_extractor=_extract_multiperiod, + min_n=20, + ), + # --- Staggered group --- + "CallawaySantAnna": _EstimatorProfile( + default_dgp=generate_staggered_data, + dgp_kwargs_builder=_staggered_dgp_kwargs, + fit_kwargs_builder=_staggered_fit_kwargs, + result_extractor=_extract_staggered, + min_n=40, + ), + "SunAbraham": _EstimatorProfile( + default_dgp=generate_staggered_data, + dgp_kwargs_builder=_staggered_dgp_kwargs, + fit_kwargs_builder=_staggered_fit_kwargs, + result_extractor=_extract_staggered, + min_n=40, + ), + "ImputationDiD": _EstimatorProfile( + default_dgp=generate_staggered_data, + dgp_kwargs_builder=_staggered_dgp_kwargs, + fit_kwargs_builder=_staggered_fit_kwargs, + result_extractor=_extract_staggered, + min_n=40, + ), + "TwoStageDiD": _EstimatorProfile( + default_dgp=generate_staggered_data, + dgp_kwargs_builder=_staggered_dgp_kwargs, + fit_kwargs_builder=_staggered_fit_kwargs, + result_extractor=_extract_staggered, + min_n=40, + ), + "StackedDiD": _EstimatorProfile( + default_dgp=generate_staggered_data, + dgp_kwargs_builder=_staggered_dgp_kwargs, + fit_kwargs_builder=_staggered_fit_kwargs, + result_extractor=_extract_staggered, + min_n=40, + ), + "EfficientDiD": _EstimatorProfile( + default_dgp=generate_staggered_data, + dgp_kwargs_builder=_staggered_dgp_kwargs, + fit_kwargs_builder=_staggered_fit_kwargs, + result_extractor=_extract_staggered, + min_n=40, + ), + # --- Factor model group --- + "TROP": _EstimatorProfile( + default_dgp=generate_factor_data, + dgp_kwargs_builder=_factor_dgp_kwargs, + fit_kwargs_builder=_trop_fit_kwargs, + result_extractor=_extract_simple, + min_n=30, + ), + "SyntheticDiD": _EstimatorProfile( + default_dgp=generate_factor_data, + dgp_kwargs_builder=_factor_dgp_kwargs, + fit_kwargs_builder=_sdid_fit_kwargs, + result_extractor=_extract_simple, + min_n=30, + ), + # --- Triple difference --- + "TripleDifference": _EstimatorProfile( + default_dgp=generate_ddd_data, + dgp_kwargs_builder=_ddd_dgp_kwargs, + fit_kwargs_builder=_ddd_fit_kwargs, + result_extractor=_extract_simple, + min_n=64, + ), + } + return _ESTIMATOR_REGISTRY + + @dataclass class PowerResults: """ @@ -332,10 +771,7 @@ def power_curve_df(self) -> pd.DataFrame: pd.DataFrame DataFrame with effect_size and power columns. """ - return pd.DataFrame({ - "effect_size": self.effect_sizes, - "power": self.powers - }) + return pd.DataFrame({"effect_size": self.effect_sizes, "power": self.powers}) class PowerAnalysis: @@ -463,9 +899,7 @@ def _compute_variance( n_c_pre = n_control n_c_post = n_control - variance = sigma**2 * ( - 1 / n_t_post + 1 / n_t_pre + 1 / n_c_post + 1 / n_c_pre - ) + variance = sigma**2 * (1 / n_t_post + 1 / n_t_pre + 1 / n_c_post + 1 / n_c_pre) elif design == "panel": # Panel DiD with multiple periods # Account for serial correlation via ICC @@ -528,9 +962,7 @@ def power( T = n_pre + n_post design = "panel" if T > 2 else "basic_did" - variance = self._compute_variance( - n_treated, n_control, n_pre, n_post, sigma, rho, design - ) + variance = self._compute_variance(n_treated, n_control, n_pre, n_post, sigma, rho, design) se = np.sqrt(variance) # Calculate power @@ -538,7 +970,8 @@ def power( z_alpha = stats.norm.ppf(1 - self.alpha / 2) # Power = P(reject | effect) = P(|Z| > z_alpha | effect) power_val = ( - 1 - stats.norm.cdf(z_alpha - effect_size / se) + 1 + - stats.norm.cdf(z_alpha - effect_size / se) + stats.norm.cdf(-z_alpha - effect_size / se) ) elif self.alternative == "greater": @@ -551,8 +984,7 @@ def power( # Also compute MDE and required N for reference mde = self._compute_mde_from_se(se) required_n = self._compute_required_n( - effect_size, sigma, n_pre, n_post, rho, design, - n_treated / (n_treated + n_control) + effect_size, sigma, n_pre, n_post, rho, design, n_treated / (n_treated + n_control) ) return PowerResults( @@ -620,9 +1052,7 @@ def mde( T = n_pre + n_post design = "panel" if T > 2 else "basic_did" - variance = self._compute_variance( - n_treated, n_control, n_pre, n_post, sigma, rho, design - ) + variance = self._compute_variance(n_treated, n_control, n_pre, n_post, sigma, rho, design) se = np.sqrt(variance) mde = self._compute_mde_from_se(se) @@ -674,7 +1104,9 @@ def _compute_required_n( # = 2 * sigma^2 / N * (1/(p*(1-p))) n_total = ( - 2 * sigma**2 * (z_alpha + z_beta)**2 + 2 + * sigma**2 + * (z_alpha + z_beta) ** 2 / (effect_size**2 * treat_frac * (1 - treat_frac)) ) else: # panel @@ -684,7 +1116,10 @@ def _compute_required_n( # For balanced: Var = 2 * sigma^2 / N * design_effect / T n_total = ( - 2 * sigma**2 * (z_alpha + z_beta)**2 * design_effect + 2 + * sigma**2 + * (z_alpha + z_beta) ** 2 + * design_effect / (effect_size**2 * treat_frac * (1 - treat_frac) * T) ) @@ -744,9 +1179,7 @@ def sample_size( n_total = n_treated + n_control # Compute actual power achieved - variance = self._compute_variance( - n_treated, n_control, n_pre, n_post, sigma, rho, design - ) + variance = self._compute_variance(n_treated, n_control, n_pre, n_post, sigma, rho, design) se = np.sqrt(variance) mde = self._compute_mde_from_se(se) @@ -865,9 +1298,7 @@ def sample_size_curve( DataFrame with columns 'sample_size' and 'power'. """ # Get required N to determine default range - required = self.sample_size( - effect_size, sigma, n_pre, n_post, rho, treat_frac - ) + required = self.sample_size(effect_size, sigma, n_pre, n_post, rho, treat_frac) if sample_sizes is None: min_n = max(10, required.required_n // 4) @@ -907,6 +1338,7 @@ def simulate_power( data_generator: Optional[Callable] = None, data_generator_kwargs: Optional[Dict[str, Any]] = None, estimator_kwargs: Optional[Dict[str, Any]] = None, + result_extractor: Optional[Callable] = None, progress: bool = True, ) -> SimulationPowerResults: """ @@ -914,7 +1346,8 @@ def simulate_power( This function simulates datasets with known treatment effects and estimates power as the fraction of simulations where the null hypothesis is rejected. - This is the recommended approach for complex designs like staggered adoption. + Most built-in estimators are supported via an internal registry that selects + the appropriate data-generating process and fit signature automatically. Parameters ---------- @@ -942,12 +1375,18 @@ def simulate_power( seed : int, optional Random seed for reproducibility. data_generator : callable, optional - Custom data generation function. Should accept same signature as - generate_did_data(). If None, uses generate_did_data(). + Custom data generation function. When provided, bypasses the + registry DGP and calls this function with the standard kwargs + (n_units, n_periods, treatment_effect, etc.). data_generator_kwargs : dict, optional Additional keyword arguments for data generator. estimator_kwargs : dict, optional Additional keyword arguments for estimator.fit(). + result_extractor : callable, optional + Custom function to extract results from the estimator output. + Takes the estimator result object and returns a tuple of + ``(att, se, p_value, conf_int)``. Useful for unregistered + estimators with non-standard result schemas. progress : bool, default=True Whether to print progress updates. @@ -982,15 +1421,11 @@ def simulate_power( ... ) >>> print(results.power_curve_df()) - With Callaway-Sant'Anna for staggered designs: + With Callaway-Sant'Anna (auto-detected, no custom DGP needed): >>> from diff_diff import CallawaySantAnna >>> cs = CallawaySantAnna() - >>> # Custom data generator for staggered adoption - >>> def staggered_data(n_units, n_periods, treatment_effect, **kwargs): - ... # Your staggered data generation logic - ... ... - >>> results = simulate_power(cs, data_generator=staggered_data, ...) + >>> results = simulate_power(cs, n_simulations=200, seed=42) Notes ----- @@ -1000,54 +1435,98 @@ def simulate_power( 3. Repeat n_simulations times 4. Power = fraction of simulations where p-value < alpha - For staggered designs, you'll need to provide a custom data_generator - that creates appropriate staggered treatment timing. - References ---------- Burlig, F., Preonas, L., & Woerman, M. (2020). "Panel Data and Experimental Design." """ - from diff_diff.prep import generate_did_data - rng = np.random.default_rng(seed) - # Use default data generator if none provided - if data_generator is None: - data_generator = generate_did_data + estimator_name = type(estimator).__name__ + registry = _get_registry() + profile = registry.get(estimator_name) + + # If no profile and no custom data_generator, raise + if profile is None and data_generator is None: + raise ValueError( + f"Estimator '{estimator_name}' not in registry. " + f"Provide a custom data_generator and estimator_kwargs " + f"(the full dict of keyword arguments for estimator.fit(), " + f"e.g. dict(outcome='y', treatment='treat', time='period'))." + ) + + # When a custom data_generator is provided, bypass registry DGP + use_custom_dgp = data_generator is not None data_gen_kwargs = data_generator_kwargs or {} est_kwargs = estimator_kwargs or {} + # SyntheticDiD placebo variance requires n_control > n_treated. + # Check after merging data_generator_kwargs so overrides of n_treated + # are accounted for. + if estimator_name == "SyntheticDiD" and not use_custom_dgp: + vm = getattr(estimator, "variance_method", "placebo") + effective_n_treated = data_gen_kwargs.get( + "n_treated", max(1, int(n_units * treatment_fraction)) + ) + n_control = n_units - effective_n_treated + if vm == "placebo" and n_control <= effective_n_treated: + raise ValueError( + f"SyntheticDiD placebo variance requires more control than " + f"treated units (got n_control={n_control}, " + f"n_treated={effective_n_treated}). Either lower " + f"treatment_fraction so that n_control > n_treated, or use " + f"SyntheticDiD(variance_method='bootstrap')." + ) + + # Warn if staggered estimator settings don't match auto DGP + if profile is not None and not use_custom_dgp: + _check_staggered_dgp_compat(estimator, data_generator_kwargs) + + # Warn if DDD design inputs are silently ignored + if estimator_name == "TripleDifference" and not use_custom_dgp: + _check_ddd_dgp_compat( + n_units, + n_periods, + treatment_fraction, + treatment_period, + data_generator_kwargs, + ) + # Determine effect sizes to test if effect_sizes is None: effect_sizes = [treatment_effect] all_powers = [] - # For the primary effect (last in list), collect detailed results - # Use index-based comparison to avoid float precision issues + # For the primary effect, collect detailed results if len(effect_sizes) == 1: primary_idx = 0 else: - # Find index of treatment_effect in effect_sizes primary_idx = -1 for i, es in enumerate(effect_sizes): if np.isclose(es, treatment_effect): primary_idx = i break if primary_idx == -1: - primary_idx = len(effect_sizes) - 1 # Default to last + primary_idx = len(effect_sizes) - 1 primary_effect = effect_sizes[primary_idx] + # Initialize so they are always bound + primary_estimates: List[float] = [] + primary_ses: List[float] = [] + primary_p_values: List[float] = [] + primary_rejections: List[bool] = [] + primary_ci_contains: List[bool] = [] + for effect_idx, effect in enumerate(effect_sizes): - is_primary = (effect_idx == primary_idx) + is_primary = effect_idx == primary_idx - estimates = [] - ses = [] - p_values = [] - rejections = [] - ci_contains_true = [] + estimates: List[float] = [] + ses: List[float] = [] + p_values: List[float] = [] + rejections: List[bool] = [] + ci_contains_true: List[bool] = [] n_failures = 0 for sim in range(n_simulations): @@ -1055,90 +1534,76 @@ def simulate_power( pct = (sim + effect_idx * n_simulations) / (len(effect_sizes) * n_simulations) print(f" Simulation progress: {pct:.0%}") - # Generate data sim_seed = rng.integers(0, 2**31) - data = data_generator( - n_units=n_units, - n_periods=n_periods, - treatment_effect=effect, - treatment_fraction=treatment_fraction, - treatment_period=treatment_period, - noise_sd=sigma, - seed=sim_seed, - **data_gen_kwargs - ) + + # --- Generate data --- + if use_custom_dgp: + assert data_generator is not None + data = data_generator( + n_units=n_units, + n_periods=n_periods, + treatment_effect=effect, + treatment_fraction=treatment_fraction, + treatment_period=treatment_period, + noise_sd=sigma, + seed=sim_seed, + **data_gen_kwargs, + ) + else: + assert profile is not None + dgp_kwargs = profile.dgp_kwargs_builder( + n_units=n_units, + n_periods=n_periods, + treatment_effect=effect, + treatment_fraction=treatment_fraction, + treatment_period=treatment_period, + sigma=sigma, + ) + dgp_kwargs.update(data_gen_kwargs) + dgp_kwargs.pop("seed", None) + data = profile.default_dgp(seed=sim_seed, **dgp_kwargs) try: - # Fit estimator - # Try to determine the right arguments based on estimator type - estimator_name = type(estimator).__name__ - - if estimator_name == "DifferenceInDifferences": - result = estimator.fit( - data, - outcome="outcome", - treatment="treated", - time="post", - **est_kwargs - ) - elif estimator_name == "TwoWayFixedEffects": - result = estimator.fit( - data, - outcome="outcome", - treatment="treated", - time="period", - unit="unit", - **est_kwargs - ) - elif estimator_name == "MultiPeriodDiD": - post_periods = list(range(treatment_period, n_periods)) - result = estimator.fit( - data, - outcome="outcome", - treatment="treated", - time="period", - post_periods=post_periods, - **est_kwargs - ) - elif estimator_name == "CallawaySantAnna": - # Need to create first_treat column for staggered - # For standard generate_did_data, convert to first_treat format - data = data.copy() - data["first_treat"] = np.where( - data["treated"] == 1, treatment_period, 0 - ) - result = estimator.fit( - data, - outcome="outcome", - unit="unit", - time="period", - first_treat="first_treat", - **est_kwargs + # --- Fit estimator --- + if profile is not None and not use_custom_dgp: + fit_kwargs = profile.fit_kwargs_builder( + data, n_units, n_periods, treatment_period ) + fit_kwargs.update(est_kwargs) else: - # Generic fallback - try common signature - result = estimator.fit( - data, - outcome="outcome", - treatment="treated", - time="post", - **est_kwargs - ) + # Custom DGP fallback: use registry fit kwargs if available, + # otherwise use basic DiD signature + if profile is not None: + fit_kwargs = profile.fit_kwargs_builder( + data, n_units, n_periods, treatment_period + ) + fit_kwargs.update(est_kwargs) + else: + fit_kwargs = dict(est_kwargs) + + result = estimator.fit(data, **fit_kwargs) + + # --- Extract results --- + if profile is not None: + att, se, p_val, ci = profile.result_extractor(result) + elif result_extractor is not None: + att, se, p_val, ci = result_extractor(result) + else: + att = result.att if hasattr(result, "att") else result.avg_att + se = result.se if hasattr(result, "se") else result.avg_se + p_val = result.p_value if hasattr(result, "p_value") else result.avg_p_value + ci = result.conf_int if hasattr(result, "conf_int") else result.avg_conf_int - # Extract results - att = result.att if hasattr(result, 'att') else result.avg_att - se = result.se if hasattr(result, 'se') else result.avg_se - p_val = result.p_value if hasattr(result, 'p_value') else result.avg_p_value - ci = result.conf_int if hasattr(result, 'conf_int') else result.avg_conf_int + # NaN p-value → treat as non-rejection + rejected = bool(p_val < alpha) if not np.isnan(p_val) else False estimates.append(att) ses.append(se) p_values.append(p_val) - rejections.append(p_val < alpha) + rejections.append(rejected) ci_contains_true.append(ci[0] <= effect <= ci[1]) except Exception as e: - # Track failed simulations n_failures += 1 if progress: print(f" Warning: Simulation {sim} failed: {e}") @@ -1148,21 +1613,18 @@ def simulate_power( failure_rate = n_failures / n_simulations if failure_rate > 0.1: warnings.warn( - f"{n_failures}/{n_simulations} simulations ({failure_rate:.1%}) failed " - f"for effect_size={effect}. Check estimator and data generator.", - UserWarning + f"{n_failures}/{n_simulations} simulations ({failure_rate:.1%}) " + f"failed for effect_size={effect}. " + f"Check estimator and data generator.", + UserWarning, ) if len(estimates) == 0: raise RuntimeError("All simulations failed. Check estimator and data generator.") - # Compute power and SE power_val = np.mean(rejections) - power_se = np.sqrt(power_val * (1 - power_val) / len(rejections)) - all_powers.append(power_val) - # Store detailed results for primary effect if is_primary: primary_estimates = estimates primary_ses = ses @@ -1177,10 +1639,9 @@ def simulate_power( z = stats.norm.ppf(0.975) power_ci = ( max(0.0, power_val - z * power_se), - min(1.0, power_val + z * power_se) + min(1.0, power_val + z * power_se), ) - # Compute summary statistics mean_estimate = np.mean(primary_estimates) std_estimate = np.std(primary_estimates, ddof=1) mean_se = np.mean(primary_ses) @@ -1200,15 +1661,602 @@ def simulate_power( powers=all_powers, true_effect=primary_effect, alpha=alpha, - estimator_name=type(estimator).__name__, + estimator_name=estimator_name, simulation_results=[ {"estimate": e, "se": s, "p_value": p, "rejected": r} - for e, s, p, r in zip(primary_estimates, primary_ses, - primary_p_values, primary_rejections) + for e, s, p, r in zip( + primary_estimates, + primary_ses, + primary_p_values, + primary_rejections, + ) ], ) +# --------------------------------------------------------------------------- +# Simulation-based MDE and sample-size search +# --------------------------------------------------------------------------- + + +@dataclass +class SimulationMDEResults: + """ + Results from simulation-based minimum detectable effect search. + + Attributes + ---------- + mde : float + Minimum detectable effect (smallest effect achieving target power). + power_at_mde : float + Power achieved at the MDE. + target_power : float + Target power used in the search. + alpha : float + Significance level. + n_units : int + Sample size used. + n_simulations_per_step : int + Number of simulations per bisection step. + n_steps : int + Number of bisection steps performed. + search_path : list of dict + Diagnostic trace of ``{effect_size, power}`` at each step. + estimator_name : str + Name of the estimator used. + """ + + mde: float + power_at_mde: float + target_power: float + alpha: float + n_units: int + n_simulations_per_step: int + n_steps: int + search_path: List[Dict[str, float]] + estimator_name: str + + def __repr__(self) -> str: + return ( + f"SimulationMDEResults(mde={self.mde:.4f}, " + f"power_at_mde={self.power_at_mde:.3f}, " + f"n_steps={self.n_steps})" + ) + + def summary(self) -> str: + """Generate a formatted summary.""" + lines = [ + "=" * 65, + "Simulation-Based MDE Results".center(65), + "=" * 65, + "", + f"{'Estimator:':<35} {self.estimator_name}", + f"{'Significance level (alpha):':<35} {self.alpha:.3f}", + f"{'Target power:':<35} {self.target_power:.1%}", + f"{'Sample size (n_units):':<35} {self.n_units}", + f"{'Simulations per step:':<35} {self.n_simulations_per_step}", + "", + "-" * 65, + "Search Results".center(65), + "-" * 65, + f"{'Minimum detectable effect:':<35} {self.mde:.4f}", + f"{'Power at MDE:':<35} {self.power_at_mde:.1%}", + f"{'Bisection steps:':<35} {self.n_steps}", + "=" * 65, + ] + return "\n".join(lines) + + def to_dict(self) -> Dict[str, Any]: + """Convert results to a dictionary.""" + return { + "mde": self.mde, + "power_at_mde": self.power_at_mde, + "target_power": self.target_power, + "alpha": self.alpha, + "n_units": self.n_units, + "n_simulations_per_step": self.n_simulations_per_step, + "n_steps": self.n_steps, + "estimator_name": self.estimator_name, + } + + def to_dataframe(self) -> pd.DataFrame: + """Convert results to a single-row DataFrame.""" + return pd.DataFrame([self.to_dict()]) + + +@dataclass +class SimulationSampleSizeResults: + """ + Results from simulation-based sample size search. + + Attributes + ---------- + required_n : int + Required number of units to achieve target power. + power_at_n : float + Power achieved at the required N. + target_power : float + Target power used in the search. + alpha : float + Significance level. + effect_size : float + Effect size used in the search. + n_simulations_per_step : int + Number of simulations per bisection step. + n_steps : int + Number of bisection steps performed. + search_path : list of dict + Diagnostic trace of ``{n_units, power}`` at each step. + estimator_name : str + Name of the estimator used. + """ + + required_n: int + power_at_n: float + target_power: float + alpha: float + effect_size: float + n_simulations_per_step: int + n_steps: int + search_path: List[Dict[str, float]] + estimator_name: str + + def __repr__(self) -> str: + return ( + f"SimulationSampleSizeResults(required_n={self.required_n}, " + f"power_at_n={self.power_at_n:.3f}, " + f"n_steps={self.n_steps})" + ) + + def summary(self) -> str: + """Generate a formatted summary.""" + lines = [ + "=" * 65, + "Simulation-Based Sample Size Results".center(65), + "=" * 65, + "", + f"{'Estimator:':<35} {self.estimator_name}", + f"{'Significance level (alpha):':<35} {self.alpha:.3f}", + f"{'Target power:':<35} {self.target_power:.1%}", + f"{'Effect size:':<35} {self.effect_size:.4f}", + f"{'Simulations per step:':<35} {self.n_simulations_per_step}", + "", + "-" * 65, + "Search Results".center(65), + "-" * 65, + f"{'Required sample size:':<35} {self.required_n}", + f"{'Power at required N:':<35} {self.power_at_n:.1%}", + f"{'Bisection steps:':<35} {self.n_steps}", + "=" * 65, + ] + return "\n".join(lines) + + def to_dict(self) -> Dict[str, Any]: + """Convert results to a dictionary.""" + return { + "required_n": self.required_n, + "power_at_n": self.power_at_n, + "target_power": self.target_power, + "alpha": self.alpha, + "effect_size": self.effect_size, + "n_simulations_per_step": self.n_simulations_per_step, + "n_steps": self.n_steps, + "estimator_name": self.estimator_name, + } + + def to_dataframe(self) -> pd.DataFrame: + """Convert results to a single-row DataFrame.""" + return pd.DataFrame([self.to_dict()]) + + +def simulate_mde( + estimator: Any, + n_units: int = 100, + n_periods: int = 4, + treatment_fraction: float = 0.5, + treatment_period: int = 2, + sigma: float = 1.0, + n_simulations: int = 200, + power: float = 0.80, + alpha: float = 0.05, + effect_range: Optional[Tuple[float, float]] = None, + tol: float = 0.02, + max_steps: int = 15, + seed: Optional[int] = None, + data_generator: Optional[Callable] = None, + data_generator_kwargs: Optional[Dict[str, Any]] = None, + estimator_kwargs: Optional[Dict[str, Any]] = None, + result_extractor: Optional[Callable] = None, + progress: bool = True, +) -> SimulationMDEResults: + """ + Find the minimum detectable effect via simulation-based bisection search. + + Searches over effect sizes to find the smallest effect that achieves the + target power, using ``simulate_power()`` at each step. + + Parameters + ---------- + estimator : estimator object + DiD estimator to use. + n_units : int, default=100 + Number of units per simulation. + n_periods : int, default=4 + Number of time periods. + treatment_fraction : float, default=0.5 + Fraction of units that are treated. + treatment_period : int, default=2 + First post-treatment period (0-indexed). + sigma : float, default=1.0 + Residual standard deviation. + n_simulations : int, default=200 + Simulations per bisection step. + power : float, default=0.80 + Target power. + alpha : float, default=0.05 + Significance level. + effect_range : tuple of (float, float), optional + ``(lo, hi)`` bracket for the search. If None, auto-brackets. + tol : float, default=0.02 + Convergence tolerance on power. + max_steps : int, default=15 + Maximum bisection steps. + seed : int, optional + Random seed for reproducibility. + data_generator : callable, optional + Custom data generation function. + data_generator_kwargs : dict, optional + Additional keyword arguments for data generator. + estimator_kwargs : dict, optional + Additional keyword arguments for estimator.fit(). + result_extractor : callable, optional + Custom function to extract results from the estimator output. + Forwarded to ``simulate_power()``. + progress : bool, default=True + Whether to print progress updates. + + Returns + ------- + SimulationMDEResults + Results including the MDE and search diagnostics. + + Examples + -------- + >>> from diff_diff import simulate_mde, DifferenceInDifferences + >>> result = simulate_mde(DifferenceInDifferences(), n_simulations=100, seed=42) + >>> print(f"MDE: {result.mde:.3f}") + """ + master_rng = np.random.default_rng(seed) + estimator_name = type(estimator).__name__ + search_path: List[Dict[str, float]] = [] + + common_kwargs: Dict[str, Any] = dict( + estimator=estimator, + n_units=n_units, + n_periods=n_periods, + treatment_fraction=treatment_fraction, + treatment_period=treatment_period, + sigma=sigma, + n_simulations=n_simulations, + alpha=alpha, + data_generator=data_generator, + data_generator_kwargs=data_generator_kwargs, + estimator_kwargs=estimator_kwargs, + result_extractor=result_extractor, + progress=False, + ) + + def _power_at(effect: float) -> float: + step_seed = int(master_rng.integers(0, 2**31)) + res = simulate_power(treatment_effect=effect, seed=step_seed, **common_kwargs) + pwr = float(res.power) + search_path.append({"effect_size": effect, "power": pwr}) + if progress: + print(f" MDE search: effect={effect:.4f}, power={pwr:.3f}") + return pwr + + # --- Bracket --- + if effect_range is not None: + lo, hi = effect_range + power_lo = _power_at(lo) + power_hi = _power_at(hi) + if power_lo >= power: + warnings.warn( + f"Power at effect={lo} is {power_lo:.2f} >= target {power}. " + f"Lower bound already exceeds target power. Returning lo as MDE.", + UserWarning, + ) + return SimulationMDEResults( + mde=lo, + power_at_mde=power_lo, + target_power=power, + alpha=alpha, + n_units=n_units, + n_simulations_per_step=n_simulations, + n_steps=len(search_path), + search_path=search_path, + estimator_name=estimator_name, + ) + if power_hi < power: + warnings.warn( + f"Target power {power} not bracketed: power at effect={hi} " + f"is {power_hi:.2f}. Upper bound may be too low.", + UserWarning, + ) + else: + lo = 0.0 + # Check that power at zero is below target (no inflated Type I error) + power_at_zero = _power_at(0.0) + if power_at_zero >= power: + warnings.warn( + f"Power at effect=0 is {power_at_zero:.2f} >= target {power}. " + f"This suggests inflated Type I error. Returning MDE=0.", + UserWarning, + ) + return SimulationMDEResults( + mde=0.0, + power_at_mde=power_at_zero, + target_power=power, + alpha=alpha, + n_units=n_units, + n_simulations_per_step=n_simulations, + n_steps=len(search_path), + search_path=search_path, + estimator_name=estimator_name, + ) + + hi = sigma + for _ in range(10): + if _power_at(hi) >= power: + break + hi *= 2 + else: + warnings.warn( + f"Could not bracket MDE (power at effect={hi} still below " + f"{power}). Returning best upper bound.", + UserWarning, + ) + + # --- Bisect --- + best_effect = hi + best_power = search_path[-1]["power"] if search_path else 0.0 + + for _ in range(max_steps): + mid = (lo + hi) / 2 + pwr = _power_at(mid) + + if pwr >= power: + hi = mid + best_effect = mid + best_power = pwr + else: + lo = mid + + # Convergence: effect range is tight or power is close enough + if hi - lo < max(tol * hi, 1e-6) or abs(pwr - power) < tol: + break + + return SimulationMDEResults( + mde=best_effect, + power_at_mde=best_power, + target_power=power, + alpha=alpha, + n_units=n_units, + n_simulations_per_step=n_simulations, + n_steps=len(search_path), + search_path=search_path, + estimator_name=estimator_name, + ) + + +def simulate_sample_size( + estimator: Any, + treatment_effect: float = 5.0, + n_periods: int = 4, + treatment_fraction: float = 0.5, + treatment_period: int = 2, + sigma: float = 1.0, + n_simulations: int = 200, + power: float = 0.80, + alpha: float = 0.05, + n_range: Optional[Tuple[int, int]] = None, + max_steps: int = 15, + seed: Optional[int] = None, + data_generator: Optional[Callable] = None, + data_generator_kwargs: Optional[Dict[str, Any]] = None, + estimator_kwargs: Optional[Dict[str, Any]] = None, + result_extractor: Optional[Callable] = None, + progress: bool = True, +) -> SimulationSampleSizeResults: + """ + Find the required sample size via simulation-based bisection search. + + Searches over ``n_units`` to find the smallest N that achieves the + target power, using ``simulate_power()`` at each step. + + Parameters + ---------- + estimator : estimator object + DiD estimator to use. + treatment_effect : float, default=5.0 + True treatment effect to simulate. + n_periods : int, default=4 + Number of time periods. + treatment_fraction : float, default=0.5 + Fraction of units that are treated. + treatment_period : int, default=2 + First post-treatment period (0-indexed). + sigma : float, default=1.0 + Residual standard deviation. + n_simulations : int, default=200 + Simulations per bisection step. + power : float, default=0.80 + Target power. + alpha : float, default=0.05 + Significance level. + n_range : tuple of (int, int), optional + ``(lo, hi)`` bracket for sample size. If None, auto-brackets. + max_steps : int, default=15 + Maximum bisection steps. + seed : int, optional + Random seed for reproducibility. + data_generator : callable, optional + Custom data generation function. + data_generator_kwargs : dict, optional + Additional keyword arguments for data generator. + estimator_kwargs : dict, optional + Additional keyword arguments for estimator.fit(). + result_extractor : callable, optional + Custom function to extract results from the estimator output. + Forwarded to ``simulate_power()``. + progress : bool, default=True + Whether to print progress updates. + + Returns + ------- + SimulationSampleSizeResults + Results including the required N and search diagnostics. + + Examples + -------- + >>> from diff_diff import simulate_sample_size, DifferenceInDifferences + >>> result = simulate_sample_size( + ... DifferenceInDifferences(), treatment_effect=5.0, n_simulations=100, seed=42 + ... ) + >>> print(f"Required N: {result.required_n}") + """ + master_rng = np.random.default_rng(seed) + estimator_name = type(estimator).__name__ + search_path: List[Dict[str, float]] = [] + + # Determine min_n from registry + registry = _get_registry() + profile = registry.get(estimator_name) + min_n = profile.min_n if profile is not None else 20 + + common_kwargs: Dict[str, Any] = dict( + estimator=estimator, + n_periods=n_periods, + treatment_effect=treatment_effect, + treatment_fraction=treatment_fraction, + treatment_period=treatment_period, + sigma=sigma, + n_simulations=n_simulations, + alpha=alpha, + data_generator=data_generator, + data_generator_kwargs=data_generator_kwargs, + estimator_kwargs=estimator_kwargs, + result_extractor=result_extractor, + progress=False, + ) + + def _power_at_n(n: int) -> float: + step_seed = int(master_rng.integers(0, 2**31)) + res = simulate_power(n_units=n, seed=step_seed, **common_kwargs) + pwr = float(res.power) + search_path.append({"n_units": float(n), "power": pwr}) + if progress: + print(f" Sample size search: n={n}, power={pwr:.3f}") + return pwr + + # --- Bracket --- + if n_range is not None: + lo, hi = n_range + power_lo = _power_at_n(lo) + if power_lo >= power: + warnings.warn( + f"Power at n={lo} is {power_lo:.2f} >= target {power}. " + f"Lower bound already achieves target power. Returning lo.", + UserWarning, + ) + return SimulationSampleSizeResults( + required_n=lo, + power_at_n=power_lo, + target_power=power, + alpha=alpha, + effect_size=treatment_effect, + n_simulations_per_step=n_simulations, + n_steps=len(search_path), + search_path=search_path, + estimator_name=estimator_name, + ) + power_hi = _power_at_n(hi) + if power_hi < power: + warnings.warn( + f"Target power {power} not bracketed: power at n={hi} " + f"is {power_hi:.2f}. Upper bound may be too low.", + UserWarning, + ) + else: + lo = min_n + power_lo = _power_at_n(lo) + if power_lo >= power: + warnings.warn( + f"Power at registry floor n={lo} is {power_lo:.2f} >= " + f"target {power}. No smaller sample sizes were evaluated. " + f"Pass n_range=(lo, hi) to search below this floor.", + UserWarning, + ) + return SimulationSampleSizeResults( + required_n=lo, + power_at_n=power_lo, + target_power=power, + alpha=alpha, + effect_size=treatment_effect, + n_simulations_per_step=n_simulations, + n_steps=len(search_path), + search_path=search_path, + estimator_name=estimator_name, + ) + hi = max(100, 2 * min_n) + for _ in range(10): + if _power_at_n(hi) >= power: + break + hi *= 2 + else: + warnings.warn( + f"Could not bracket required N (power at n={hi} still below " + f"{power}). Returning best upper bound.", + UserWarning, + ) + + # --- Bisect on integer n_units --- + best_n = hi + best_power = search_path[-1]["power"] if search_path else 0.0 + + for _ in range(max_steps): + if hi - lo <= 2: + break + mid = (lo + hi) // 2 + pwr = _power_at_n(mid) + + if pwr >= power: + hi = mid + best_n = mid + best_power = pwr + else: + lo = mid + + # Final answer is hi (conservative ceiling) — skip if already evaluated + if best_n != hi: + final_pwr = _power_at_n(hi) + if final_pwr >= power: + best_n = hi + best_power = final_pwr + + return SimulationSampleSizeResults( + required_n=best_n, + power_at_n=best_power, + target_power=power, + alpha=alpha, + effect_size=treatment_effect, + n_simulations_per_step=n_simulations, + n_steps=len(search_path), + search_path=search_path, + estimator_name=estimator_name, + ) + + def compute_mde( n_treated: int, n_control: int, diff --git a/docs/api/index.rst b/docs/api/index.rst index d139b76..f153299 100644 --- a/docs/api/index.rst +++ b/docs/api/index.rst @@ -148,10 +148,14 @@ Power analysis for study design: diff_diff.PowerAnalysis diff_diff.PowerResults diff_diff.SimulationPowerResults + diff_diff.SimulationMDEResults + diff_diff.SimulationSampleSizeResults diff_diff.compute_power diff_diff.compute_mde diff_diff.compute_sample_size diff_diff.simulate_power + diff_diff.simulate_mde + diff_diff.simulate_sample_size Pre-Trends Power Analysis ------------------------- diff --git a/docs/api/power.rst b/docs/api/power.rst index 0e17b75..52d57c3 100644 --- a/docs/api/power.rst +++ b/docs/api/power.rst @@ -30,9 +30,11 @@ Main class for analytical power calculations. .. autosummary:: - ~PowerAnalysis.compute_power - ~PowerAnalysis.compute_mde - ~PowerAnalysis.compute_sample_size + ~PowerAnalysis.power + ~PowerAnalysis.mde + ~PowerAnalysis.sample_size + ~PowerAnalysis.power_curve + ~PowerAnalysis.sample_size_curve Example ~~~~~~~ @@ -41,29 +43,19 @@ Example from diff_diff import PowerAnalysis - # Create power analysis object - pa = PowerAnalysis( - effect_size=0.5, - n_treated=100, - n_control=100, - n_pre=4, - n_post=4, - sigma=1.0, - rho=0.5, # Within-unit correlation - alpha=0.05 - ) + pa = PowerAnalysis(alpha=0.05, power=0.80) # Compute power - power = pa.compute_power() - print(f"Power: {power:.2%}") + result = pa.power(effect_size=0.5, n_treated=100, n_control=100, sigma=1.0) + print(f"Power: {result.power:.2%}") # Compute MDE at 80% power - mde = pa.compute_mde(power=0.80) - print(f"MDE: {mde:.3f}") + result = pa.mde(n_treated=100, n_control=100, sigma=1.0) + print(f"MDE: {result.mde:.3f}") # Required sample size - n = pa.compute_sample_size(power=0.80) - print(f"Required N per group: {n}") + result = pa.sample_size(effect_size=0.5, sigma=1.0) + print(f"Required N: {result.required_n}") PowerResults ------------ @@ -85,6 +77,26 @@ Results from simulation-based power analysis. :undoc-members: :show-inheritance: +SimulationMDEResults +-------------------- + +Results from simulation-based MDE search. + +.. autoclass:: diff_diff.SimulationMDEResults + :members: + :undoc-members: + :show-inheritance: + +SimulationSampleSizeResults +--------------------------- + +Results from simulation-based sample size search. + +.. autoclass:: diff_diff.SimulationSampleSizeResults + :members: + :undoc-members: + :show-inheritance: + Convenience Functions --------------------- @@ -116,6 +128,20 @@ Simulation-based power for any DiD estimator. .. autofunction:: diff_diff.simulate_power +simulate_mde +~~~~~~~~~~~~~ + +Simulation-based MDE for any DiD estimator. + +.. autofunction:: diff_diff.simulate_mde + +simulate_sample_size +~~~~~~~~~~~~~~~~~~~~ + +Simulation-based sample size for any DiD estimator. + +.. autofunction:: diff_diff.simulate_sample_size + Complete Example ---------------- @@ -125,8 +151,8 @@ Complete Example PowerAnalysis, compute_mde, simulate_power, + simulate_mde, DifferenceInDifferences, - plot_power_curve, ) # Quick MDE calculation @@ -145,20 +171,23 @@ Complete Example # Simulation-based power for DiD estimator sim_results = simulate_power( estimator=DifferenceInDifferences(), - effect_size=0.5, - n_treated=100, - n_control=100, - n_periods=8, - treatment_start=4, + treatment_effect=5.0, + n_units=100, + n_periods=4, + treatment_period=2, sigma=1.0, - n_simulations=1000 + n_simulations=20, ) print(f"Simulated power: {sim_results.power:.2%}") - # Power curve - pa = PowerAnalysis(n_treated=100, n_control=100, n_pre=4, n_post=4, sigma=1.0) - ax = plot_power_curve(pa, effect_range=(0, 1), n_points=50) - ax.figure.savefig('power_curve.png') + # Simulation-based MDE + mde_results = simulate_mde( + estimator=DifferenceInDifferences(), + n_units=100, + n_simulations=10, + max_steps=5, + ) + print(f"Simulated MDE: {mde_results.mde:.3f}") See Also -------- diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 54bfbac..76ceed8 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -1111,6 +1111,7 @@ Convergence criterion: stop when objective decrease < min_decrease² (default mi - **Single pre-period**: `compute_time_weights` returns `[1.0]` when `n_pre <= 1` (Frank-Wolfe on a 1-element simplex is trivial). - **Bootstrap with 0 control or 0 treated in resample**: Skip iteration (`continue`). If ALL bootstrap iterations fail, raises `ValueError`. If only 1 succeeds, warns and returns SE=0.0. If >5% failure rate, warns about reliability. - **Placebo with n_control <= n_treated**: Warns that not enough control units for placebo variance estimation, returns SE=0.0 and empty placebo effects array. The check is `n_control - n_treated < 1`. +- **Note:** Power analysis functions (`simulate_power`, `simulate_mde`, `simulate_sample_size`) raise `ValueError` for placebo variance when `n_control <= n_treated` (fail-fast before simulation). - **Negative weights attempted**: Frank-Wolfe operates on the simplex (non-negative, sum-to-1), so weights are always feasible by construction. The step size is clipped to [0, 1] and the move is toward a simplex vertex. - **Perfect pre-treatment fit**: Regularization (ζ² ||ω||²) prevents overfitting by penalizing weight concentration. - **Single treated unit**: Valid; placebo variance uses jackknife-style permutations of controls. @@ -1711,6 +1712,8 @@ n = 2(t_{α/2} + t_{1-κ})² σ² / MDE² - Very small effects: may require infeasibly large samples - High ICC: dramatically reduces effective sample size - Unequal allocation: optimal is often 50-50 but depends on costs +- **Note:** The simulation-based power registry (`simulate_power`, `simulate_mde`, `simulate_sample_size`) uses a single-cohort staggered DGP by default. Estimators configured with `control_group="not_yet_treated"`, `clean_control="strict"`, or `anticipation>0` will receive a `UserWarning` because the default DGP does not match their identification strategy. Users must supply `data_generator_kwargs` (e.g., `cohort_periods=[2, 4]`, `never_treated_frac=0.0`) or a custom `data_generator` to match the estimator design. +- **Note:** The `TripleDifference` registry adapter uses `generate_ddd_data`, a fixed 2×2×2 factorial DGP (group × partition × time). The `n_periods`, `treatment_period`, and `treatment_fraction` parameters are ignored — DDD always simulates 2 periods with balanced groups. `n_units` is mapped to `n_per_cell = max(2, n_units // 8)` (effective total N = `n_per_cell × 8`), so non-multiples of 8 are rounded down and values below 16 are clamped to 16. A `UserWarning` is emitted when simulation inputs differ from the effective DDD design. Passing `n_per_cell` in `data_generator_kwargs` suppresses the effective-N rounding warning but not warnings for ignored parameters (`n_periods`, `treatment_period`, `treatment_fraction`). **Reference implementation(s):** - R: `pwr` package (general), `DeclareDesign` (simulation-based) diff --git a/tests/test_doc_snippets.py b/tests/test_doc_snippets.py index 68b9ab4..65f62c5 100644 --- a/tests/test_doc_snippets.py +++ b/tests/test_doc_snippets.py @@ -36,6 +36,7 @@ "api/visualization.rst", "api/honest_did.rst", "api/pretrends.rst", + "api/power.rst", "python_comparison.rst", "r_comparison.rst", ] @@ -62,12 +63,8 @@ ) # Heuristic: skip ``::`` blocks that look like shell or prose, not Python. -_SHELL_HINTS_RE = re.compile( - r"^\s*(\$\s|#!|pip\s+install|maturin\s)", re.MULTILINE -) -_PROSE_HINT_RE = re.compile( - r"^[A-Z][a-z]+ [a-z]+ [a-z]+", re.MULTILINE # English prose sentence -) +_SHELL_HINTS_RE = re.compile(r"^\s*(\$\s|#!|pip\s+install|maturin\s)", re.MULTILINE) +_PROSE_HINT_RE = re.compile(r"^[A-Z][a-z]+ [a-z]+ [a-z]+", re.MULTILINE) # English prose sentence def _extract_snippets(rst_path: Path) -> List[Tuple[int, str]]: @@ -140,6 +137,7 @@ def _collect_cases() -> List[Tuple[str, str, Optional[str]]]: _CASES = _collect_cases() + # --------------------------------------------------------------------------- # Shared namespace builder # --------------------------------------------------------------------------- @@ -167,9 +165,7 @@ def _build_namespace() -> dict: # Synthetic datasets that doc snippets commonly reference rng = np.random.default_rng(42) - staggered = diff_diff.generate_staggered_data( - n_units=60, n_periods=10, seed=42 - ) + staggered = diff_diff.generate_staggered_data(n_units=60, n_periods=10, seed=42) # Add alias columns that doc snippets expect # Use a simple time split (not unit-specific) so basic 2x2 DID works mid = staggered["period"].median() @@ -219,16 +215,18 @@ def _build_namespace() -> dict: # ------------------------------------------------------------------ def _mock_load_card_krueger(**kwargs): n = 40 - return pd.DataFrame({ - "store_id": range(n), - "state": ["NJ"] * (n // 2) + ["PA"] * (n // 2), - "chain": (["bk", "kfc", "roys", "wendys"] * 10)[:n], - "emp_pre": rng.normal(20, 5, n), - "emp_post": rng.normal(21, 5, n), - "wage_pre": rng.normal(4.5, 0.3, n), - "wage_post": rng.normal(5.0, 0.3, n), - "treated": [1] * (n // 2) + [0] * (n // 2), - }) + return pd.DataFrame( + { + "store_id": range(n), + "state": ["NJ"] * (n // 2) + ["PA"] * (n // 2), + "chain": (["bk", "kfc", "roys", "wendys"] * 10)[:n], + "emp_pre": rng.normal(20, 5, n), + "emp_post": rng.normal(21, 5, n), + "wage_pre": rng.normal(4.5, 0.3, n), + "wage_post": rng.normal(5.0, 0.3, n), + "treated": [1] * (n // 2) + [0] * (n // 2), + } + ) def _mock_load_castle_doctrine(**kwargs): states = [f"S{i:02d}" for i in range(10)] @@ -236,17 +234,18 @@ def _mock_load_castle_doctrine(**kwargs): rows = [(s, y) for s in states for y in years] n = len(rows) ft = [0] * 55 + [2005] * 22 + [2007] * 22 + [2009] * 11 - return pd.DataFrame({ - "state": [r[0] for r in rows], - "year": [r[1] for r in rows], - "first_treat": ft[:n], - "homicide_rate": rng.normal(5, 1, n), - "population": rng.integers(500000, 5000000, n), - "income": rng.normal(30000, 5000, n), - "treated": [1 if ft[i] and r[1] >= ft[i] else 0 - for i, r in enumerate(rows)][:n], - "cohort": ft[:n], - }) + return pd.DataFrame( + { + "state": [r[0] for r in rows], + "year": [r[1] for r in rows], + "first_treat": ft[:n], + "homicide_rate": rng.normal(5, 1, n), + "population": rng.integers(500000, 5000000, n), + "income": rng.normal(30000, 5000, n), + "treated": [1 if ft[i] and r[1] >= ft[i] else 0 for i, r in enumerate(rows)][:n], + "cohort": ft[:n], + } + ) def _mock_load_divorce_laws(**kwargs): states = [f"S{i:02d}" for i in range(10)] @@ -254,17 +253,18 @@ def _mock_load_divorce_laws(**kwargs): rows = [(s, y) for s in states for y in years] n = len(rows) ft = [0] * 125 + [1970] * 50 + [1975] * 50 + [1980] * 25 - return pd.DataFrame({ - "state": [r[0] for r in rows], - "year": [r[1] for r in rows], - "first_treat": ft[:n], - "divorce_rate": rng.normal(4, 1, n), - "female_lfp": rng.normal(50, 5, n), - "suicide_rate": rng.normal(5, 2, n), - "treated": [1 if ft[i] and r[1] >= ft[i] else 0 - for i, r in enumerate(rows)][:n], - "cohort": ft[:n], - }) + return pd.DataFrame( + { + "state": [r[0] for r in rows], + "year": [r[1] for r in rows], + "first_treat": ft[:n], + "divorce_rate": rng.normal(4, 1, n), + "female_lfp": rng.normal(50, 5, n), + "suicide_rate": rng.normal(5, 2, n), + "treated": [1 if ft[i] and r[1] >= ft[i] else 0 for i, r in enumerate(rows)][:n], + "cohort": ft[:n], + } + ) def _mock_load_mpdta(**kwargs): counties = list(range(1, 21)) @@ -272,14 +272,16 @@ def _mock_load_mpdta(**kwargs): rows = [(c, y) for c in counties for y in years] n = len(rows) ft = ([0] * 25 + [2004] * 25 + [2006] * 25 + [2007] * 25)[:n] - return pd.DataFrame({ - "countyreal": [r[0] for r in rows], - "year": [r[1] for r in rows], - "lpop": rng.normal(10, 1, n), - "lemp": rng.normal(8, 0.5, n), - "first_treat": ft, - "treat": [1 if f != 0 else 0 for f in ft], - }) + return pd.DataFrame( + { + "countyreal": [r[0] for r in rows], + "year": [r[1] for r in rows], + "lpop": rng.normal(10, 1, n), + "lemp": rng.normal(8, 0.5, n), + "first_treat": ft, + "treat": [1 if f != 0 else 0 for f in ft], + } + ) _dataset_dispatch = { "card_krueger": _mock_load_card_krueger, @@ -303,6 +305,7 @@ def _mock_list_datasets(): # Inject mocks into namespace so `from diff_diff.datasets import ...` works import types + mock_datasets_mod = types.ModuleType("diff_diff.datasets") mock_datasets_mod.load_card_krueger = _mock_load_card_krueger mock_datasets_mod.load_castle_doctrine = _mock_load_castle_doctrine @@ -311,6 +314,7 @@ def _mock_list_datasets(): mock_datasets_mod.load_dataset = _mock_load_dataset mock_datasets_mod.list_datasets = _mock_list_datasets import sys + sys.modules["diff_diff.datasets"] = mock_datasets_mod diff_diff.datasets = mock_datasets_mod @@ -333,6 +337,7 @@ def _restore_datasets_module(): """Restore diff_diff.datasets after each test to prevent mock leaking.""" import sys as _sys import diff_diff as _dd + orig_mod = _sys.modules.get("diff_diff.datasets") orig_attr = getattr(_dd, "datasets", None) yield diff --git a/tests/test_power.py b/tests/test_power.py index d012cda..5008109 100644 --- a/tests/test_power.py +++ b/tests/test_power.py @@ -1,20 +1,52 @@ """Tests for power analysis module.""" +import warnings + import numpy as np import pandas as pd import pytest from diff_diff import ( + TROP, + CallawaySantAnna, DifferenceInDifferences, + EfficientDiD, + ImputationDiD, + MultiPeriodDiD, PowerAnalysis, PowerResults, + SimulationMDEResults, SimulationPowerResults, + SimulationSampleSizeResults, + StackedDiD, + SunAbraham, + SyntheticDiD, + TripleDifference, + TwoStageDiD, + TwoWayFixedEffects, compute_mde, compute_power, compute_sample_size, + simulate_mde, simulate_power, + simulate_sample_size, +) +from diff_diff.power import ( + MAX_SAMPLE_SIZE, + _basic_dgp_kwargs, + _basic_fit_kwargs, + _ddd_dgp_kwargs, + _ddd_fit_kwargs, + _extract_multiperiod, + _extract_simple, + _extract_staggered, + _factor_dgp_kwargs, + _get_registry, + _staggered_dgp_kwargs, + _staggered_fit_kwargs, + _trop_fit_kwargs, ) -from diff_diff.power import MAX_SAMPLE_SIZE +from diff_diff.prep import generate_did_data class TestPowerAnalysis: @@ -80,12 +112,7 @@ def test_mde_decreases_with_sample_size(self): def test_power_calculation(self): """Test power calculation.""" pa = PowerAnalysis(alpha=0.05) - result = pa.power( - effect_size=0.5, - n_treated=50, - n_control=50, - sigma=1.0 - ) + result = pa.power(effect_size=0.5, n_treated=50, n_control=50, sigma=1.0) assert isinstance(result, PowerResults) assert 0 < result.power < 1 @@ -95,12 +122,8 @@ def test_power_increases_with_effect_size(self): """Test that power increases with effect size.""" pa = PowerAnalysis() - result_small = pa.power( - effect_size=0.2, n_treated=50, n_control=50, sigma=1.0 - ) - result_large = pa.power( - effect_size=0.8, n_treated=50, n_control=50, sigma=1.0 - ) + result_small = pa.power(effect_size=0.2, n_treated=50, n_control=50, sigma=1.0) + result_large = pa.power(effect_size=0.8, n_treated=50, n_control=50, sigma=1.0) assert result_large.power > result_small.power @@ -108,12 +131,8 @@ def test_power_increases_with_sample_size(self): """Test that power increases with sample size.""" pa = PowerAnalysis() - result_small = pa.power( - effect_size=0.5, n_treated=25, n_control=25, sigma=1.0 - ) - result_large = pa.power( - effect_size=0.5, n_treated=100, n_control=100, sigma=1.0 - ) + result_small = pa.power(effect_size=0.5, n_treated=25, n_control=25, sigma=1.0) + result_large = pa.power(effect_size=0.5, n_treated=100, n_control=100, sigma=1.0) assert result_large.power > result_small.power @@ -140,14 +159,8 @@ def test_panel_design(self): pa = PowerAnalysis(power=0.80) # Panel with multiple periods should have smaller MDE - result_2period = pa.mde( - n_treated=50, n_control=50, sigma=1.0, - n_pre=1, n_post=1 - ) - result_6period = pa.mde( - n_treated=50, n_control=50, sigma=1.0, - n_pre=3, n_post=3 - ) + result_2period = pa.mde(n_treated=50, n_control=50, sigma=1.0, n_pre=1, n_post=1) + result_6period = pa.mde(n_treated=50, n_control=50, sigma=1.0, n_pre=3, n_post=3) # More periods should reduce MDE (more data) assert result_6period.mde < result_2period.mde @@ -157,14 +170,8 @@ def test_icc_effect(self): """Test that intra-cluster correlation affects power.""" pa = PowerAnalysis(power=0.80) - result_no_icc = pa.mde( - n_treated=50, n_control=50, sigma=1.0, - n_pre=3, n_post=3, rho=0.0 - ) - result_with_icc = pa.mde( - n_treated=50, n_control=50, sigma=1.0, - n_pre=3, n_post=3, rho=0.5 - ) + result_no_icc = pa.mde(n_treated=50, n_control=50, sigma=1.0, n_pre=3, n_post=3, rho=0.0) + result_with_icc = pa.mde(n_treated=50, n_control=50, sigma=1.0, n_pre=3, n_post=3, rho=0.5) # Higher ICC should increase MDE (less independent information) assert result_with_icc.mde > result_no_icc.mde @@ -173,8 +180,7 @@ def test_power_curve(self): """Test power curve generation.""" pa = PowerAnalysis() curve = pa.power_curve( - n_treated=50, n_control=50, sigma=1.0, - effect_sizes=[0.1, 0.2, 0.3, 0.5, 0.7, 1.0] + n_treated=50, n_control=50, sigma=1.0, effect_sizes=[0.1, 0.2, 0.3, 0.5, 0.7, 1.0] ) assert isinstance(curve, pd.DataFrame) @@ -196,8 +202,7 @@ def test_sample_size_curve(self): """Test sample size curve generation.""" pa = PowerAnalysis() curve = pa.sample_size_curve( - effect_size=0.5, sigma=1.0, - sample_sizes=[20, 50, 100, 150, 200] + effect_size=0.5, sigma=1.0, sample_sizes=[20, 50, 100, 150, 200] ) assert isinstance(curve, pd.DataFrame) @@ -258,22 +263,14 @@ def test_one_sided_power_calculation(self): pa_two = PowerAnalysis(alternative="two-sided") # For positive effect, 'greater' should have higher power than two-sided - result_greater = pa_greater.power( - effect_size=0.5, n_treated=50, n_control=50, sigma=1.0 - ) - result_two = pa_two.power( - effect_size=0.5, n_treated=50, n_control=50, sigma=1.0 - ) + result_greater = pa_greater.power(effect_size=0.5, n_treated=50, n_control=50, sigma=1.0) + result_two = pa_two.power(effect_size=0.5, n_treated=50, n_control=50, sigma=1.0) assert result_greater.power > result_two.power # For negative effect, 'less' should have higher power - result_less = pa_less.power( - effect_size=-0.5, n_treated=50, n_control=50, sigma=1.0 - ) - result_two_neg = pa_two.power( - effect_size=-0.5, n_treated=50, n_control=50, sigma=1.0 - ) + result_less = pa_less.power(effect_size=-0.5, n_treated=50, n_control=50, sigma=1.0) + result_two_neg = pa_two.power(effect_size=-0.5, n_treated=50, n_control=50, sigma=1.0) assert result_less.power > result_two_neg.power @@ -282,12 +279,8 @@ def test_negative_effect_size(self): pa = PowerAnalysis() # Power should work the same for negative effects (symmetric) - result_pos = pa.power( - effect_size=0.5, n_treated=50, n_control=50, sigma=1.0 - ) - result_neg = pa.power( - effect_size=-0.5, n_treated=50, n_control=50, sigma=1.0 - ) + result_pos = pa.power(effect_size=0.5, n_treated=50, n_control=50, sigma=1.0) + result_neg = pa.power(effect_size=-0.5, n_treated=50, n_control=50, sigma=1.0) # Two-sided test should have same power for positive and negative effects assert abs(result_pos.power - result_neg.power) < 0.01 @@ -297,20 +290,14 @@ def test_extreme_icc(self): pa = PowerAnalysis(power=0.80) # Test with very high ICC (0.99) - result_extreme = pa.mde( - n_treated=50, n_control=50, sigma=1.0, - n_pre=5, n_post=5, rho=0.99 - ) + result_extreme = pa.mde(n_treated=50, n_control=50, sigma=1.0, n_pre=5, n_post=5, rho=0.99) - result_moderate = pa.mde( - n_treated=50, n_control=50, sigma=1.0, - n_pre=5, n_post=5, rho=0.5 - ) + result_moderate = pa.mde(n_treated=50, n_control=50, sigma=1.0, n_pre=5, n_post=5, rho=0.5) # Extreme ICC should have higher MDE (less independent info) assert result_extreme.mde > result_moderate.mde # MDE should still be finite and reasonable - assert result_extreme.mde < float('inf') + assert result_extreme.mde < float("inf") assert result_extreme.mde > 0 @@ -326,12 +313,7 @@ def test_compute_mde(self): def test_compute_power(self): """Test compute_power convenience function.""" - power = compute_power( - effect_size=0.5, - n_treated=50, - n_control=50, - sigma=1.0 - ) + power = compute_power(effect_size=0.5, n_treated=50, n_control=50, sigma=1.0) assert isinstance(power, float) assert 0 < power < 1 @@ -353,12 +335,8 @@ def test_convenience_functions_consistency(self): assert mde_class == mde_func # Power - power_class = pa.power( - effect_size=0.5, n_treated=50, n_control=50, sigma=1.0 - ).power - power_func = compute_power( - effect_size=0.5, n_treated=50, n_control=50, sigma=1.0 - ) + power_class = pa.power(effect_size=0.5, n_treated=50, n_control=50, sigma=1.0).power + power_func = compute_power(effect_size=0.5, n_treated=50, n_control=50, sigma=1.0) assert power_class == power_func # Sample size @@ -563,13 +541,16 @@ class Result: return Result() # Test with low failure rate (should not warn) + from diff_diff.prep import generate_did_data + estimator = FailingEstimator(fail_rate=0.0) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - results = simulate_power( + simulate_power( estimator=estimator, n_simulations=10, progress=False, + data_generator=generate_did_data, ) # Should have completed successfully without warning assert len([x for x in w if "simulations" in str(x.message)]) == 0 @@ -583,10 +564,12 @@ def test_plot_power_curve_dataframe(self): pytest.importorskip("matplotlib") from diff_diff.visualization import plot_power_curve - df = pd.DataFrame({ - "effect_size": [0.1, 0.2, 0.3, 0.5, 0.7, 1.0], - "power": [0.1, 0.2, 0.4, 0.7, 0.9, 0.99] - }) + df = pd.DataFrame( + { + "effect_size": [0.1, 0.2, 0.3, 0.5, 0.7, 1.0], + "power": [0.1, 0.2, 0.4, 0.7, 0.9, 0.99], + } + ) ax = plot_power_curve(df, show=False) assert ax is not None @@ -597,10 +580,7 @@ def test_plot_power_curve_manual_data(self): from diff_diff.visualization import plot_power_curve ax = plot_power_curve( - effect_sizes=[0.1, 0.2, 0.3, 0.5], - powers=[0.1, 0.3, 0.6, 0.9], - mde=0.25, - show=False + effect_sizes=[0.1, 0.2, 0.3, 0.5], powers=[0.1, 0.3, 0.6, 0.9], mde=0.25, show=False ) assert ax is not None @@ -609,10 +589,9 @@ def test_plot_power_curve_sample_size(self): pytest.importorskip("matplotlib") from diff_diff.visualization import plot_power_curve - df = pd.DataFrame({ - "sample_size": [20, 50, 100, 150, 200], - "power": [0.2, 0.5, 0.8, 0.9, 0.95] - }) + df = pd.DataFrame( + {"sample_size": [20, 50, 100, 150, 200], "power": [0.2, 0.5, 0.8, 0.9, 0.95]} + ) ax = plot_power_curve(df, show=False) assert ax is not None @@ -626,10 +605,7 @@ def test_plot_validates_input(self): plot_power_curve(show=False) # No data provided with pytest.raises(ValueError): - plot_power_curve( - effect_sizes=[1, 2, 3], - show=False # Missing powers - ) + plot_power_curve(effect_sizes=[1, 2, 3], show=False) # Missing powers class TestEdgeCases: @@ -648,15 +624,11 @@ def test_extreme_power_values(self): pa = PowerAnalysis() # Zero effect should give ~alpha power - result_zero = pa.power( - effect_size=0.0, n_treated=50, n_control=50, sigma=1.0 - ) + result_zero = pa.power(effect_size=0.0, n_treated=50, n_control=50, sigma=1.0) assert result_zero.power < 0.10 # Huge effect should give ~1.0 power - result_huge = pa.power( - effect_size=100.0, n_treated=50, n_control=50, sigma=1.0 - ) + result_huge = pa.power(effect_size=100.0, n_treated=50, n_control=50, sigma=1.0) assert result_huge.power > 0.99 def test_unbalanced_design(self): @@ -689,3 +661,1115 @@ def test_max_sample_size_constant(self): # Verify constant is the expected value assert MAX_SAMPLE_SIZE == 2**31 - 1 + + +# --------------------------------------------------------------------------- +# Registry tests +# --------------------------------------------------------------------------- + + +class TestEstimatorRegistry: + """Tests for the estimator registry.""" + + EXPECTED_ESTIMATORS = [ + "DifferenceInDifferences", + "MultiPeriodDiD", + "CallawaySantAnna", + "SunAbraham", + "ImputationDiD", + "TwoStageDiD", + "StackedDiD", + "EfficientDiD", + "TROP", + "SyntheticDiD", + "TripleDifference", + ] + + def test_all_estimators_registered(self): + """Every supported estimator has a registry entry.""" + registry = _get_registry() + for name in self.EXPECTED_ESTIMATORS: + assert name in registry, f"{name} missing from registry" + + def test_bacon_excluded(self): + """BaconDecomposition is diagnostic-only and should not be in registry.""" + registry = _get_registry() + assert "BaconDecomposition" not in registry + + def test_dgp_kwargs_builders_return_dicts(self): + """Each DGP kwargs builder returns a non-empty dict.""" + params = dict( + n_units=50, + n_periods=4, + treatment_effect=5.0, + treatment_fraction=0.5, + treatment_period=2, + sigma=1.0, + ) + for builder in [ + _basic_dgp_kwargs, + _staggered_dgp_kwargs, + _factor_dgp_kwargs, + _ddd_dgp_kwargs, + ]: + result = builder(**params) + assert isinstance(result, dict) + assert len(result) > 0 + + def test_fit_kwargs_builders_return_dicts(self): + """Each fit kwargs builder returns a dict with 'outcome'.""" + dummy_df = pd.DataFrame({"period": [0, 1, 2, 3]}) + for builder in [ + _basic_fit_kwargs, + _staggered_fit_kwargs, + _ddd_fit_kwargs, + _trop_fit_kwargs, + ]: + result = builder(dummy_df, 50, 4, 2) + assert isinstance(result, dict) + assert "outcome" in result + + def test_extract_simple(self): + """_extract_simple extracts from .att/.se/.p_value/.conf_int.""" + + class MockResult: + att = 3.0 + se = 0.5 + p_value = 0.01 + conf_int = (2.0, 4.0) + + att, se, p, ci = _extract_simple(MockResult()) + assert att == 3.0 + assert se == 0.5 + assert p == 0.01 + assert ci == (2.0, 4.0) + + def test_extract_multiperiod(self): + """_extract_multiperiod extracts from avg_* attributes.""" + + class MockResult: + avg_att = 4.0 + avg_se = 0.6 + avg_p_value = 0.001 + avg_conf_int = (2.8, 5.2) + + att, se, p, ci = _extract_multiperiod(MockResult()) + assert att == 4.0 + assert se == 0.6 + assert p == 0.001 + assert ci == (2.8, 5.2) + + def test_extract_staggered_analytical(self): + """_extract_staggered handles analytical result objects.""" + + class MockResult: + overall_att = 2.0 + overall_se = 0.3 + overall_p_value = 0.02 + overall_conf_int = (1.4, 2.6) + + att, se, p, ci = _extract_staggered(MockResult()) + assert att == 2.0 + assert se == 0.3 + assert p == 0.02 + assert ci == (1.4, 2.6) + + def test_extract_staggered_bootstrap_fallback(self): + """_extract_staggered falls back to bootstrap attribute names.""" + + class MockBootstrapResult: + overall_att = 2.0 + overall_att_se = 0.4 + overall_att_p_value = 0.03 + overall_att_ci = (1.2, 2.8) + + att, se, p, ci = _extract_staggered(MockBootstrapResult()) + assert att == 2.0 + assert se == 0.4 + assert p == 0.03 + assert ci == (1.2, 2.8) + + def test_continuous_did_not_in_registry(self): + """ContinuousDiD is not in registry and raises without custom data_generator.""" + from diff_diff import ContinuousDiD + + registry = _get_registry() + assert "ContinuousDiD" not in registry + + with pytest.raises(ValueError, match="not in registry"): + simulate_power( + ContinuousDiD(), + n_simulations=5, + progress=False, + ) + + def test_twfe_in_registry(self): + """TwoWayFixedEffects is in the registry.""" + registry = _get_registry() + assert "TwoWayFixedEffects" in registry + + def test_unknown_estimator_raises_without_data_generator(self): + """Unknown estimator without data_generator raises ValueError.""" + + class UnknownEstimator: + pass + + with pytest.raises(ValueError, match="not in registry"): + simulate_power( + UnknownEstimator(), + n_simulations=5, + progress=False, + ) + + +# --------------------------------------------------------------------------- +# Estimator coverage tests for simulate_power +# --------------------------------------------------------------------------- + + +class TestEstimatorCoverage: + """Verify simulate_power works for each registered estimator.""" + + def _assert_valid_result(self, result, expected_name): + assert 0 <= result.power <= 1 + assert result.estimator_name == expected_name + assert np.isfinite(result.mean_estimate) + assert result.n_simulations > 0 + assert result.coverage >= 0 + + def test_did(self): + result = simulate_power( + DifferenceInDifferences(), + n_simulations=10, + seed=42, + progress=False, + ) + self._assert_valid_result(result, "DifferenceInDifferences") + + def test_multiperiod(self): + result = simulate_power( + MultiPeriodDiD(), + n_simulations=10, + seed=42, + progress=False, + ) + self._assert_valid_result(result, "MultiPeriodDiD") + + def test_callaway_santanna(self): + result = simulate_power( + CallawaySantAnna(), + n_units=60, + n_periods=6, + treatment_period=3, + n_simulations=10, + seed=42, + progress=False, + ) + self._assert_valid_result(result, "CallawaySantAnna") + + def test_sun_abraham(self): + result = simulate_power( + SunAbraham(), + n_units=60, + n_periods=6, + treatment_period=3, + n_simulations=10, + seed=42, + progress=False, + ) + self._assert_valid_result(result, "SunAbraham") + + def test_imputation_did(self): + result = simulate_power( + ImputationDiD(), + n_units=60, + n_periods=6, + treatment_period=3, + n_simulations=10, + seed=42, + progress=False, + ) + self._assert_valid_result(result, "ImputationDiD") + + def test_two_stage_did(self): + result = simulate_power( + TwoStageDiD(), + n_units=60, + n_periods=6, + treatment_period=3, + n_simulations=10, + seed=42, + progress=False, + ) + self._assert_valid_result(result, "TwoStageDiD") + + def test_stacked_did(self): + result = simulate_power( + StackedDiD(), + n_units=60, + n_periods=6, + treatment_period=3, + n_simulations=10, + seed=42, + progress=False, + ) + self._assert_valid_result(result, "StackedDiD") + + def test_efficient_did(self): + result = simulate_power( + EfficientDiD(), + n_units=60, + n_periods=6, + treatment_period=3, + n_simulations=10, + seed=42, + progress=False, + ) + self._assert_valid_result(result, "EfficientDiD") + + def test_triple_difference(self): + result = simulate_power( + TripleDifference(), + n_units=80, + n_periods=2, + treatment_period=1, + n_simulations=10, + seed=42, + progress=False, + ) + self._assert_valid_result(result, "TripleDifference") + + def test_ddd_warns_ignored_params(self): + """TripleDifference warns when simulation params don't match DDD design.""" + with pytest.warns(UserWarning, match="n_periods=6 is ignored"): + simulate_power( + TripleDifference(), + n_units=80, + n_periods=6, + treatment_period=3, + treatment_fraction=0.3, + n_simulations=2, + seed=42, + progress=False, + ) + + def test_ddd_warns_nonaligned_n_units(self): + """TripleDifference warns when n_units doesn't map cleanly to 8 cells.""" + with pytest.warns(UserWarning, match="effective sample size is 64"): + simulate_power( + TripleDifference(), + n_units=65, + n_periods=2, + treatment_period=1, + n_simulations=2, + seed=42, + progress=False, + ) + + def test_ddd_small_n_units_warns(self): + """TripleDifference warns when n_units < 16 (clamped to 16).""" + with pytest.warns(UserWarning, match="effective sample size is 16"): + simulate_power( + TripleDifference(), + n_units=10, + n_periods=2, + treatment_period=1, + n_simulations=2, + seed=42, + progress=False, + ) + + def test_ddd_no_warn_aligned(self): + """No warning when n_units is a multiple of 8 and defaults match DDD.""" + with warnings.catch_warnings(): + warnings.simplefilter("error") + simulate_power( + TripleDifference(), + n_units=80, + n_periods=2, + treatment_period=1, + treatment_fraction=0.5, + n_simulations=2, + seed=42, + progress=False, + ) + + def test_ddd_no_warn_custom_dgp(self): + """Custom data_generator bypasses the DDD compat check.""" + + def custom_dgp(**kwargs): + from diff_diff.prep_dgp import generate_ddd_data + + return generate_ddd_data(n_per_cell=10, seed=kwargs.get("seed")) + + with warnings.catch_warnings(): + warnings.simplefilter("error") + simulate_power( + TripleDifference(), + n_units=65, + n_periods=6, + data_generator=custom_dgp, + estimator_kwargs=dict( + outcome="outcome", + group="group", + partition="partition", + time="time", + ), + n_simulations=2, + seed=42, + progress=False, + ) + + def test_ddd_no_warn_n_per_cell_override(self): + """n_per_cell override suppresses rounding warning but not ignored-param warnings.""" + with pytest.warns(UserWarning, match="n_periods=6 is ignored"): + simulate_power( + TripleDifference(), + n_units=80, + n_periods=6, + treatment_period=1, + data_generator_kwargs=dict(n_per_cell=10), + n_simulations=2, + seed=42, + progress=False, + ) + + def test_ddd_n_per_cell_suppresses_rounding(self): + """n_per_cell override suppresses effective-N rounding warning.""" + with warnings.catch_warnings(): + warnings.simplefilter("error") + simulate_power( + TripleDifference(), + n_units=80, + n_periods=2, + treatment_period=1, + data_generator_kwargs=dict(n_per_cell=10), + n_simulations=2, + seed=42, + progress=False, + ) + + @pytest.mark.slow + def test_ddd_mde(self): + """simulate_mde works for TripleDifference.""" + result = simulate_mde( + TripleDifference(), + n_units=80, + n_periods=2, + treatment_period=1, + n_simulations=5, + effect_range=(0.5, 5.0), + seed=42, + progress=False, + ) + assert isinstance(result, SimulationMDEResults) + assert result.mde > 0 + + @pytest.mark.slow + def test_ddd_sample_size(self): + """simulate_sample_size works for TripleDifference.""" + result = simulate_sample_size( + TripleDifference(), + n_periods=2, + treatment_period=1, + n_simulations=5, + n_range=(64, 200), + seed=42, + progress=False, + ) + assert isinstance(result, SimulationSampleSizeResults) + assert result.required_n > 0 + + @pytest.mark.slow + def test_trop(self): + result = simulate_power( + TROP(), + n_units=50, + n_periods=6, + treatment_period=3, + treatment_fraction=0.3, + n_simulations=10, + seed=42, + progress=False, + ) + self._assert_valid_result(result, "TROP") + + @pytest.mark.slow + def test_synthetic_did(self): + result = simulate_power( + SyntheticDiD(), + n_units=50, + n_periods=6, + treatment_period=3, + treatment_fraction=0.3, + n_simulations=10, + seed=42, + progress=False, + ) + self._assert_valid_result(result, "SyntheticDiD") + + def test_sdid_placebo_rejects_high_fraction(self): + """SyntheticDiD placebo variance raises when n_control <= n_treated.""" + with pytest.raises(ValueError, match="placebo variance requires more control"): + simulate_power( + SyntheticDiD(), + treatment_fraction=0.5, + n_simulations=5, + seed=42, + progress=False, + ) + + @pytest.mark.slow + def test_sdid_placebo_boundary_fraction(self): + """treatment_fraction=0.49 with 50 units gives n_control=26 > n_treated=24.""" + result = simulate_power( + SyntheticDiD(), + treatment_fraction=0.49, + n_units=50, + n_periods=6, + treatment_period=3, + n_simulations=5, + seed=42, + progress=False, + ) + self._assert_valid_result(result, "SyntheticDiD") + + @pytest.mark.slow + def test_sdid_bootstrap_allows_high_fraction(self): + """Bootstrap variance method bypasses the placebo constraint.""" + result = simulate_power( + SyntheticDiD(variance_method="bootstrap"), + treatment_fraction=0.5, + n_units=50, + n_periods=6, + treatment_period=3, + n_simulations=5, + seed=42, + progress=False, + ) + self._assert_valid_result(result, "SyntheticDiD") + assert result.power >= 0 + + def test_sdid_mde_rejects_high_fraction(self): + """simulate_mde raises for SyntheticDiD placebo with high treatment_fraction.""" + with pytest.raises(ValueError, match="placebo variance requires more control"): + simulate_mde( + SyntheticDiD(), + treatment_fraction=0.5, + n_simulations=5, + seed=42, + progress=False, + ) + + def test_sdid_sample_size_rejects_high_fraction(self): + """simulate_sample_size raises for SyntheticDiD placebo with high fraction.""" + with pytest.raises(ValueError, match="placebo variance requires more control"): + simulate_sample_size( + SyntheticDiD(), + treatment_fraction=0.5, + n_simulations=5, + seed=42, + progress=False, + ) + + def test_sdid_placebo_rejects_n_treated_override(self): + """SDID placebo raises when data_generator_kwargs overrides n_treated.""" + with pytest.raises(ValueError, match="placebo variance requires more control"): + simulate_power( + SyntheticDiD(), + n_units=50, + treatment_fraction=0.3, + data_generator_kwargs=dict(n_treated=30), + n_simulations=5, + seed=42, + progress=False, + ) + + def test_sdid_mde_rejects_n_treated_override(self): + """simulate_mde raises when kwargs override makes n_control <= n_treated.""" + with pytest.raises(ValueError, match="placebo variance requires more control"): + simulate_mde( + SyntheticDiD(), + n_units=50, + treatment_fraction=0.3, + data_generator_kwargs=dict(n_treated=30), + n_simulations=5, + seed=42, + progress=False, + ) + + def test_sdid_sample_size_rejects_n_treated_override(self): + """simulate_sample_size raises when kwargs override is infeasible.""" + with pytest.raises(ValueError, match="placebo variance requires more control"): + simulate_sample_size( + SyntheticDiD(), + treatment_fraction=0.3, + data_generator_kwargs=dict(n_treated=30), + n_range=(50, 100), + n_simulations=5, + seed=42, + progress=False, + ) + + @pytest.mark.slow + def test_sdid_mde(self): + """simulate_mde works for SyntheticDiD with valid treatment_fraction.""" + result = simulate_mde( + SyntheticDiD(), + treatment_fraction=0.3, + n_units=50, + n_periods=6, + treatment_period=3, + n_simulations=5, + effect_range=(0.5, 3.0), + seed=42, + progress=False, + ) + assert isinstance(result, SimulationMDEResults) + assert result.mde > 0 + + @pytest.mark.slow + def test_sdid_sample_size(self): + """simulate_sample_size works for SyntheticDiD with valid fraction.""" + result = simulate_sample_size( + SyntheticDiD(), + treatment_fraction=0.3, + n_periods=6, + treatment_period=3, + n_simulations=5, + n_range=(30, 80), + seed=42, + progress=False, + ) + assert isinstance(result, SimulationSampleSizeResults) + assert result.required_n > 0 + + @pytest.mark.slow + def test_twfe(self): + result = simulate_power( + TwoWayFixedEffects(), + n_simulations=5, + seed=42, + progress=False, + ) + self._assert_valid_result(result, "TwoWayFixedEffects") + + @pytest.mark.slow + def test_twfe_mde(self): + result = simulate_mde( + TwoWayFixedEffects(), + n_simulations=5, + effect_range=(0.5, 5.0), + seed=42, + progress=False, + ) + assert isinstance(result, SimulationMDEResults) + assert result.mde > 0 + + @pytest.mark.slow + def test_twfe_sample_size(self): + result = simulate_sample_size( + TwoWayFixedEffects(), + n_simulations=5, + n_range=(20, 100), + seed=42, + progress=False, + ) + assert isinstance(result, SimulationSampleSizeResults) + assert result.required_n > 0 + + @pytest.mark.slow + def test_custom_fallback_unregistered_estimator(self): + """Unregistered estimator works with custom data_generator and estimator_kwargs.""" + + class _UnregisteredEstimator: + """Unregistered wrapper for testing custom fallback.""" + + def __init__(self): + self._inner = DifferenceInDifferences() + + def fit(self, data, **kwargs): + return self._inner.fit(data, **kwargs) + + result = simulate_power( + _UnregisteredEstimator(), + data_generator=generate_did_data, + estimator_kwargs=dict(outcome="outcome", treatment="treated", time="post"), + n_simulations=5, + seed=42, + progress=False, + ) + assert 0 <= result.power <= 1 + assert result.n_simulations > 0 + + def test_custom_fallback_missing_kwargs_raises(self): + """Unregistered estimator with no estimator_kwargs fails on fit.""" + + class _UnregisteredEstimator: + def __init__(self): + self._inner = DifferenceInDifferences() + + def fit(self, data, **kwargs): + return self._inner.fit(data, **kwargs) + + with pytest.raises((ValueError, TypeError, RuntimeError)): + simulate_power( + _UnregisteredEstimator(), + data_generator=generate_did_data, + n_simulations=5, + seed=42, + progress=False, + ) + + @pytest.mark.slow + def test_custom_result_extractor(self): + """Custom result_extractor works for unregistered estimator.""" + + class _UnregisteredEstimator: + def __init__(self): + self._inner = DifferenceInDifferences() + + def fit(self, data, **kwargs): + return self._inner.fit(data, **kwargs) + + def _custom_extractor(result): + return (result.att, result.se, result.p_value, result.conf_int) + + result = simulate_power( + _UnregisteredEstimator(), + data_generator=generate_did_data, + estimator_kwargs=dict(outcome="outcome", treatment="treated", time="post"), + result_extractor=_custom_extractor, + n_simulations=5, + seed=42, + progress=False, + ) + assert 0 <= result.power <= 1 + assert result.n_simulations > 0 + + @pytest.mark.slow + def test_custom_result_extractor_mde_forwarding(self): + """result_extractor forwards correctly through simulate_mde.""" + + class _UnregisteredEstimator: + def __init__(self): + self._inner = DifferenceInDifferences() + + def fit(self, data, **kwargs): + return self._inner.fit(data, **kwargs) + + def _custom_extractor(result): + return (result.att, result.se, result.p_value, result.conf_int) + + result = simulate_mde( + _UnregisteredEstimator(), + data_generator=generate_did_data, + estimator_kwargs=dict(outcome="outcome", treatment="treated", time="post"), + result_extractor=_custom_extractor, + n_simulations=5, + effect_range=(0.5, 5.0), + seed=42, + progress=False, + ) + assert isinstance(result, SimulationMDEResults) + assert result.mde > 0 + + # -- Staggered DGP compatibility warnings -- + + def test_staggered_dgp_warns_not_yet_treated(self): + """Auto DGP warns when CS has control_group='not_yet_treated'.""" + with pytest.warns(UserWarning, match="not_yet_treated"): + simulate_power( + CallawaySantAnna(control_group="not_yet_treated"), + n_simulations=3, + seed=42, + progress=False, + ) + + def test_staggered_dgp_warns_anticipation(self): + """Auto DGP warns when staggered estimator has anticipation > 0.""" + with pytest.warns(UserWarning, match="anticipation=1"): + simulate_power( + CallawaySantAnna(anticipation=1), + n_simulations=3, + seed=42, + progress=False, + ) + + def test_staggered_dgp_warns_strict_clean_control(self): + """Auto DGP warns when StackedDiD has clean_control='strict'.""" + with pytest.warns(UserWarning, match="strict"): + simulate_power( + StackedDiD(clean_control="strict"), + n_simulations=3, + seed=42, + progress=False, + ) + + def test_staggered_dgp_no_warn_custom_dgp_bypasses_check(self): + """Custom data_generator bypasses DGP compat check entirely.""" + from diff_diff.prep import generate_staggered_data + + def _custom_staggered(**kwargs): + # Adapt simulate_power's standard kwargs to generate_staggered_data + return generate_staggered_data( + n_units=kwargs["n_units"], + n_periods=kwargs["n_periods"], + treatment_effect=kwargs["treatment_effect"], + cohort_periods=[2, 4], + never_treated_frac=0.0, + noise_sd=kwargs["noise_sd"], + seed=kwargs["seed"], + ) + + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + simulate_power( + CallawaySantAnna(control_group="not_yet_treated"), + data_generator=_custom_staggered, + n_periods=6, + treatment_period=3, + estimator_kwargs=dict( + outcome="outcome", + unit="unit", + time="period", + first_treat="first_treat", + ), + n_simulations=3, + seed=42, + progress=False, + ) + + def test_staggered_dgp_no_warn_with_dgp_kwargs_override(self): + """data_generator_kwargs with cohort_periods suppresses warning.""" + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + result = simulate_power( + CallawaySantAnna(control_group="not_yet_treated"), + n_periods=6, + treatment_period=3, + data_generator_kwargs=dict(cohort_periods=[2, 4], never_treated_frac=0.0), + n_simulations=3, + seed=42, + progress=False, + ) + assert 0 <= result.power <= 1 + + @pytest.mark.slow + def test_cs_not_yet_treated_with_matching_dgp(self): + """CS with control_group='not_yet_treated' and multi-cohort DGP.""" + result = simulate_power( + CallawaySantAnna(control_group="not_yet_treated"), + n_units=60, + n_periods=6, + treatment_period=3, + data_generator_kwargs=dict(cohort_periods=[2, 4], never_treated_frac=0.0), + n_simulations=10, + seed=42, + progress=False, + ) + assert 0 <= result.power <= 1 + assert result.n_simulations > 0 + + @pytest.mark.slow + def test_stacked_did_strict_with_matching_dgp(self): + """StackedDiD with clean_control='strict' and multi-cohort DGP.""" + result = simulate_power( + StackedDiD(clean_control="strict", kappa_pre=1, kappa_post=1), + n_units=80, + n_periods=8, + treatment_period=4, + data_generator_kwargs=dict(cohort_periods=[3, 5]), + n_simulations=10, + seed=42, + progress=False, + ) + assert 0 <= result.power <= 1 + assert result.n_simulations > 0 + + +# --------------------------------------------------------------------------- +# simulate_mde tests +# --------------------------------------------------------------------------- + + +class TestSimulateMDE: + """Tests for simulate_mde function.""" + + def test_basic_mde(self): + """MDE found for DiD, power at MDE close to target.""" + result = simulate_mde( + DifferenceInDifferences(), + n_units=100, + sigma=1.0, + n_simulations=50, + seed=42, + progress=False, + ) + assert isinstance(result, SimulationMDEResults) + assert result.mde > 0 + assert result.power_at_mde >= result.target_power - 0.10 + + def test_result_methods(self): + """summary(), to_dict(), to_dataframe() work.""" + result = simulate_mde( + DifferenceInDifferences(), + n_simulations=30, + seed=42, + progress=False, + ) + summary = result.summary() + assert "MDE" in summary or "Minimum" in summary + + d = result.to_dict() + assert "mde" in d + assert "estimator_name" in d + + df = result.to_dataframe() + assert isinstance(df, pd.DataFrame) + assert len(df) == 1 + + def test_monotonicity_in_search_path(self): + """The search path records plausible effect_size / power pairs.""" + result = simulate_mde( + DifferenceInDifferences(), + n_simulations=50, + seed=42, + progress=False, + ) + assert len(result.search_path) > 0 + for step in result.search_path: + assert "effect_size" in step + assert "power" in step + assert 0 <= step["power"] <= 1 + + def test_convergence_within_max_steps(self): + """Search terminates within max_steps.""" + result = simulate_mde( + DifferenceInDifferences(), + n_simulations=30, + max_steps=10, + seed=42, + progress=False, + ) + # n_steps includes bracketing steps + bisection + assert result.n_steps <= 25 # generous bound + + def test_custom_data_generator(self): + """Works with user-provided DGP.""" + from diff_diff.prep import generate_did_data + + result = simulate_mde( + DifferenceInDifferences(), + n_simulations=30, + seed=42, + progress=False, + data_generator=generate_did_data, + ) + assert result.mde > 0 + + def test_small_sigma_gives_small_mde(self): + """Small noise → small MDE.""" + result = simulate_mde( + DifferenceInDifferences(), + n_units=100, + sigma=0.1, + n_simulations=50, + seed=42, + progress=False, + ) + assert result.mde < 1.0 + + def test_large_sigma_gives_large_mde(self): + """Large noise → large MDE.""" + result = simulate_mde( + DifferenceInDifferences(), + n_units=50, + sigma=10.0, + n_simulations=50, + seed=42, + progress=False, + ) + assert result.mde > 1.0 + + def test_explicit_effect_range(self): + """Explicit effect_range evaluates endpoints and populates search_path.""" + result = simulate_mde( + DifferenceInDifferences(), + n_units=100, + sigma=1.0, + n_simulations=30, + effect_range=(0.5, 5.0), + seed=42, + progress=False, + ) + assert result.mde > 0 + assert result.power_at_mde > 0 + assert len(result.search_path) > 0 + + def test_unbracketed_effect_range_warns(self): + """Tiny effect_range that cannot bracket target power warns.""" + with pytest.warns(UserWarning, match="not bracketed"): + simulate_mde( + DifferenceInDifferences(), + n_units=50, + sigma=10.0, + n_simulations=30, + effect_range=(0.0, 0.001), + seed=42, + progress=False, + ) + + +# --------------------------------------------------------------------------- +# simulate_sample_size tests +# --------------------------------------------------------------------------- + + +class TestSimulateSampleSize: + """Tests for simulate_sample_size function.""" + + def test_basic_sample_size(self): + """Required N found for DiD, power at N close to target.""" + result = simulate_sample_size( + DifferenceInDifferences(), + treatment_effect=5.0, + sigma=1.0, + n_simulations=50, + seed=42, + progress=False, + ) + assert isinstance(result, SimulationSampleSizeResults) + assert result.required_n > 0 + assert result.power_at_n >= result.target_power - 0.10 + + def test_result_methods(self): + """summary(), to_dict(), to_dataframe() work.""" + result = simulate_sample_size( + DifferenceInDifferences(), + treatment_effect=5.0, + n_simulations=30, + seed=42, + progress=False, + ) + summary = result.summary() + assert "Sample Size" in summary or "Required" in summary + + d = result.to_dict() + assert "required_n" in d + assert "estimator_name" in d + + df = result.to_dataframe() + assert isinstance(df, pd.DataFrame) + assert len(df) == 1 + + def test_monotonicity_in_search_path(self): + """The search path records plausible n_units / power pairs.""" + result = simulate_sample_size( + DifferenceInDifferences(), + treatment_effect=5.0, + n_simulations=50, + seed=42, + progress=False, + ) + assert len(result.search_path) > 0 + for step in result.search_path: + assert "n_units" in step + assert "power" in step + assert 0 <= step["power"] <= 1 + + def test_custom_data_generator(self): + """Works with user-provided DGP.""" + from diff_diff.prep import generate_did_data + + result = simulate_sample_size( + DifferenceInDifferences(), + treatment_effect=5.0, + n_simulations=30, + seed=42, + progress=False, + data_generator=generate_did_data, + ) + assert result.required_n > 0 + + def test_large_effect_gives_small_n(self): + """Large effect → small N.""" + result = simulate_sample_size( + DifferenceInDifferences(), + treatment_effect=20.0, + sigma=1.0, + n_simulations=50, + seed=42, + progress=False, + ) + assert result.required_n <= 100 + + def test_small_effect_gives_large_n(self): + """Small effect → large N.""" + result = simulate_sample_size( + DifferenceInDifferences(), + treatment_effect=0.5, + sigma=5.0, + n_simulations=50, + seed=42, + progress=False, + ) + assert result.required_n >= 50 + + def test_explicit_n_range(self): + """Explicit n_range evaluates endpoints and populates search_path.""" + result = simulate_sample_size( + DifferenceInDifferences(), + treatment_effect=5.0, + sigma=1.0, + n_simulations=30, + n_range=(20, 200), + seed=42, + progress=False, + ) + assert result.required_n > 0 + assert result.power_at_n > 0 + assert len(result.search_path) > 0 + + def test_unbracketed_n_range_warns(self): + """Tiny n_range that cannot bracket target power warns.""" + with pytest.warns(UserWarning, match="not bracketed"): + simulate_sample_size( + DifferenceInDifferences(), + treatment_effect=0.01, + sigma=10.0, + n_simulations=30, + n_range=(20, 22), + seed=42, + progress=False, + ) + + def test_lo_already_sufficient_explicit(self): + """When lo already meets power, return lo immediately with warning.""" + with pytest.warns(UserWarning, match="Lower bound already achieves"): + result = simulate_sample_size( + DifferenceInDifferences(), + treatment_effect=50.0, + sigma=0.1, + n_simulations=50, + n_range=(20, 200), + seed=42, + progress=False, + ) + assert result.required_n == 20 + assert result.power_at_n >= 0.80 + + def test_lo_already_sufficient_auto(self): + """Auto-bracket warns and returns min_n when effect overwhelmingly large.""" + with pytest.warns(UserWarning, match="registry floor"): + result = simulate_sample_size( + DifferenceInDifferences(), + treatment_effect=50.0, + sigma=0.1, + n_simulations=50, + seed=42, + progress=False, + ) + # min_n for DifferenceInDifferences is 20 + assert result.required_n == 20 + assert result.power_at_n >= 0.80