From e43266054ca34ec027f0df92e93c60d3eb59ff15 Mon Sep 17 00:00:00 2001 From: igerber Date: Wed, 18 Mar 2026 14:57:01 -0400 Subject: [PATCH 1/4] Rename TROP method="twostep" to method="local" with deprecation Rename the per-observation TROP method from "twostep" to "local", forming a natural local/global pair. Both "twostep" and "joint" are now deprecated aliases (removal in v3.0). Also rename internal _joint_* methods to _global_* and Rust exports for consistency. Co-Authored-By: Claude Opus 4.6 (1M context) --- CHANGELOG.md | 4 + CLAUDE.md | 2 +- README.md | 6 +- diff_diff/_backend.py | 32 +++--- diff_diff/trop.py | 123 +++++++++++++---------- docs/api/trop.rst | 19 ++-- docs/choosing_estimator.rst | 2 +- docs/methodology/REGISTRY.md | 31 +++--- docs/troubleshooting.rst | 20 ++++ docs/tutorials/10_trop.ipynb | 6 +- rust/src/lib.rs | 8 +- rust/src/trop.rs | 4 +- tests/test_rust_backend.py | 112 ++++++++++----------- tests/test_trop.py | 182 +++++++++++++++++++---------------- 14 files changed, 305 insertions(+), 246 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 98c0844..27d18e5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Changed +- Rename TROP `method="twostep"` to `method="local"`; `"twostep"` deprecated, removal in v3.0 +- Rename internal TROP `_joint_*` methods to `_global_*` for consistency + ## [2.7.1] - 2026-03-15 ### Changed diff --git a/CLAUDE.md b/CLAUDE.md index 8fc7397..29b1407 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -122,7 +122,7 @@ category (`Methodology/Correctness`, `Performance`, or `Testing/Docs`): `threshold = 0.40 if n_boot < 100 else 0.15`. - **`assert_nan_inference()`** from conftest.py: Use to validate ALL inference fields are NaN-consistent. Don't check individual fields separately. -- **Slow tests**: TROP methodology/joint-method tests, Sun-Abraham bootstrap, and +- **Slow tests**: TROP methodology/global-method tests, Sun-Abraham bootstrap, and TROP-parity tests are marked `@pytest.mark.slow` and excluded by default via `addopts`. `test_trop.py` uses per-class markers (not file-level) so that validation, API, and solver tests still run in the pure Python CI fallback. Run `pytest -m ''` to include diff --git a/README.md b/README.md index ae28d3c..3467264 100644 --- a/README.md +++ b/README.md @@ -1517,7 +1517,7 @@ trop = TROP( ```python TROP( - method='twostep', # Estimation method: 'twostep' (default) or 'joint' + method='local', # Estimation method: 'local' (default) or 'global' lambda_time_grid=None, # Time decay grid (default: [0, 0.1, 0.5, 1, 2, 5]) lambda_unit_grid=None, # Unit distance grid (default: [0, 0.1, 0.5, 1, 2, 5]) lambda_nn_grid=None, # Nuclear norm grid (default: [0, 0.01, 0.1, 1, 10]) @@ -1530,8 +1530,8 @@ TROP( ``` **Estimation methods:** -- `'twostep'` (default): Per-observation model fitting following Algorithm 2 of the paper. Computes observation-specific weights and fits a model for each treated observation, then averages the individual treatment effects. More flexible but computationally intensive. -- `'joint'`: Joint weighted least squares optimization. Estimates a single scalar treatment effect τ along with fixed effects and optional low-rank factor adjustment. Faster but assumes homogeneous treatment effects. +- `'local'` (default): Per-observation model fitting following Algorithm 2 of the paper. Computes observation-specific weights and fits a model for each treated observation, then averages the individual treatment effects. More flexible but computationally intensive. +- `'global'`: Global weighted least squares optimization. Fits a single model on control observations with global weights, then computes per-observation treatment effects as residuals. Faster but uses global rather than observation-specific weights. **Convenience function:** diff --git a/diff_diff/_backend.py b/diff_diff/_backend.py index 0c6e99f..2065e17 100644 --- a/diff_diff/_backend.py +++ b/diff_diff/_backend.py @@ -23,13 +23,13 @@ project_simplex as _rust_project_simplex, solve_ols as _rust_solve_ols, compute_robust_vcov as _rust_compute_robust_vcov, - # TROP estimator acceleration (twostep method) + # TROP estimator acceleration (local method) compute_unit_distance_matrix as _rust_unit_distance_matrix, loocv_grid_search as _rust_loocv_grid_search, bootstrap_trop_variance as _rust_bootstrap_trop_variance, - # TROP estimator acceleration (joint method) - loocv_grid_search_joint as _rust_loocv_grid_search_joint, - bootstrap_trop_variance_joint as _rust_bootstrap_trop_variance_joint, + # TROP estimator acceleration (global method) + loocv_grid_search_global as _rust_loocv_grid_search_global, + bootstrap_trop_variance_global as _rust_bootstrap_trop_variance_global, # SDID weights (Frank-Wolfe matching R's synthdid) compute_sdid_unit_weights as _rust_sdid_unit_weights, compute_time_weights as _rust_compute_time_weights, @@ -46,13 +46,13 @@ _rust_project_simplex = None _rust_solve_ols = None _rust_compute_robust_vcov = None - # TROP estimator acceleration (twostep method) + # TROP estimator acceleration (local method) _rust_unit_distance_matrix = None _rust_loocv_grid_search = None _rust_bootstrap_trop_variance = None - # TROP estimator acceleration (joint method) - _rust_loocv_grid_search_joint = None - _rust_bootstrap_trop_variance_joint = None + # TROP estimator acceleration (global method) + _rust_loocv_grid_search_global = None + _rust_bootstrap_trop_variance_global = None # SDID weights (Frank-Wolfe matching R's synthdid) _rust_sdid_unit_weights = None _rust_compute_time_weights = None @@ -69,13 +69,13 @@ _rust_project_simplex = None _rust_solve_ols = None _rust_compute_robust_vcov = None - # TROP estimator acceleration (twostep method) + # TROP estimator acceleration (local method) _rust_unit_distance_matrix = None _rust_loocv_grid_search = None _rust_bootstrap_trop_variance = None - # TROP estimator acceleration (joint method) - _rust_loocv_grid_search_joint = None - _rust_bootstrap_trop_variance_joint = None + # TROP estimator acceleration (global method) + _rust_loocv_grid_search_global = None + _rust_bootstrap_trop_variance_global = None # SDID weights (Frank-Wolfe matching R's synthdid) _rust_sdid_unit_weights = None _rust_compute_time_weights = None @@ -118,13 +118,13 @@ def rust_backend_info(): '_rust_project_simplex', '_rust_solve_ols', '_rust_compute_robust_vcov', - # TROP estimator acceleration (twostep method) + # TROP estimator acceleration (local method) '_rust_unit_distance_matrix', '_rust_loocv_grid_search', '_rust_bootstrap_trop_variance', - # TROP estimator acceleration (joint method) - '_rust_loocv_grid_search_joint', - '_rust_bootstrap_trop_variance_joint', + # TROP estimator acceleration (global method) + '_rust_loocv_grid_search_global', + '_rust_bootstrap_trop_variance_global', # SDID weights (Frank-Wolfe matching R's synthdid) '_rust_sdid_unit_weights', '_rust_compute_time_weights', diff --git a/diff_diff/trop.py b/diff_diff/trop.py index abb21f9..8f2c1f3 100644 --- a/diff_diff/trop.py +++ b/diff_diff/trop.py @@ -31,8 +31,8 @@ _rust_unit_distance_matrix, _rust_loocv_grid_search, _rust_bootstrap_trop_variance, - _rust_loocv_grid_search_joint, - _rust_bootstrap_trop_variance_joint, + _rust_loocv_grid_search_global, + _rust_bootstrap_trop_variance_global, ) from diff_diff.trop_results import ( _LAMBDA_INF, @@ -63,10 +63,10 @@ class TROP: Parameters ---------- - method : str, default='twostep' + method : str, default='local' Estimation method to use: - - 'twostep': Per-observation model fitting following Algorithm 2 of + - 'local': Per-observation model fitting following Algorithm 2 of Athey et al. (2025). Computes observation-specific weights and fits a model for each treated observation, averaging the individual treatment effects. More flexible but computationally intensive. @@ -77,10 +77,11 @@ class TROP: treatment effects as residuals: tau_it = Y_it - mu - alpha_i - beta_t - L_it for treated cells. ATT is the mean of these effects. For the paper's full - per-treated-cell estimator, use ``method='twostep'``. + per-treated-cell estimator, use ``method='local'``. - - 'joint': Deprecated alias for 'global'. Will be removed in a - future version. + - 'twostep': Deprecated alias for 'local'. Will be removed in v3.0. + + - 'joint': Deprecated alias for 'global'. Will be removed in v3.0. lambda_time_grid : list, optional Grid of time weight decay parameters. 0.0 = uniform weights (disabled). @@ -138,7 +139,7 @@ class TROP: def __init__( self, - method: str = "twostep", + method: str = "local", lambda_time_grid: Optional[List[float]] = None, lambda_unit_grid: Optional[List[float]] = None, lambda_nn_grid: Optional[List[float]] = None, @@ -149,16 +150,24 @@ def __init__( seed: Optional[int] = None, ): # Validate method parameter - # 'global' is the preferred name; 'joint' is a deprecated alias - valid_methods = ("twostep", "joint", "global") + # 'local'/'global' are preferred; 'twostep'/'joint' are deprecated aliases + valid_methods = ("local", "twostep", "joint", "global") if method not in valid_methods: raise ValueError( f"method must be one of {valid_methods}, got '{method}'" ) + if method == "twostep": + warnings.warn( + "method='twostep' is deprecated and will be removed in v3.0. " + "Use method='local' instead.", + FutureWarning, + stacklevel=2, + ) + method = "local" if method == "joint": warnings.warn( - "method='joint' is deprecated and will be removed in a future " - "version. Use method='global' instead.", + "method='joint' is deprecated and will be removed in v3.0. " + "Use method='global' instead.", FutureWarning, stacklevel=2, ) @@ -556,7 +565,7 @@ def _cycling_parameter_search( # Joint estimation method # ========================================================================= - def _compute_joint_weights( + def _compute_global_weights( self, Y: np.ndarray, D: np.ndarray, @@ -567,7 +576,7 @@ def _compute_joint_weights( n_periods: int, ) -> np.ndarray: """ - Compute distance-based weights for joint estimation. + Compute distance-based weights for global estimation. Following the reference implementation, weights are computed based on: - Time distance: distance to center of treated block @@ -655,7 +664,7 @@ def _compute_joint_weights( return delta - def _solve_joint_model( + def _solve_global_model( self, Y: np.ndarray, delta: np.ndarray, @@ -668,10 +677,10 @@ def _solve_joint_model( """ n_periods, n_units = Y.shape if lambda_nn >= 1e10: - mu, alpha, beta = self._solve_joint_no_lowrank(Y, delta) + mu, alpha, beta = self._solve_global_no_lowrank(Y, delta) L = np.zeros((n_periods, n_units)) else: - mu, alpha, beta, L = self._solve_joint_with_lowrank( + mu, alpha, beta, L = self._solve_global_with_lowrank( Y, delta, lambda_nn, self.max_iter, self.tol ) return mu, alpha, beta, L @@ -718,7 +727,7 @@ def _extract_posthoc_tau( return att, treatment_effects, tau_values - def _loocv_score_joint( + def _loocv_score_global( self, Y: np.ndarray, D: np.ndarray, @@ -731,12 +740,12 @@ def _loocv_score_joint( n_periods: int, ) -> float: """ - Compute LOOCV score for joint method with specific parameter combination. + Compute LOOCV score for global method with specific parameter combination. Following paper's Equation 5: Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]² - For joint method, we exclude each control observation, fit the joint model + For global method, we exclude each control observation, fit the global model on remaining data, and compute the pseudo-treatment effect at the excluded obs. Parameters @@ -766,7 +775,7 @@ def _loocv_score_joint( LOOCV score (sum of squared pseudo-treatment effects). """ # Compute global weights (same for all LOOCV iterations) - delta = self._compute_joint_weights( + delta = self._compute_global_weights( Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods ) @@ -779,7 +788,7 @@ def _loocv_score_joint( delta_ex[t_ex, i_ex] = 0.0 try: - mu, alpha, beta, L = self._solve_joint_model(Y, delta_ex, lambda_nn) + mu, alpha, beta, L = self._solve_global_model(Y, delta_ex, lambda_nn) # Pseudo treatment effect: τ = Y - μ - α - β - L if np.isfinite(Y[t_ex, i_ex]): @@ -796,7 +805,7 @@ def _loocv_score_joint( return tau_sq_sum - def _solve_joint_no_lowrank( + def _solve_global_no_lowrank( self, Y: np.ndarray, delta: np.ndarray, @@ -806,7 +815,7 @@ def _solve_joint_no_lowrank( Solves: min Σ (1-W)*δ_{it}(Y_{it} - μ - α_i - β_t)² - The (1-W) masking is already applied to delta by _compute_joint_weights, + The (1-W) masking is already applied to delta by _compute_global_weights, so treated observations have zero weight and do not affect the fit. Parameters @@ -880,7 +889,7 @@ def _solve_joint_no_lowrank( return float(mu), alpha, beta - def _solve_joint_with_lowrank( + def _solve_global_with_lowrank( self, Y: np.ndarray, delta: np.ndarray, @@ -893,7 +902,7 @@ def _solve_joint_with_lowrank( Solves: min Σ (1-W)*δ_{it}(Y_{it} - μ - α_i - β_t - L_{it})² + λ_nn||L||_* - The (1-W) masking is already applied to delta by _compute_joint_weights, + The (1-W) masking is already applied to delta by _compute_global_weights, so treated observations have zero weight and do not affect the fit. Parameters @@ -942,7 +951,7 @@ def _solve_joint_with_lowrank( # Step 1: Fix L, solve for (mu, alpha, beta) Y_adj = Y_safe - L - mu, alpha, beta = self._solve_joint_no_lowrank(Y_adj, delta_masked) + mu, alpha, beta = self._solve_global_no_lowrank(Y_adj, delta_masked) # Step 2: Fix (mu, alpha, beta), update L with FISTA acceleration R = Y_safe - mu - alpha[np.newaxis, :] - beta[:, np.newaxis] @@ -981,11 +990,11 @@ def _solve_joint_with_lowrank( # Final re-solve with converged L (match Rust behavior) Y_adj = Y_safe - L - mu, alpha, beta = self._solve_joint_no_lowrank(Y_adj, delta_masked) + mu, alpha, beta = self._solve_global_no_lowrank(Y_adj, delta_masked) return mu, alpha, beta, L - def _fit_joint( + def _fit_global( self, data: pd.DataFrame, outcome: str, @@ -1024,10 +1033,10 @@ def _fit_joint( (fixed `treated_periods` across resamples). The treatment timing is inferred from the data once and held constant for all bootstrap iterations. For staggered adoption designs where treatment timing varies - across units, use `method="twostep"` which computes observation-specific + across units, use `method="local"` which computes observation-specific weights that naturally handle heterogeneous timing. """ - # Data setup (same as twostep method) + # Data setup (same as local method) all_units = sorted(data[unit].unique()) all_periods = sorted(data[time].unique()) @@ -1097,7 +1106,7 @@ def _fit_joint( if n_pre_periods < 2: raise ValueError("Need at least 2 pre-treatment periods") - # Check for staggered adoption (joint method requires simultaneous treatment) + # Check for staggered adoption (global method requires simultaneous treatment) # Use only observed periods (skip missing) to avoid false positives on unbalanced panels first_treat_by_unit = [] for i in treated_unit_idx: @@ -1115,7 +1124,7 @@ def _fit_joint( raise ValueError( f"method='global' requires simultaneous treatment adoption, but your data " f"shows staggered adoption (units first treated at periods {unique_starts}). " - f"Use method='twostep' which properly handles staggered adoption designs." + f"Use method='local' which properly handles staggered adoption designs." ) # LOOCV grid search for tuning parameters @@ -1124,7 +1133,7 @@ def _fit_joint( best_score = np.inf control_mask = D == 0 - if HAS_RUST_BACKEND and _rust_loocv_grid_search_joint is not None: + if HAS_RUST_BACKEND and _rust_loocv_grid_search_global is not None: try: # Prepare inputs for Rust function control_mask_u8 = control_mask.astype(np.uint8) @@ -1133,7 +1142,7 @@ def _fit_joint( lambda_unit_arr = np.array(self.lambda_unit_grid, dtype=np.float64) lambda_nn_arr = np.array(self.lambda_nn_grid, dtype=np.float64) - result = _rust_loocv_grid_search_joint( + result = _rust_loocv_grid_search_global( Y, D.astype(np.float64), control_mask_u8, lambda_time_arr, lambda_unit_arr, lambda_nn_arr, self.max_iter, self.tol, @@ -1170,7 +1179,7 @@ def _fit_joint( except Exception as e: # Fall back to Python implementation on error logger.debug( - "Rust LOOCV grid search (joint) failed, falling back to Python: %s", e + "Rust LOOCV grid search (global) failed, falling back to Python: %s", e ) best_lambda = None best_score = np.inf @@ -1193,7 +1202,7 @@ def _fit_joint( ln = 1e10 if np.isinf(lambda_nn_val) else lambda_nn_val try: - score = self._loocv_score_joint( + score = self._loocv_score_global( Y, D, control_obs, lt, lu, ln, treated_periods, n_units, n_periods ) @@ -1223,11 +1232,11 @@ def _fit_joint( lambda_nn = 1e10 # Compute final weights and fit - delta = self._compute_joint_weights( + delta = self._compute_global_weights( Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods ) - mu, alpha, beta, L = self._solve_joint_model(Y, delta, lambda_nn) + mu, alpha, beta, L = self._solve_global_model(Y, delta, lambda_nn) # Post-hoc tau extraction (per paper Eq. 2) att, treatment_effects, tau_values = self._extract_posthoc_tau( @@ -1258,7 +1267,7 @@ def _fit_joint( # Bootstrap variance estimation effective_lambda = (lambda_time, lambda_unit, lambda_nn) - se, bootstrap_dist = self._bootstrap_variance_joint( + se, bootstrap_dist = self._bootstrap_variance_global( data, outcome, treatment, unit, time, effective_lambda, treated_periods ) @@ -1300,7 +1309,7 @@ def _fit_joint( self.is_fitted_ = True return self.results_ - def _bootstrap_variance_joint( + def _bootstrap_variance_global( self, data: pd.DataFrame, outcome: str, @@ -1311,7 +1320,7 @@ def _bootstrap_variance_joint( treated_periods: int, ) -> Tuple[float, np.ndarray]: """ - Compute bootstrap standard error for joint method. + Compute bootstrap standard error for global method. Uses Rust backend when available for parallel bootstrap (5-15x speedup). @@ -1340,7 +1349,7 @@ def _bootstrap_variance_joint( lambda_time, lambda_unit, lambda_nn = optimal_lambda # Try Rust backend for parallel bootstrap (5-15x speedup) - if HAS_RUST_BACKEND and _rust_bootstrap_trop_variance_joint is not None: + if HAS_RUST_BACKEND and _rust_bootstrap_trop_variance_global is not None: try: # Create matrices for Rust function all_units = sorted(data[unit].unique()) @@ -1359,7 +1368,7 @@ def _bootstrap_variance_joint( .values ) - bootstrap_estimates, se = _rust_bootstrap_trop_variance_joint( + bootstrap_estimates, se = _rust_bootstrap_trop_variance_global( Y, D, lambda_time, lambda_unit, lambda_nn, self.n_bootstrap, self.max_iter, self.tol, @@ -1378,7 +1387,7 @@ def _bootstrap_variance_joint( except Exception as e: logger.debug( - "Rust bootstrap (joint) failed, falling back to Python: %s", e + "Rust bootstrap (global) failed, falling back to Python: %s", e ) # Python fallback implementation @@ -1419,7 +1428,7 @@ def _bootstrap_variance_joint( ], ignore_index=True) try: - tau = self._fit_joint_with_fixed_lambda( + tau = self._fit_global_with_fixed_lambda( boot_data, outcome, treatment, unit, time, optimal_lambda, treated_periods ) @@ -1441,7 +1450,7 @@ def _bootstrap_variance_joint( se = np.std(bootstrap_estimates, ddof=1) return float(se), bootstrap_estimates - def _fit_joint_with_fixed_lambda( + def _fit_global_with_fixed_lambda( self, data: pd.DataFrame, outcome: str, @@ -1478,12 +1487,12 @@ def _fit_joint_with_fixed_lambda( ) # Compute weights (includes (1-W) masking) - delta = self._compute_joint_weights( + delta = self._compute_global_weights( Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods ) # Fit model on control data and extract post-hoc tau - mu, alpha, beta, L = self._solve_joint_model(Y, delta, lambda_nn) + mu, alpha, beta, L = self._solve_global_model(Y, delta, lambda_nn) att, _, _ = self._extract_posthoc_tau(Y, D, mu, alpha, beta, L) return att @@ -1541,9 +1550,9 @@ def fit( # Dispatch based on estimation method if self.method == "global": - return self._fit_joint(data, outcome, treatment, unit, time) + return self._fit_global(data, outcome, treatment, unit, time) - # Below is the twostep method (default) + # Below is the local method (default) # Get unique units and periods all_units = sorted(data[unit].unique()) all_periods = sorted(data[time].unique()) @@ -2660,10 +2669,18 @@ def get_params(self) -> Dict[str, Any]: def set_params(self, **params) -> "TROP": """Set estimator parameters.""" for key, value in params.items(): + if key == "method" and value == "twostep": + warnings.warn( + "method='twostep' is deprecated and will be removed in " + "v3.0. Use method='local' instead.", + FutureWarning, + stacklevel=2, + ) + value = "local" if key == "method" and value == "joint": warnings.warn( - "method='joint' is deprecated and will be removed in a " - "future version. Use method='global' instead.", + "method='joint' is deprecated and will be removed in " + "v3.0. Use method='global' instead.", FutureWarning, stacklevel=2, ) diff --git a/docs/api/trop.rst b/docs/api/trop.rst index cfcdc8c..2610624 100644 --- a/docs/api/trop.rst +++ b/docs/api/trop.rst @@ -100,7 +100,7 @@ Estimation Methods TROP supports two estimation methods via the ``method`` parameter: -**Two-Step Method** (``method='twostep'``, default) +**Local Method** (``method='local'``, default) The default method follows Algorithm 2 from the paper: @@ -124,7 +124,7 @@ the estimator is consistent if any one of the three components A computationally efficient adaptation using the ``(1-W)`` masking principle from Eq. 2. Fits a single global model rather than per-treated-cell models. For the paper's full per-treated-cell estimator (Algorithm 2), use -``method='twostep'``. +``method='local'``. 1. **Compute weights**: Distance-based unit and time weights computed once (distance to center of treated block, RMSE to average treated trajectory), @@ -145,15 +145,16 @@ For the paper's full per-treated-cell estimator (Algorithm 2), use The global method is **faster** (single optimization vs N_treated optimizations). Treatment effects are **heterogeneous** per-observation residuals; ATT is their mean. -``method='joint'`` is a deprecated alias for ``method='global'`` and will be -removed in a future version. +``method='twostep'`` is a deprecated alias for ``method='local'`` and will be +removed in v3.0. ``method='joint'`` is a deprecated alias for ``method='global'`` +and will be removed in v3.0. .. list-table:: :header-rows: 1 :widths: 20 40 40 * - Feature - - Two-Step (default) + - Local (default) - Global * - Treatment effect - Per-observation τ_{it} (per-obs models) @@ -168,7 +169,7 @@ removed in a future version. - Observation-specific - Global (center of treated block) -Use ``method='twostep'`` for observation-specific weight optimization. +Use ``method='local'`` for observation-specific weight optimization. Use ``method='global'`` for faster estimation with global weights. Example Usage @@ -228,9 +229,9 @@ Using the global method for faster estimation:: unit='unit_id', time='period') # Compare methods - trop_twostep = TROP(method='twostep', ...) # Default (per-observation) - results_twostep = trop_twostep.fit(data, ...) - print(f"Two-step ATT: {results_twostep.att:.3f}") + trop_local = TROP(method='local', ...) # Default (per-observation) + results_local = trop_local.fit(data, ...) + print(f"Local ATT: {results_local.att:.3f}") print(f"Global ATT: {results_global.att:.3f}") Examining factor structure:: diff --git a/docs/choosing_estimator.rst b/docs/choosing_estimator.rst index 3670f9a..8a13c85 100644 --- a/docs/choosing_estimator.rst +++ b/docs/choosing_estimator.rst @@ -405,7 +405,7 @@ exponential unit distance weights, and time decay weights with LOOCV tuning. .. note:: TROP is computationally intensive. Use ``method='global'`` for faster - estimation at the cost of some flexibility vs. ``method='twostep'``. + estimation at the cost of some flexibility vs. ``method='local'``. Bacon Decomposition ~~~~~~~~~~~~~~~~~~~ diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 54bfbac..e846543 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -1275,14 +1275,14 @@ Optimization (Equation 2): ``` (α̂, β̂, L̂) = argmin_{α,β,L} Σ_j Σ_s θ_s^{i,t} ω_j^{i,t} (1-W_js)(Y_js - α_j - β_s - L_js)² + λ_nn ||L||_* ``` -Solved via alternating minimization. For α, β (or μ, α, β, τ in joint): weighted least +Solved via alternating minimization. For α, β (or μ, α, β, τ in global): weighted least squares (closed form). For L: proximal gradient with step size η = 1/(2·max(W)): ``` Gradient step: G = L + (W/max(W)) ⊙ (R - L) Proximal step: L = U × soft_threshold(Σ, η·λ_nn) × V' (SVD of G = UΣV') ``` -where R is the residual after removing fixed effects (and τ·D in joint mode). -Both the twostep and global solvers use FISTA/Nesterov acceleration for the +where R is the residual after removing fixed effects (and τ·D in global mode). +Both the local and global solvers use FISTA/Nesterov acceleration for the inner L update (O(1/k²) convergence rate, up to 20 inner iterations per outer alternating step). @@ -1372,12 +1372,13 @@ Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]² ### TROP Global Estimation Method -**Method**: `method="global"` in TROP estimator (`method="joint"` is a deprecated alias) +**Method**: `method="global"` in TROP estimator (`method="joint"` is a deprecated alias; +`method="twostep"` is a deprecated alias for `method="local"`) **Approach**: Computationally efficient adaptation using the (1-W) masking principle from Eq. 2. Fits a single global model on control data, then extracts treatment effects as post-hoc residuals. For the paper's full -per-treated-cell estimator (Algorithm 2), use `method='twostep'`. +per-treated-cell estimator (Algorithm 2), use `method='local'`. **Objective function** (Equation G1): ``` @@ -1400,7 +1401,7 @@ ATT = mean(τ̂_{it}) over all treated observations Treatment effects are **heterogeneous** per-observation values. ATT is their mean. -**Weight computation** (differs from twostep): +**Weight computation** (differs from local): - Time weights: δ_time(t) = exp(-λ_time × |t - center|) where center = T - treated_periods/2 - Unit weights: δ_unit(i) = exp(-λ_unit × RMSE(i, treated_avg)) where RMSE is computed over pre-treatment periods comparing to average treated trajectory @@ -1424,7 +1425,7 @@ Treatment effects are **heterogeneous** per-observation values. ATT is their mea 3. **Post-hoc**: Extract τ̂_{it} = Y_{it} - μ̂ - α̂_i - β̂_t - L̂_{it} for treated cells -**LOOCV parameter selection** (unified with twostep, Equation 5): +**LOOCV parameter selection** (unified with local, Equation 5): Following paper's Equation 5 and footnote 2: ``` Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]² @@ -1441,14 +1442,14 @@ For global method, LOOCV works as follows: 3. Select λ combination that minimizes Q(λ) **Rust acceleration**: The LOOCV grid search is parallelized in Rust for 5-10x speedup. -- `loocv_grid_search_joint()` - Parallel LOOCV across all λ combinations -- `bootstrap_trop_variance_joint()` - Parallel bootstrap variance estimation +- `loocv_grid_search_global()` - Parallel LOOCV across all λ combinations +- `bootstrap_trop_variance_global()` - Parallel bootstrap variance estimation -**Key differences from twostep method**: +**Key differences from local method**: - Global weights (distance to treated block center) vs. per-observation weights - Single model fit per λ combination vs. N_treated fits - Treatment effects are post-hoc residuals from a single global model (global) - vs. post-hoc residuals from per-observation models (twostep) + vs. post-hoc residuals from per-observation models (local) - Both use (1-W) masking (control-only fitting) - Faster computation for large panels @@ -1457,14 +1458,14 @@ For global method, LOOCV works as follows: to receive treatment at the same time. A `ValueError` is raised if staggered adoption is detected (units first treated at different periods). Treatment timing is inferred once and held constant for bootstrap variance estimation. - For staggered adoption designs, use `method="twostep"`. + For staggered adoption designs, use `method="local"`. **Reference**: Adapted from reference implementation. See also Athey et al. (2025). **Edge Cases (treated NaN outcomes):** - **Partial NaN**: When some treated outcomes Y_{it} are NaN/missing: - `_extract_posthoc_tau()` (global) skips these cells; only finite τ̂ values are averaged - - Twostep loop skips NaN outcomes entirely (no model fit, no tau appended) + - Local loop skips NaN outcomes entirely (no model fit, no tau appended) - `n_treated_obs` in results reflects valid (finite) count, not total D==1 count - `df_trop = max(1, n_valid_treated - 1)` uses valid count - Warning issued when n_valid_treated < total treated count @@ -1475,13 +1476,15 @@ For global method, LOOCV works as follows: iterations succeed. `safe_inference()` propagates NaN downstream. **Requirements checklist:** -- [x] Same LOOCV framework as twostep (Equation 5) +- [x] Same LOOCV framework as local (Equation 5) - [x] Global weight computation using treated block center - [x] (1-W) masking for control-only fitting (per paper Eq. 2) - [x] Alternating minimization for nuclear norm penalty - [x] Returns ATT = mean of per-observation post-hoc τ̂_{it} - [x] Rust acceleration for LOOCV and bootstrap +- **Note:** `method="twostep"` renamed to `method="local"` and `method="joint"` renamed to `method="global"` to form a natural local/global pair. Both old names are deprecated aliases, removal planned for v3.0. + --- # Diagnostics & Sensitivity diff --git a/docs/troubleshooting.rst b/docs/troubleshooting.rst index 3fa925a..ade5c89 100644 --- a/docs/troubleshooting.rst +++ b/docs/troubleshooting.rst @@ -582,6 +582,26 @@ inaccurate with missing observations. Deprecation Warnings -------------------- +"method='twostep' is deprecated" +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +**Problem:** TROP emits a ``FutureWarning`` that ``method='twostep'`` is +deprecated. + +**Causes:** + +1. Code uses the old ``method='twostep'`` parameter name + +**Solutions:** + +.. code-block:: python + + # Old (deprecated) + trop = TROP(method='twostep') + + # New (use 'local' instead) + trop = TROP(method='local') + "method='joint' is deprecated" ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/tutorials/10_trop.ipynb b/docs/tutorials/10_trop.ipynb index 8844660..5c8d9d5 100644 --- a/docs/tutorials/10_trop.ipynb +++ b/docs/tutorials/10_trop.ipynb @@ -598,14 +598,14 @@ }, { "cell_type": "code", - "source": "# Compare estimation methods\nprint(\"Estimation method comparison:\")\nprint(\"=\"*60)\n\nimport time\n\n# Two-step method (default)\nstart = time.time()\ntrop_twostep = TROP(\n method='twostep',\n lambda_time_grid=[0.0, 1.0],\n lambda_unit_grid=[0.0, 1.0], \n lambda_nn_grid=[0.0, 0.1],\n n_bootstrap=20,\n seed=42\n)\nresults_twostep = trop_twostep.fit(\n df,\n outcome='outcome',\n treatment='treated',\n unit='unit',\n time='period'\n)\ntwostep_time = time.time() - start\n\n# Global method\nstart = time.time()\ntrop_global = TROP(\n method='global',\n lambda_time_grid=[0.0, 1.0],\n lambda_unit_grid=[0.0, 1.0], \n lambda_nn_grid=[0.0, 0.1],\n n_bootstrap=20,\n seed=42\n)\nresults_global = trop_global.fit(\n df,\n outcome='outcome',\n treatment='treated',\n unit='unit',\n time='period'\n)\nglobal_time = time.time() - start\n\nprint(f\"\\n{'Method':<15} {'ATT':>10} {'SE':>10} {'Time (s)':>12}\")\nprint(\"-\"*60)\nprint(f\"{'Two-step':<15} {results_twostep.att:>10.4f} {results_twostep.se:>10.4f} {twostep_time:>12.2f}\")\nprint(f\"{'Global':<15} {results_global.att:>10.4f} {results_global.se:>10.4f} {global_time:>12.2f}\")\nprint(f\"\\nTrue ATT: {true_att}\")\nprint(f\"Two-step bias: {results_twostep.att - true_att:.4f}\")\nprint(f\"Global bias: {results_global.att - true_att:.4f}\")", + "source": "# Compare estimation methods\nprint(\"Estimation method comparison:\")\nprint(\"=\"*60)\n\nimport time\n\n# Local method (default)\nstart = time.time()\ntrop_local = TROP(\n method='local',\n lambda_time_grid=[0.0, 1.0],\n lambda_unit_grid=[0.0, 1.0], \n lambda_nn_grid=[0.0, 0.1],\n n_bootstrap=20,\n seed=42\n)\nresults_local = trop_local.fit(\n df,\n outcome='outcome',\n treatment='treated',\n unit='unit',\n time='period'\n)\nlocal_time = time.time() - start\n\n# Global method\nstart = time.time()\ntrop_global = TROP(\n method='global',\n lambda_time_grid=[0.0, 1.0],\n lambda_unit_grid=[0.0, 1.0], \n lambda_nn_grid=[0.0, 0.1],\n n_bootstrap=20,\n seed=42\n)\nresults_global = trop_global.fit(\n df,\n outcome='outcome',\n treatment='treated',\n unit='unit',\n time='period'\n)\nglobal_time = time.time() - start\n\nprint(f\"\\n{'Method':<15} {'ATT':>10} {'SE':>10} {'Time (s)':>12}\")\nprint(\"-\"*60)\nprint(f\"{'Local':<15} {results_local.att:>10.4f} {results_local.se:>10.4f} {local_time:>12.2f}\")\nprint(f\"{'Global':<15} {results_global.att:>10.4f} {results_global.se:>10.4f} {global_time:>12.2f}\")\nprint(f\"\\nTrue ATT: {true_att}\")\nprint(f\"Local bias: {results_local.att - true_att:.4f}\")\nprint(f\"Global bias: {results_global.att - true_att:.4f}\")", "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", - "source": "## 10. Estimation Methods: Two-Step vs Global\n\nTROP supports two estimation methods via the `method` parameter:\n\n**Two-Step Method** (`method='twostep'`, default):\n- Follows Algorithm 2 from the paper\n- Computes observation-specific weights for each treated observation\n- Fits a model per treated observation, then averages the individual effects\n- More flexible, allows for heterogeneous treatment effects\n- Computationally intensive (N_treated optimizations)\n\n**Global Method** (`method='global'`):\n- Fits a single model on control data using (1-W) masked weights (per paper Eq. 2)\n- Extracts per-observation treatment effects as post-hoc residuals: τ_it = Y_it - μ - α_i - β_t - L_it\n- ATT = mean(τ_it) over treated observations\n- Faster (single optimization) with global weights\n\nNote: `method='joint'` is a deprecated alias for `method='global'`.", + "source": "## 10. Estimation Methods: Local vs Global\n\nTROP supports two estimation methods via the `method` parameter:\n\n**Local Method** (`method='local'`, default):\n- Follows Algorithm 2 from the paper\n- Computes observation-specific weights for each treated observation\n- Fits a model per treated observation, then averages the individual effects\n- More flexible, allows for heterogeneous treatment effects\n- Computationally intensive (N_treated optimizations)\n\n**Global Method** (`method='global'`):\n- Fits a single model on control data using (1-W) masked weights (per paper Eq. 2)\n- Extracts per-observation treatment effects as post-hoc residuals: τ_it = Y_it - μ - α_i - β_t - L_it\n- ATT = mean(τ_it) over treated observations\n- Faster (single optimization) with global weights\n\nNote: `method='twostep'` is a deprecated alias for `method='local'`, and `method='joint'` is a deprecated alias for `method='global'`. Both will be removed in v3.0.", "metadata": {} }, { @@ -638,7 +638,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": "## Summary\n\nKey takeaways for TROP:\n\n1. **Best use cases**: Factor confounding, unobserved time-varying confounders with interactive effects\n2. **Factor estimation**: Nuclear norm regularization with LOOCV for tuning\n3. **Three tuning parameters**: λ_time, λ_unit, λ_nn selected automatically via LOOCV\n4. **Unit weights**: Exponential distance-based weighting of control units, where distance is computed as RMS outcome difference on control periods excluding the target period\n5. **Time weights**: Exponential decay weighting of pre-treatment periods\n6. **Weights**: Importance weights controlling relative contribution of observations (higher = more relevant)\n7. **Estimation methods**:\n - `method='twostep'` (default): Per-observation estimation, allows heterogeneous effects\n - `method='global'`: Single model with (1-W) masking, post-hoc heterogeneous effects, faster\n\n**When to use TROP vs SDID**:\n- Use **SDID** when parallel trends is plausible and factors are not a concern\n- Use **TROP** when you suspect factor confounding (regional shocks, economic cycles, latent factors)\n- Running both provides a useful robustness check\n\n**When to use twostep vs global method**:\n- Use **twostep** (default) for maximum flexibility with per-observation weights\n- Use **global** for faster estimation with global weights\n\n**Reference**:\n- Athey, S., Imbens, G. W., Qu, Z., & Viviano, D. (2025). Triply Robust Panel Estimators. *Working Paper*. https://arxiv.org/abs/2508.21536" + "source": "## Summary\n\nKey takeaways for TROP:\n\n1. **Best use cases**: Factor confounding, unobserved time-varying confounders with interactive effects\n2. **Factor estimation**: Nuclear norm regularization with LOOCV for tuning\n3. **Three tuning parameters**: λ_time, λ_unit, λ_nn selected automatically via LOOCV\n4. **Unit weights**: Exponential distance-based weighting of control units, where distance is computed as RMS outcome difference on control periods excluding the target period\n5. **Time weights**: Exponential decay weighting of pre-treatment periods\n6. **Weights**: Importance weights controlling relative contribution of observations (higher = more relevant)\n7. **Estimation methods**:\n - `method='local'` (default): Per-observation estimation, allows heterogeneous effects\n - `method='global'`: Single model with (1-W) masking, post-hoc heterogeneous effects, faster\n\n**When to use TROP vs SDID**:\n- Use **SDID** when parallel trends is plausible and factors are not a concern\n- Use **TROP** when you suspect factor confounding (regional shocks, economic cycles, latent factors)\n- Running both provides a useful robustness check\n\n**When to use local vs global method**:\n- Use **local** (default) for maximum flexibility with per-observation weights\n- Use **global** for faster estimation with global weights\n\n**Reference**:\n- Athey, S., Imbens, G. W., Qu, Z., & Viviano, D. (2025). Triply Robust Panel Estimators. *Working Paper*. https://arxiv.org/abs/2508.21536" }, { "cell_type": "code", diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 9507e1c..6dd2fdd 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -42,14 +42,14 @@ fn _rust_backend(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(linalg::solve_ols, m)?)?; m.add_function(wrap_pyfunction!(linalg::compute_robust_vcov, m)?)?; - // TROP estimator acceleration (twostep method) + // TROP estimator acceleration (local method) m.add_function(wrap_pyfunction!(trop::compute_unit_distance_matrix, m)?)?; m.add_function(wrap_pyfunction!(trop::loocv_grid_search, m)?)?; m.add_function(wrap_pyfunction!(trop::bootstrap_trop_variance, m)?)?; - // TROP estimator acceleration (joint method) - m.add_function(wrap_pyfunction!(trop::loocv_grid_search_joint, m)?)?; - m.add_function(wrap_pyfunction!(trop::bootstrap_trop_variance_joint, m)?)?; + // TROP estimator acceleration (global method) + m.add_function(wrap_pyfunction!(trop::loocv_grid_search_global, m)?)?; + m.add_function(wrap_pyfunction!(trop::bootstrap_trop_variance_global, m)?)?; // Diagnostics m.add_function(wrap_pyfunction!(rust_backend_info, m)?)?; diff --git a/rust/src/trop.rs b/rust/src/trop.rs index 997bf36..8993a3e 100644 --- a/rust/src/trop.rs +++ b/rust/src/trop.rs @@ -1522,7 +1522,7 @@ fn loocv_score_joint( #[pyfunction] #[pyo3(signature = (y, d, control_mask, lambda_time_grid, lambda_unit_grid, lambda_nn_grid, max_iter, tol))] #[allow(clippy::too_many_arguments)] -pub fn loocv_grid_search_joint<'py>( +pub fn loocv_grid_search_global<'py>( _py: Python<'py>, y: PyReadonlyArray2<'py, f64>, d: PyReadonlyArray2<'py, f64>, @@ -1651,7 +1651,7 @@ pub fn loocv_grid_search_joint<'py>( #[pyfunction] #[pyo3(signature = (y, d, lambda_time, lambda_unit, lambda_nn, n_bootstrap, max_iter, tol, seed))] #[allow(clippy::too_many_arguments)] -pub fn bootstrap_trop_variance_joint<'py>( +pub fn bootstrap_trop_variance_global<'py>( py: Python<'py>, y: PyReadonlyArray2<'py, f64>, d: PyReadonlyArray2<'py, f64>, diff --git a/tests/test_rust_backend.py b/tests/test_rust_backend.py index 1921339..312347f 100644 --- a/tests/test_rust_backend.py +++ b/tests/test_rust_backend.py @@ -1151,12 +1151,12 @@ def test_trop_produces_valid_results(self): @pytest.mark.slow @pytest.mark.skipif(not HAS_RUST_BACKEND, reason="Rust backend not available") -class TestTROPJointRustBackend: - """Test suite for TROP joint method Rust backend functions.""" +class TestTROPGlobalRustBackend: + """Test suite for TROP global method Rust backend functions.""" - def test_loocv_grid_search_joint_returns_valid_result(self): - """Test loocv_grid_search_joint returns valid tuning parameters.""" - from diff_diff._rust_backend import loocv_grid_search_joint + def test_loocv_grid_search_global_returns_valid_result(self): + """Test loocv_grid_search_global returns valid tuning parameters.""" + from diff_diff._rust_backend import loocv_grid_search_global np.random.seed(42) n_periods, n_units = 10, 20 @@ -1173,7 +1173,7 @@ def test_loocv_grid_search_joint_returns_valid_result(self): lambda_unit_grid = np.array([0.0, 1.0]) lambda_nn_grid = np.array([0.0, 0.1]) - result = loocv_grid_search_joint( + result = loocv_grid_search_global( Y, D, control_mask, lambda_time_grid, lambda_unit_grid, lambda_nn_grid, 100, 1e-6, @@ -1192,9 +1192,9 @@ def test_loocv_grid_search_joint_returns_valid_result(self): assert n_attempted > 0 assert best_score >= 0 or np.isinf(best_score) - def test_loocv_grid_search_joint_reproducible(self): - """Test loocv_grid_search_joint is deterministic (no subsampling).""" - from diff_diff._rust_backend import loocv_grid_search_joint + def test_loocv_grid_search_global_reproducible(self): + """Test loocv_grid_search_global is deterministic (no subsampling).""" + from diff_diff._rust_backend import loocv_grid_search_global np.random.seed(42) n_periods, n_units = 8, 15 @@ -1210,12 +1210,12 @@ def test_loocv_grid_search_joint_reproducible(self): lambda_unit_grid = np.array([0.0, 0.5]) lambda_nn_grid = np.array([0.0, 0.1]) - result1 = loocv_grid_search_joint( + result1 = loocv_grid_search_global( Y, D, control_mask, lambda_time_grid, lambda_unit_grid, lambda_nn_grid, 50, 1e-6, ) - result2 = loocv_grid_search_joint( + result2 = loocv_grid_search_global( Y, D, control_mask, lambda_time_grid, lambda_unit_grid, lambda_nn_grid, 50, 1e-6, @@ -1224,9 +1224,9 @@ def test_loocv_grid_search_joint_reproducible(self): # Without subsampling, results should be deterministic assert result1[:4] == result2[:4] - def test_bootstrap_trop_variance_joint_shape(self): - """Test bootstrap_trop_variance_joint returns valid output.""" - from diff_diff._rust_backend import bootstrap_trop_variance_joint + def test_bootstrap_trop_variance_global_shape(self): + """Test bootstrap_trop_variance_global returns valid output.""" + from diff_diff._rust_backend import bootstrap_trop_variance_global np.random.seed(42) n_periods, n_units = 8, 15 @@ -1237,7 +1237,7 @@ def test_bootstrap_trop_variance_joint_shape(self): D = np.zeros((n_periods, n_units)) D[-n_post:, :n_treated] = 1.0 - estimates, se = bootstrap_trop_variance_joint( + estimates, se = bootstrap_trop_variance_global( Y, D, 0.5, 0.5, 0.1, 50, 50, 1e-6, 42 ) @@ -1246,9 +1246,9 @@ def test_bootstrap_trop_variance_joint_shape(self): assert isinstance(se, float) assert se >= 0 - def test_bootstrap_trop_variance_joint_reproducible(self): - """Test bootstrap_trop_variance_joint is reproducible.""" - from diff_diff._rust_backend import bootstrap_trop_variance_joint + def test_bootstrap_trop_variance_global_reproducible(self): + """Test bootstrap_trop_variance_global is reproducible.""" + from diff_diff._rust_backend import bootstrap_trop_variance_global np.random.seed(42) n_periods, n_units = 8, 15 @@ -1259,10 +1259,10 @@ def test_bootstrap_trop_variance_joint_reproducible(self): D = np.zeros((n_periods, n_units)) D[-n_post:, :n_treated] = 1.0 - est1, se1 = bootstrap_trop_variance_joint( + est1, se1 = bootstrap_trop_variance_global( Y, D, 0.5, 0.5, 0.1, 50, 50, 1e-6, 42 ) - est2, se2 = bootstrap_trop_variance_joint( + est2, se2 = bootstrap_trop_variance_global( Y, D, 0.5, 0.5, 0.1, 50, 50, 1e-6, 42 ) @@ -1272,11 +1272,11 @@ def test_bootstrap_trop_variance_joint_reproducible(self): @pytest.mark.slow @pytest.mark.skipif(not HAS_RUST_BACKEND, reason="Rust backend not available") -class TestTROPJointRustVsNumpy: - """Tests comparing TROP joint Rust and NumPy implementations.""" +class TestTROPGlobalRustVsNumpy: + """Tests comparing TROP global Rust and NumPy implementations.""" - def test_trop_joint_produces_valid_results(self): - """Test TROP joint with Rust backend produces valid results.""" + def test_trop_global_produces_valid_results(self): + """Test TROP global with Rust backend produces valid results.""" import pandas as pd from diff_diff import TROP @@ -1305,7 +1305,7 @@ def test_trop_joint_produces_valid_results(self): df = pd.DataFrame(data) trop = TROP( - method="joint", + method="global", lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], @@ -1327,8 +1327,8 @@ def test_trop_joint_produces_valid_results(self): assert results.lambda_unit in [0.0, 1.0] assert results.lambda_nn in [0.0, 0.1] - def test_trop_joint_and_twostep_agree_in_direction(self): - """Test joint and twostep methods agree on treatment effect direction.""" + def test_trop_global_and_local_agree_in_direction(self): + """Test global and local methods agree on treatment effect direction.""" import pandas as pd from diff_diff import TROP @@ -1356,33 +1356,33 @@ def test_trop_joint_and_twostep_agree_in_direction(self): df = pd.DataFrame(data) - # Fit with joint method - trop_joint = TROP( - method="joint", + # Fit with global method + trop_global = TROP( + method="global", lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], n_bootstrap=20, seed=42 ) - results_joint = trop_joint.fit(df, 'outcome', 'treated', 'unit', 'time') + results_global = trop_global.fit(df, 'outcome', 'treated', 'unit', 'time') - # Fit with twostep method - trop_twostep = TROP( - method="twostep", + # Fit with local method + trop_local = TROP( + method="local", lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], n_bootstrap=20, seed=42 ) - results_twostep = trop_twostep.fit(df, 'outcome', 'treated', 'unit', 'time') + results_local = trop_local.fit(df, 'outcome', 'treated', 'unit', 'time') # Both should have same sign (both positive for true_effect=2.0) - assert np.sign(results_joint.att) == np.sign(results_twostep.att) + assert np.sign(results_global.att) == np.sign(results_local.att) - def test_trop_joint_handles_nan_outcomes(self): - """Test TROP joint method handles NaN outcome values gracefully.""" + def test_trop_global_handles_nan_outcomes(self): + """Test TROP global method handles NaN outcome values gracefully.""" import pandas as pd from diff_diff import TROP @@ -1423,7 +1423,7 @@ def test_trop_joint_handles_nan_outcomes(self): assert n_nan > 0, "Should have introduced some NaN values" trop = TROP( - method="joint", + method="global", lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], @@ -1440,7 +1440,7 @@ def test_trop_joint_handles_nan_outcomes(self): # ATT should still be positive (true effect is positive) assert results.att > 0, f"ATT {results.att:.2f} should be positive" - def test_trop_joint_no_valid_pre_unit_gets_zero_weight(self): + def test_trop_global_no_valid_pre_unit_gets_zero_weight(self): """Test that units with no valid pre-period data get zero weight. When a control unit has all NaN values in the pre-treatment period, @@ -1488,9 +1488,9 @@ def test_trop_joint_no_valid_pre_unit_gets_zero_weight(self): unit_pre_data = df[(df['unit'] == control_unit_with_no_pre) & (df['time'] < (n_periods - n_post))] assert unit_pre_data['outcome'].isna().all(), "Control unit should have all NaN in pre-period" - # Fit with joint method - should handle gracefully + # Fit with global method - should handle gracefully trop = TROP( - method="joint", + method="global", lambda_time_grid=[0.5, 1.0], lambda_unit_grid=[0.5, 1.0], lambda_nn_grid=[0.0], @@ -1509,7 +1509,7 @@ def test_trop_joint_no_valid_pre_unit_gets_zero_weight(self): assert abs(results.att - true_effect) < 1.5, \ f"ATT {results.att:.2f} should be close to true effect {true_effect}" - def test_trop_joint_nan_exclusion_rust_python_parity(self): + def test_trop_global_nan_exclusion_rust_python_parity(self): """Test Rust and Python backends produce matching results with NaN data. This verifies that when data contains NaN values: @@ -1559,7 +1559,7 @@ def test_trop_joint_nan_exclusion_rust_python_parity(self): # Common TROP parameters trop_params = dict( - method="joint", + method="global", lambda_time_grid=[0.5, 1.0], lambda_unit_grid=[0.5, 1.0], lambda_nn_grid=[0.0], @@ -1578,8 +1578,8 @@ def test_trop_joint_nan_exclusion_rust_python_parity(self): trop_module = sys.modules['diff_diff.trop'] with patch.object(trop_module, 'HAS_RUST_BACKEND', False), \ - patch.object(trop_module, '_rust_loocv_grid_search_joint', None), \ - patch.object(trop_module, '_rust_bootstrap_trop_variance_joint', None): + patch.object(trop_module, '_rust_loocv_grid_search_global', None), \ + patch.object(trop_module, '_rust_bootstrap_trop_variance_global', None): trop_python = TROP(**trop_params) results_python = trop_python.fit(df.copy(), 'outcome', 'treated', 'unit', 'time') @@ -1599,7 +1599,7 @@ def test_trop_joint_nan_exclusion_rust_python_parity(self): assert results_rust.att > 0, f"Rust ATT {results_rust.att} should be positive" assert results_python.att > 0, f"Python ATT {results_python.att} should be positive" - def test_trop_joint_treated_pre_nan_rust_python_parity(self): + def test_trop_global_treated_pre_nan_rust_python_parity(self): """Test Rust/Python parity when treated units have pre-period NaN. When all treated units have NaN at a pre-period, average_treated[t] = NaN. @@ -1649,7 +1649,7 @@ def test_trop_joint_treated_pre_nan_rust_python_parity(self): # Common TROP parameters trop_params = dict( - method="joint", + method="global", lambda_time_grid=[1.0], lambda_unit_grid=[1.0], lambda_nn_grid=[0.0], @@ -1668,8 +1668,8 @@ def test_trop_joint_treated_pre_nan_rust_python_parity(self): trop_module = sys.modules['diff_diff.trop'] with patch.object(trop_module, 'HAS_RUST_BACKEND', False), \ - patch.object(trop_module, '_rust_loocv_grid_search_joint', None), \ - patch.object(trop_module, '_rust_bootstrap_trop_variance_joint', None): + patch.object(trop_module, '_rust_loocv_grid_search_global', None), \ + patch.object(trop_module, '_rust_bootstrap_trop_variance_global', None): trop_python = TROP(**trop_params) results_python = trop_python.fit(df.copy(), 'outcome', 'treated', 'unit', 'time') @@ -1684,7 +1684,7 @@ def test_trop_joint_treated_pre_nan_rust_python_parity(self): f"Rust ATT ({results_rust.att:.3f}) and Python ATT ({results_python.att:.3f}) " \ f"differ by {att_diff:.3f}, should be < 0.5" - def test_trop_joint_solver_parity_no_lowrank(self): + def test_trop_global_solver_parity_no_lowrank(self): """Test Rust/Python solver parity for no-lowrank path (lambda_nn >= 1e10). Both backends should produce matching (mu, alpha, beta) at atol=1e-6. @@ -1732,8 +1732,8 @@ def test_trop_joint_solver_parity_no_lowrank(self): # Python-only backend trop_module = sys.modules['diff_diff.trop'] with patch.object(trop_module, 'HAS_RUST_BACKEND', False), \ - patch.object(trop_module, '_rust_loocv_grid_search_joint', None), \ - patch.object(trop_module, '_rust_bootstrap_trop_variance_joint', None): + patch.object(trop_module, '_rust_loocv_grid_search_global', None), \ + patch.object(trop_module, '_rust_bootstrap_trop_variance_global', None): trop_python = TROP(**trop_params) results_python = trop_python.fit(df.copy(), 'outcome', 'treated', 'unit', 'time') @@ -1754,7 +1754,7 @@ def test_trop_joint_solver_parity_no_lowrank(self): assert abs(r_val - p_val) < 1e-6, \ f"Time effect mismatch for {key}: Rust={r_val:.8f}, Python={p_val:.8f}" - def test_trop_joint_solver_parity_with_lowrank(self): + def test_trop_global_solver_parity_with_lowrank(self): """Test Rust/Python solver parity for with-lowrank path (finite lambda_nn). Both backends should produce matching (mu, alpha, beta) at atol=1e-6. @@ -1803,8 +1803,8 @@ def test_trop_joint_solver_parity_with_lowrank(self): # Python-only backend trop_module = sys.modules['diff_diff.trop'] with patch.object(trop_module, 'HAS_RUST_BACKEND', False), \ - patch.object(trop_module, '_rust_loocv_grid_search_joint', None), \ - patch.object(trop_module, '_rust_bootstrap_trop_variance_joint', None): + patch.object(trop_module, '_rust_loocv_grid_search_global', None), \ + patch.object(trop_module, '_rust_bootstrap_trop_variance_global', None): trop_python = TROP(**trop_params) results_python = trop_python.fit(df.copy(), 'outcome', 'treated', 'unit', 'time') diff --git a/tests/test_trop.py b/tests/test_trop.py index 7deabd7..da4b1f0 100644 --- a/tests/test_trop.py +++ b/tests/test_trop.py @@ -2601,7 +2601,7 @@ class TestTROPNuclearNormSolver: def test_proximal_step_size_correctness(self): """Verify L converges to prox_{λ/2}(R) for uniform weights.""" - trop_est = TROP(method="joint", n_bootstrap=2) + trop_est = TROP(method="global", n_bootstrap=2) # Small problem with known solution rng = np.random.default_rng(42) @@ -2631,7 +2631,7 @@ def test_lowrank_objective_decreases(self): delta = rng.uniform(0.5, 2.0, (6, 4)) lambda_nn = 0.3 - trop_est = TROP(method="joint", n_bootstrap=2) + trop_est = TROP(method="global", n_bootstrap=2) L = np.zeros_like(R) objectives = [] @@ -2655,14 +2655,14 @@ def test_lowrank_objective_decreases(self): f"Objective increased at step {k}: {objectives[k]} > {objectives[k-1]}" ) - def test_twostep_nonuniform_weights_objective(self): + def test_local_nonuniform_weights_objective(self): """Verify objective decreases with non-uniform weights (W_max < 1).""" rng = np.random.default_rng(123) R = rng.normal(0, 1, (6, 4)) W = rng.uniform(0.1, 0.8, (6, 4)) lambda_nn = 0.3 - trop_est = TROP(method="twostep", n_bootstrap=2) + trop_est = TROP(method="local", n_bootstrap=2) # Initial objective with L=0 L_init = np.zeros_like(R) @@ -2704,7 +2704,7 @@ def test_zero_weights_no_division_error(self): W = np.zeros((6, 4)) L_init = rng.normal(0, 1, (6, 4)) - trop_est = TROP(method="twostep", n_bootstrap=2) + trop_est = TROP(method="local", n_bootstrap=2) result = trop_est._weighted_nuclear_norm_solve( Y=Y, W=W, @@ -2719,18 +2719,18 @@ def test_zero_weights_no_division_error(self): @pytest.mark.slow -class TestTROPJointMethod: - """Tests for TROP method='joint'. +class TestTROPGlobalMethod: + """Tests for TROP method='global'. - The joint method estimates a single scalar treatment effect τ via - weighted least squares, as opposed to the twostep method which + The global method estimates a single scalar treatment effect τ via + weighted least squares, as opposed to the local method which computes per-observation effects. """ - def test_joint_basic(self, simple_panel_data): - """Joint method runs and produces reasonable ATT.""" + def test_global_basic(self, simple_panel_data): + """Global method runs and produces reasonable ATT.""" trop_est = TROP( - method="joint", + method="global", lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], @@ -2753,10 +2753,10 @@ def test_joint_basic(self, simple_panel_data): # ATT should be positive (true effect is 3.0) assert results.att > 0 - def test_joint_no_lowrank(self, simple_panel_data): - """Joint method with lambda_nn=inf (no low-rank).""" + def test_global_no_lowrank(self, simple_panel_data): + """Global method with lambda_nn=inf (no low-rank).""" trop_est = TROP( - method="joint", + method="global", lambda_time_grid=[0.0], lambda_unit_grid=[0.0], lambda_nn_grid=[float('inf')], # Disable low-rank @@ -2777,11 +2777,11 @@ def test_joint_no_lowrank(self, simple_panel_data): # Factor matrix should be all zeros assert np.allclose(results.factor_matrix, 0.0) - def test_joint_with_lowrank(self, factor_dgp_data, ci_params): - """Joint method with finite lambda_nn (with low-rank).""" + def test_global_with_lowrank(self, factor_dgp_data, ci_params): + """Global method with finite lambda_nn (with low-rank).""" n_boot = ci_params.bootstrap(20) trop_est = TROP( - method="joint", + method="global", lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1, 1.0], @@ -2801,18 +2801,18 @@ def test_joint_with_lowrank(self, factor_dgp_data, ci_params): # Should produce non-zero factor matrix if low-rank is used # (depends on which lambda_nn is selected) - def test_joint_matches_direction(self, simple_panel_data): - """Joint method sign/magnitude roughly matches twostep.""" - # Fit with twostep - trop_twostep = TROP( - method="twostep", + def test_global_matches_direction(self, simple_panel_data): + """Global method sign/magnitude roughly matches local.""" + # Fit with local + trop_local = TROP( + method="local", lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], n_bootstrap=10, seed=42, ) - results_twostep = trop_twostep.fit( + results_local = trop_local.fit( simple_panel_data, outcome="outcome", treatment="treated", @@ -2820,16 +2820,16 @@ def test_joint_matches_direction(self, simple_panel_data): time="period", ) - # Fit with joint - trop_joint = TROP( - method="joint", + # Fit with global + trop_global = TROP( + method="global", lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], n_bootstrap=10, seed=42, ) - results_joint = trop_joint.fit( + results_global = trop_global.fit( simple_panel_data, outcome="outcome", treatment="treated", @@ -2838,11 +2838,11 @@ def test_joint_matches_direction(self, simple_panel_data): ) # Both should have positive ATT (true effect is 3.0) - assert results_twostep.att > 0 - assert results_joint.att > 0 + assert results_local.att > 0 + assert results_global.att > 0 # Signs should match - assert np.sign(results_twostep.att) == np.sign(results_joint.att) + assert np.sign(results_local.att) == np.sign(results_global.att) def test_method_parameter_validation(self): """Invalid method raises ValueError.""" @@ -2865,24 +2865,38 @@ def test_method_in_get_params_joint_deprecated(self): def test_method_in_set_params(self): """method parameter can be set via set_params().""" - trop_est = TROP(method="twostep") - assert trop_est.method == "twostep" + trop_est = TROP(method="local") + assert trop_est.method == "local" trop_est.set_params(method="global") assert trop_est.method == "global" def test_method_set_params_joint_deprecated(self): """'joint' alias maps to 'global' via set_params().""" - trop_est = TROP(method="twostep") + trop_est = TROP(method="local") with pytest.warns(FutureWarning, match="deprecated"): trop_est.set_params(method="joint") assert trop_est.method == "global" - def test_joint_bootstrap_variance(self, simple_panel_data, ci_params): - """Joint method bootstrap variance estimation works.""" + def test_method_in_get_params_twostep_deprecated(self): + """'twostep' alias maps to 'local' in get_params().""" + with pytest.warns(FutureWarning, match="deprecated"): + trop_est = TROP(method="twostep") + params = trop_est.get_params() + assert params["method"] == "local" + + def test_method_set_params_twostep_deprecated(self): + """'twostep' alias maps to 'local' via set_params().""" + trop_est = TROP(method="global") + with pytest.warns(FutureWarning, match="deprecated"): + trop_est.set_params(method="twostep") + assert trop_est.method == "local" + + def test_global_bootstrap_variance(self, simple_panel_data, ci_params): + """Global method bootstrap variance estimation works.""" n_boot = ci_params.bootstrap(20) trop_est = TROP( - method="joint", + method="global", lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], @@ -2901,11 +2915,11 @@ def test_joint_bootstrap_variance(self, simple_panel_data, ci_params): assert results.n_bootstrap == n_boot assert results.bootstrap_distribution is not None - def test_joint_confidence_interval(self, simple_panel_data, ci_params): - """Joint method produces valid confidence intervals.""" + def test_global_confidence_interval(self, simple_panel_data, ci_params): + """Global method produces valid confidence intervals.""" n_boot = ci_params.bootstrap(30) trop_est = TROP( - method="joint", + method="global", lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], @@ -2925,14 +2939,14 @@ def test_joint_confidence_interval(self, simple_panel_data, ci_params): assert lower < results.att < upper assert lower < upper - def test_joint_loocv_selects_from_grid(self, simple_panel_data): - """Joint method LOOCV selects tuning parameters from the grid.""" + def test_global_loocv_selects_from_grid(self, simple_panel_data): + """Global method LOOCV selects tuning parameters from the grid.""" grid_time = [0.0, 0.5, 1.0] grid_unit = [0.0, 0.5, 1.0] grid_nn = [0.0, 0.1] trop_est = TROP( - method="joint", + method="global", lambda_time_grid=grid_time, lambda_unit_grid=grid_unit, lambda_nn_grid=grid_nn, @@ -2954,10 +2968,10 @@ def test_joint_loocv_selects_from_grid(self, simple_panel_data): # LOOCV score should be computed assert np.isfinite(results.loocv_score) or np.isnan(results.loocv_score) - def test_joint_loocv_score_internal(self, simple_panel_data): - """Test the internal _loocv_score_joint method produces valid scores.""" + def test_global_loocv_score_internal(self, simple_panel_data): + """Test the internal _loocv_score_global method produces valid scores.""" trop_est = TROP( - method="joint", + method="global", lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], @@ -2992,21 +3006,21 @@ def test_joint_loocv_score_internal(self, simple_panel_data): treated_periods = 3 # From fixture: n_post = 3 # Score should be finite - score = trop_est._loocv_score_joint( + score = trop_est._loocv_score_global( Y, D, control_obs, 0.0, 0.0, 0.0, treated_periods, n_units, n_periods ) assert np.isfinite(score) or np.isinf(score), "Score should be finite or inf" # Score with larger lambda_nn should still work - score2 = trop_est._loocv_score_joint( + score2 = trop_est._loocv_score_global( Y, D, control_obs, 1.0, 1.0, 0.1, treated_periods, n_units, n_periods ) assert np.isfinite(score2) or np.isinf(score2), "Score should be finite or inf" - def test_joint_handles_nan_outcomes(self, simple_panel_data): - """Joint method handles NaN outcome values gracefully.""" + def test_global_handles_nan_outcomes(self, simple_panel_data): + """Global method handles NaN outcome values gracefully.""" # Introduce NaN in some control observations data = simple_panel_data.copy() control_mask = data['treated'] == 0 @@ -3018,7 +3032,7 @@ def test_joint_handles_nan_outcomes(self, simple_panel_data): data.loc[nan_indices, 'outcome'] = np.nan trop_est = TROP( - method="joint", + method="global", lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], @@ -3039,8 +3053,8 @@ def test_joint_handles_nan_outcomes(self, simple_panel_data): # ATT should be positive (true effect is 3.0) assert results.att > 0, "ATT should be positive" - def test_joint_with_lowrank_handles_nan(self, simple_panel_data): - """Joint method with low-rank handles NaN values correctly.""" + def test_global_with_lowrank_handles_nan(self, simple_panel_data): + """Global method with low-rank handles NaN values correctly.""" # Introduce NaN in some control observations data = simple_panel_data.copy() control_mask = data['treated'] == 0 @@ -3052,7 +3066,7 @@ def test_joint_with_lowrank_handles_nan(self, simple_panel_data): data.loc[nan_indices, 'outcome'] = np.nan trop_est = TROP( - method="joint", + method="global", lambda_time_grid=[0.0], lambda_unit_grid=[0.0], lambda_nn_grid=[0.1], # Finite lambda_nn enables low-rank @@ -3071,7 +3085,7 @@ def test_joint_with_lowrank_handles_nan(self, simple_panel_data): assert np.isfinite(results.att), "ATT should be finite with NaN data" assert np.isfinite(results.se), "SE should be finite with NaN data" - def test_joint_nan_exclusion_behavior(self, simple_panel_data): + def test_global_nan_exclusion_behavior(self, simple_panel_data): """Verify NaN observations are truly excluded from estimation. This tests the PR #113 fix: NaN observations should not contribute @@ -3098,7 +3112,7 @@ def test_joint_nan_exclusion_behavior(self, simple_panel_data): # Fit on both versions with identical settings trop_nan = TROP( - method="joint", + method="global", lambda_time_grid=[1.0], lambda_unit_grid=[1.0], lambda_nn_grid=[0.0], # Disable low-rank for cleaner comparison @@ -3106,7 +3120,7 @@ def test_joint_nan_exclusion_behavior(self, simple_panel_data): seed=42, ) trop_dropped = TROP( - method="joint", + method="global", lambda_time_grid=[1.0], lambda_unit_grid=[1.0], lambda_nn_grid=[0.0], @@ -3136,7 +3150,7 @@ def test_joint_nan_exclusion_behavior(self, simple_panel_data): f"({results_dropped.att:.4f}) - true NaN exclusion" ) - def test_joint_unit_no_valid_pre_gets_zero_weight(self, simple_panel_data): + def test_global_unit_no_valid_pre_gets_zero_weight(self, simple_panel_data): """Verify units with no valid pre-period data get zero weight. This tests the PR #113 fix: units with no valid pre-period observations @@ -3159,7 +3173,7 @@ def test_joint_unit_no_valid_pre_gets_zero_weight(self, simple_panel_data): data.loc[mask, 'outcome'] = np.nan trop_est = TROP( - method="joint", + method="global", lambda_time_grid=[1.0], lambda_unit_grid=[1.0], # Non-zero lambda_unit to use distance weighting lambda_nn_grid=[0.0], @@ -3179,8 +3193,8 @@ def test_joint_unit_no_valid_pre_gets_zero_weight(self, simple_panel_data): assert np.isfinite(results.att), "ATT should be finite even with unit having no pre-period data" assert np.isfinite(results.se), "SE should be finite" - def test_joint_treated_pre_nan_handling(self, simple_panel_data): - """Verify joint method handles NaN in treated units during pre-periods. + def test_global_treated_pre_nan_handling(self, simple_panel_data): + """Verify global method handles NaN in treated units during pre-periods. When all treated units have NaN at a pre-period, average_treated[t] = NaN. This period should be excluded from unit distance calculation (both numerator @@ -3211,7 +3225,7 @@ def test_joint_treated_pre_nan_handling(self, simple_panel_data): assert n_nan == len(treated_units), f"Should have {len(treated_units)} NaN, got {n_nan}" trop_est = TROP( - method="joint", + method="global", lambda_time_grid=[1.0], lambda_unit_grid=[1.0], lambda_nn_grid=[0.0], @@ -3230,7 +3244,7 @@ def test_joint_treated_pre_nan_handling(self, simple_panel_data): assert np.isfinite(results.att), f"ATT should be finite, got {results.att}" assert np.isfinite(results.se), f"SE should be finite, got {results.se}" - def test_joint_rejects_staggered_adoption(self): + def test_global_rejects_staggered_adoption(self): """Global method raises ValueError for staggered adoption data. The global method assumes all treated units receive treatment at the @@ -3310,7 +3324,7 @@ def test_global_uses_control_only_weights(self, simple_panel_data): treated_periods = np.sum(np.any(D == 1, axis=1)) - delta = trop_est._compute_joint_weights( + delta = trop_est._compute_global_weights( Y, D, 1.0, 1.0, int(treated_periods), n_units, n_periods ) @@ -3404,10 +3418,10 @@ def test_global_treated_outcome_does_not_affect_fit(self, simple_panel_data): ) # Compute weights and fit with original data - delta = trop_est._compute_joint_weights( + delta = trop_est._compute_global_weights( Y, D, 1.0, 1.0, treated_periods, n_units, n_periods ) - mu1, alpha1, beta1, L1 = trop_est._solve_joint_with_lowrank( + mu1, alpha1, beta1, L1 = trop_est._solve_global_with_lowrank( Y, delta, 0.1, 100, 1e-6 ) @@ -3416,10 +3430,10 @@ def test_global_treated_outcome_does_not_affect_fit(self, simple_panel_data): Y_perturbed[D == 1] += 1000.0 # Recompute (same weights since (1-W) zeroes treated cells) - delta2 = trop_est._compute_joint_weights( + delta2 = trop_est._compute_global_weights( Y_perturbed, D, 1.0, 1.0, treated_periods, n_units, n_periods ) - mu2, alpha2, beta2, L2 = trop_est._solve_joint_with_lowrank( + mu2, alpha2, beta2, L2 = trop_est._solve_global_with_lowrank( Y_perturbed, delta2, 0.1, 100, 1e-6 ) @@ -3479,8 +3493,8 @@ def test_global_n_treated_obs_partial_nan(self): f"Expected {total_treated - n_nan}, got {results.n_treated_obs}" assert np.isfinite(results.att) - def test_twostep_n_treated_obs_partial_nan(self): - """Twostep method: n_treated_obs reflects only finite outcomes.""" + def test_local_n_treated_obs_partial_nan(self): + """Local method: n_treated_obs reflects only finite outcomes.""" df = self._make_panel() treated_mask = (df['treated'] == 1) @@ -3492,7 +3506,7 @@ def test_twostep_n_treated_obs_partial_nan(self): total_treated = int(treated_mask.sum()) trop_est = TROP( - method="twostep", + method="local", lambda_time_grid=[1.0], lambda_unit_grid=[1.0], lambda_nn_grid=[np.inf], @@ -3507,8 +3521,8 @@ def test_twostep_n_treated_obs_partial_nan(self): f"Expected {total_treated - n_nan}, got {results.n_treated_obs}" assert np.isfinite(results.att) - def test_twostep_nan_treated_not_poison_att(self): - """Twostep: NaN treated outcomes don't poison ATT via np.mean.""" + def test_local_nan_treated_not_poison_att(self): + """Local: NaN treated outcomes don't poison ATT via np.mean.""" df = self._make_panel(effect=3.0) # Make ONE treated outcome NaN @@ -3517,7 +3531,7 @@ def test_twostep_nan_treated_not_poison_att(self): df.loc[first_treated_idx, 'outcome'] = np.nan trop_est = TROP( - method="twostep", + method="local", lambda_time_grid=[1.0], lambda_unit_grid=[1.0], lambda_nn_grid=[np.inf], @@ -3558,14 +3572,14 @@ def test_global_all_treated_nan_warns(self): assert results.n_treated_obs == 0 assert np.isnan(results.att) - def test_twostep_all_treated_nan_warns(self): - """Twostep method warns when all treated outcomes are NaN.""" + def test_local_all_treated_nan_warns(self): + """Local method warns when all treated outcomes are NaN.""" df = self._make_panel() df.loc[df['treated'] == 1, 'outcome'] = np.nan trop_est = TROP( - method="twostep", + method="local", lambda_time_grid=[1.0], lambda_unit_grid=[1.0], lambda_nn_grid=[np.inf], @@ -3602,15 +3616,15 @@ def test_global_bootstrap_zero_draws_returns_nan_se(self): ) # Disable Rust backend so Python fallback path is tested, - # then patch _fit_joint_with_fixed_lambda to always raise + # then patch _fit_global_with_fixed_lambda to always raise trop_module = sys.modules['diff_diff.trop'] with patch.object(trop_module, 'HAS_RUST_BACKEND', False), \ - patch.object(trop_module, '_rust_bootstrap_trop_variance_joint', None), \ - patch.object(TROP, '_fit_joint_with_fixed_lambda', + patch.object(trop_module, '_rust_bootstrap_trop_variance_global', None), \ + patch.object(TROP, '_fit_global_with_fixed_lambda', side_effect=ValueError("forced failure")): with warnings.catch_warnings(): warnings.simplefilter("ignore") - se, dist = trop_est._bootstrap_variance_joint( + se, dist = trop_est._bootstrap_variance_global( df, 'outcome', 'treated', 'unit', 'time', (1.0, 1.0, 1e10), 3, ) @@ -3618,14 +3632,14 @@ def test_global_bootstrap_zero_draws_returns_nan_se(self): assert np.isnan(se), f"SE should be NaN when 0 draws succeed, got {se}" assert len(dist) == 0 - def test_twostep_bootstrap_zero_draws_returns_nan_se(self): - """Twostep bootstrap with 0 successful draws returns NaN SE, not 0.0.""" + def test_local_bootstrap_zero_draws_returns_nan_se(self): + """Local bootstrap with 0 successful draws returns NaN SE, not 0.0.""" from unittest.mock import patch df = TestTROPNValidTreated._make_panel() trop_est = TROP( - method="twostep", + method="local", lambda_time_grid=[1.0], lambda_unit_grid=[1.0], lambda_nn_grid=[np.inf], From 1e6bcd195169ee531f3888570142b8e48b1f65de Mon Sep 17 00:00:00 2001 From: igerber Date: Wed, 18 Mar 2026 15:09:29 -0400 Subject: [PATCH 2/4] Update stale Rust doc comments and test docstring per review Fix P3 items from code review: - Replace remaining "joint"/"twostep" wording in Rust doc comments and section headers with "global"/"local" - Rewrite TestTROPGlobalMethod docstring to match current global method semantics (single control-only fit with post-hoc residuals) Co-Authored-By: Claude Opus 4.6 (1M context) --- rust/src/trop.rs | 20 ++++++++++---------- tests/test_trop.py | 8 +++++--- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/rust/src/trop.rs b/rust/src/trop.rs index 8993a3e..78bc1b2 100644 --- a/rust/src/trop.rs +++ b/rust/src/trop.rs @@ -1081,12 +1081,12 @@ pub fn bootstrap_trop_variance<'py>( } // ============================================================================ -// Joint method implementation +// Global method implementation // ============================================================================ -/// Compute global weights for joint method estimation. +/// Compute global weights for global method estimation. /// -/// Unlike twostep (which computes per-observation weights), joint uses global +/// Unlike local (which computes per-observation weights), global uses /// weights based on: /// - Time weights: distance to center of treated block /// - Unit weights: RMSE to average treated trajectory over pre-periods @@ -1196,7 +1196,7 @@ fn compute_joint_weights( delta } -/// Solve joint TWFE via weighted least squares (no low-rank, no tau). +/// Solve global TWFE via weighted least squares (no low-rank, no tau). /// /// Minimizes: min Σ δ_{it}(Y_{it} - μ - α_i - β_t)² /// @@ -1315,7 +1315,7 @@ fn solve_joint_no_lowrank( Some((mu, alpha, beta)) } -/// Solve joint TWFE + low-rank via alternating minimization (no tau). +/// Solve global TWFE + low-rank via alternating minimization (no tau). /// /// Minimizes: min Σ δ_{it}(Y_{it} - μ - α_i - β_t - L_{it})² + λ_nn||L||_* /// @@ -1422,12 +1422,12 @@ fn solve_joint_with_lowrank( Some((mu, alpha, beta, l)) } -/// Compute LOOCV score for joint method with specific parameter combination. +/// Compute LOOCV score for global method with specific parameter combination. /// /// Following paper's Equation 5: /// Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]² /// -/// For joint method, we exclude each control observation, fit the joint model +/// For global method, we exclude each control observation, fit the global model /// on remaining data, and compute the pseudo-treatment effect at the excluded obs. /// /// # Returns @@ -1502,7 +1502,7 @@ fn loocv_score_joint( } } -/// Perform LOOCV grid search for joint method using parallel grid search. +/// Perform LOOCV grid search for global method using parallel grid search. /// /// Evaluates all combinations of (lambda_time, lambda_unit, lambda_nn) in parallel /// and returns the combination with lowest LOOCV score. @@ -1630,7 +1630,7 @@ pub fn loocv_grid_search_global<'py>( Ok((best_lt, best_lu, best_ln, best_score, n_valid, n_attempted, first_failed)) } -/// Compute bootstrap variance estimation for TROP joint method in parallel. +/// Compute bootstrap variance estimation for TROP global method in parallel. /// /// Performs unit-level block bootstrap, parallelizing across bootstrap iterations. /// Uses stratified sampling to preserve treated/control unit ratio. @@ -1737,7 +1737,7 @@ pub fn bootstrap_trop_variance_global<'py>( } } - // Compute weights and fit joint model + // Compute weights and fit global model let delta = compute_joint_weights( &y_boot.view(), &d_boot.view(), diff --git a/tests/test_trop.py b/tests/test_trop.py index da4b1f0..eafa356 100644 --- a/tests/test_trop.py +++ b/tests/test_trop.py @@ -2722,9 +2722,11 @@ def test_zero_weights_no_division_error(self): class TestTROPGlobalMethod: """Tests for TROP method='global'. - The global method estimates a single scalar treatment effect τ via - weighted least squares, as opposed to the local method which - computes per-observation effects. + The global method fits a single model on control data with global + weights, then extracts per-observation treatment effects as + residuals (τ_it = Y_it - μ - α_i - β_t - L_it). ATT is the mean + of these effects. The local method instead fits a separate model + per treated observation with observation-specific weights. """ def test_global_basic(self, simple_panel_data): From fa25f700c3fad52a7551451e9e9cedc1f173496b Mon Sep 17 00:00:00 2001 From: igerber Date: Wed, 18 Mar 2026 15:19:51 -0400 Subject: [PATCH 3/4] =?UTF-8?q?Fix=20P2:=20registry=20optimization=20summa?= =?UTF-8?q?ry=20no=20longer=20claims=20global=20solves=20for=20=CF=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The generic TROP optimization summary incorrectly described the global path as solving for (μ, α, β, τ) with τ·D in the residual. The actual implementation solves for (μ, α, β, L) on control data and extracts τ_it post-hoc as residuals. Clarify the summary and point to the dedicated Global section for details. Co-Authored-By: Claude Opus 4.6 (1M context) --- docs/methodology/REGISTRY.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index e846543..de4f46a 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -1275,13 +1275,15 @@ Optimization (Equation 2): ``` (α̂, β̂, L̂) = argmin_{α,β,L} Σ_j Σ_s θ_s^{i,t} ω_j^{i,t} (1-W_js)(Y_js - α_j - β_s - L_js)² + λ_nn ||L||_* ``` -Solved via alternating minimization. For α, β (or μ, α, β, τ in global): weighted least -squares (closed form). For L: proximal gradient with step size η = 1/(2·max(W)): +Solved via alternating minimization. For α, β: weighted least squares (closed form). +The global solver adds an intercept μ and solves for (μ, α, β, L) on control data only, +extracting τ_it post-hoc as residuals (see Global section below). +For L: proximal gradient with step size η = 1/(2·max(W)): ``` Gradient step: G = L + (W/max(W)) ⊙ (R - L) Proximal step: L = U × soft_threshold(Σ, η·λ_nn) × V' (SVD of G = UΣV') ``` -where R is the residual after removing fixed effects (and τ·D in global mode). +where R is the residual after removing fixed effects. Both the local and global solvers use FISTA/Nesterov acceleration for the inner L update (O(1/k²) convergence rate, up to 20 inner iterations per outer alternating step). From 528faf51ab03feb822633ca8c9daf68976acd08a Mon Sep 17 00:00:00 2001 From: igerber Date: Wed, 18 Mar 2026 16:02:24 -0400 Subject: [PATCH 4/4] Fix P3: stale test docstring and Rust private naming note - Rewrite test_global_method_alias docstring to match its actual assertion (validates method="global" produces valid ATT) - Add comment in Rust explaining that only exported #[pyfunction] names were renamed; private helpers retain *_joint* names Co-Authored-By: Claude Opus 4.6 (1M context) --- rust/src/trop.rs | 4 ++++ tests/test_trop.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/rust/src/trop.rs b/rust/src/trop.rs index 78bc1b2..e26528e 100644 --- a/rust/src/trop.rs +++ b/rust/src/trop.rs @@ -1082,6 +1082,10 @@ pub fn bootstrap_trop_variance<'py>( // ============================================================================ // Global method implementation +// +// Note: Only the #[pyfunction] exports were renamed (joint → global) to match +// the Python public API. The private Rust helpers below retain their original +// `*_joint*` names to keep the Rust-only rename scope minimal. // ============================================================================ /// Compute global weights for global method estimation. diff --git a/tests/test_trop.py b/tests/test_trop.py index eafa356..f8207f4 100644 --- a/tests/test_trop.py +++ b/tests/test_trop.py @@ -3275,7 +3275,7 @@ def test_global_rejects_staggered_adoption(self): trop.fit(df, 'outcome', 'treated', 'unit', 'time') def test_global_method_alias(self, simple_panel_data): - """method='global' works and produces same results as deprecated 'joint'.""" + """method='global' runs and produces a valid positive ATT.""" trop_est = TROP( method="global", lambda_time_grid=[0.0, 1.0],