diff --git a/CHANGELOG.md b/CHANGELOG.md index 98c0844..27d18e5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Changed +- Rename TROP `method="twostep"` to `method="local"`; `"twostep"` deprecated, removal in v3.0 +- Rename internal TROP `_joint_*` methods to `_global_*` for consistency + ## [2.7.1] - 2026-03-15 ### Changed diff --git a/CLAUDE.md b/CLAUDE.md index 8fc7397..29b1407 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -122,7 +122,7 @@ category (`Methodology/Correctness`, `Performance`, or `Testing/Docs`): `threshold = 0.40 if n_boot < 100 else 0.15`. - **`assert_nan_inference()`** from conftest.py: Use to validate ALL inference fields are NaN-consistent. Don't check individual fields separately. -- **Slow tests**: TROP methodology/joint-method tests, Sun-Abraham bootstrap, and +- **Slow tests**: TROP methodology/global-method tests, Sun-Abraham bootstrap, and TROP-parity tests are marked `@pytest.mark.slow` and excluded by default via `addopts`. `test_trop.py` uses per-class markers (not file-level) so that validation, API, and solver tests still run in the pure Python CI fallback. Run `pytest -m ''` to include diff --git a/README.md b/README.md index ae28d3c..3467264 100644 --- a/README.md +++ b/README.md @@ -1517,7 +1517,7 @@ trop = TROP( ```python TROP( - method='twostep', # Estimation method: 'twostep' (default) or 'joint' + method='local', # Estimation method: 'local' (default) or 'global' lambda_time_grid=None, # Time decay grid (default: [0, 0.1, 0.5, 1, 2, 5]) lambda_unit_grid=None, # Unit distance grid (default: [0, 0.1, 0.5, 1, 2, 5]) lambda_nn_grid=None, # Nuclear norm grid (default: [0, 0.01, 0.1, 1, 10]) @@ -1530,8 +1530,8 @@ TROP( ``` **Estimation methods:** -- `'twostep'` (default): Per-observation model fitting following Algorithm 2 of the paper. Computes observation-specific weights and fits a model for each treated observation, then averages the individual treatment effects. More flexible but computationally intensive. -- `'joint'`: Joint weighted least squares optimization. Estimates a single scalar treatment effect τ along with fixed effects and optional low-rank factor adjustment. Faster but assumes homogeneous treatment effects. +- `'local'` (default): Per-observation model fitting following Algorithm 2 of the paper. Computes observation-specific weights and fits a model for each treated observation, then averages the individual treatment effects. More flexible but computationally intensive. +- `'global'`: Global weighted least squares optimization. Fits a single model on control observations with global weights, then computes per-observation treatment effects as residuals. Faster but uses global rather than observation-specific weights. **Convenience function:** diff --git a/diff_diff/_backend.py b/diff_diff/_backend.py index 0c6e99f..2065e17 100644 --- a/diff_diff/_backend.py +++ b/diff_diff/_backend.py @@ -23,13 +23,13 @@ project_simplex as _rust_project_simplex, solve_ols as _rust_solve_ols, compute_robust_vcov as _rust_compute_robust_vcov, - # TROP estimator acceleration (twostep method) + # TROP estimator acceleration (local method) compute_unit_distance_matrix as _rust_unit_distance_matrix, loocv_grid_search as _rust_loocv_grid_search, bootstrap_trop_variance as _rust_bootstrap_trop_variance, - # TROP estimator acceleration (joint method) - loocv_grid_search_joint as _rust_loocv_grid_search_joint, - bootstrap_trop_variance_joint as _rust_bootstrap_trop_variance_joint, + # TROP estimator acceleration (global method) + loocv_grid_search_global as _rust_loocv_grid_search_global, + bootstrap_trop_variance_global as _rust_bootstrap_trop_variance_global, # SDID weights (Frank-Wolfe matching R's synthdid) compute_sdid_unit_weights as _rust_sdid_unit_weights, compute_time_weights as _rust_compute_time_weights, @@ -46,13 +46,13 @@ _rust_project_simplex = None _rust_solve_ols = None _rust_compute_robust_vcov = None - # TROP estimator acceleration (twostep method) + # TROP estimator acceleration (local method) _rust_unit_distance_matrix = None _rust_loocv_grid_search = None _rust_bootstrap_trop_variance = None - # TROP estimator acceleration (joint method) - _rust_loocv_grid_search_joint = None - _rust_bootstrap_trop_variance_joint = None + # TROP estimator acceleration (global method) + _rust_loocv_grid_search_global = None + _rust_bootstrap_trop_variance_global = None # SDID weights (Frank-Wolfe matching R's synthdid) _rust_sdid_unit_weights = None _rust_compute_time_weights = None @@ -69,13 +69,13 @@ _rust_project_simplex = None _rust_solve_ols = None _rust_compute_robust_vcov = None - # TROP estimator acceleration (twostep method) + # TROP estimator acceleration (local method) _rust_unit_distance_matrix = None _rust_loocv_grid_search = None _rust_bootstrap_trop_variance = None - # TROP estimator acceleration (joint method) - _rust_loocv_grid_search_joint = None - _rust_bootstrap_trop_variance_joint = None + # TROP estimator acceleration (global method) + _rust_loocv_grid_search_global = None + _rust_bootstrap_trop_variance_global = None # SDID weights (Frank-Wolfe matching R's synthdid) _rust_sdid_unit_weights = None _rust_compute_time_weights = None @@ -118,13 +118,13 @@ def rust_backend_info(): '_rust_project_simplex', '_rust_solve_ols', '_rust_compute_robust_vcov', - # TROP estimator acceleration (twostep method) + # TROP estimator acceleration (local method) '_rust_unit_distance_matrix', '_rust_loocv_grid_search', '_rust_bootstrap_trop_variance', - # TROP estimator acceleration (joint method) - '_rust_loocv_grid_search_joint', - '_rust_bootstrap_trop_variance_joint', + # TROP estimator acceleration (global method) + '_rust_loocv_grid_search_global', + '_rust_bootstrap_trop_variance_global', # SDID weights (Frank-Wolfe matching R's synthdid) '_rust_sdid_unit_weights', '_rust_compute_time_weights', diff --git a/diff_diff/trop.py b/diff_diff/trop.py index abb21f9..8f2c1f3 100644 --- a/diff_diff/trop.py +++ b/diff_diff/trop.py @@ -31,8 +31,8 @@ _rust_unit_distance_matrix, _rust_loocv_grid_search, _rust_bootstrap_trop_variance, - _rust_loocv_grid_search_joint, - _rust_bootstrap_trop_variance_joint, + _rust_loocv_grid_search_global, + _rust_bootstrap_trop_variance_global, ) from diff_diff.trop_results import ( _LAMBDA_INF, @@ -63,10 +63,10 @@ class TROP: Parameters ---------- - method : str, default='twostep' + method : str, default='local' Estimation method to use: - - 'twostep': Per-observation model fitting following Algorithm 2 of + - 'local': Per-observation model fitting following Algorithm 2 of Athey et al. (2025). Computes observation-specific weights and fits a model for each treated observation, averaging the individual treatment effects. More flexible but computationally intensive. @@ -77,10 +77,11 @@ class TROP: treatment effects as residuals: tau_it = Y_it - mu - alpha_i - beta_t - L_it for treated cells. ATT is the mean of these effects. For the paper's full - per-treated-cell estimator, use ``method='twostep'``. + per-treated-cell estimator, use ``method='local'``. - - 'joint': Deprecated alias for 'global'. Will be removed in a - future version. + - 'twostep': Deprecated alias for 'local'. Will be removed in v3.0. + + - 'joint': Deprecated alias for 'global'. Will be removed in v3.0. lambda_time_grid : list, optional Grid of time weight decay parameters. 0.0 = uniform weights (disabled). @@ -138,7 +139,7 @@ class TROP: def __init__( self, - method: str = "twostep", + method: str = "local", lambda_time_grid: Optional[List[float]] = None, lambda_unit_grid: Optional[List[float]] = None, lambda_nn_grid: Optional[List[float]] = None, @@ -149,16 +150,24 @@ def __init__( seed: Optional[int] = None, ): # Validate method parameter - # 'global' is the preferred name; 'joint' is a deprecated alias - valid_methods = ("twostep", "joint", "global") + # 'local'/'global' are preferred; 'twostep'/'joint' are deprecated aliases + valid_methods = ("local", "twostep", "joint", "global") if method not in valid_methods: raise ValueError( f"method must be one of {valid_methods}, got '{method}'" ) + if method == "twostep": + warnings.warn( + "method='twostep' is deprecated and will be removed in v3.0. " + "Use method='local' instead.", + FutureWarning, + stacklevel=2, + ) + method = "local" if method == "joint": warnings.warn( - "method='joint' is deprecated and will be removed in a future " - "version. Use method='global' instead.", + "method='joint' is deprecated and will be removed in v3.0. " + "Use method='global' instead.", FutureWarning, stacklevel=2, ) @@ -556,7 +565,7 @@ def _cycling_parameter_search( # Joint estimation method # ========================================================================= - def _compute_joint_weights( + def _compute_global_weights( self, Y: np.ndarray, D: np.ndarray, @@ -567,7 +576,7 @@ def _compute_joint_weights( n_periods: int, ) -> np.ndarray: """ - Compute distance-based weights for joint estimation. + Compute distance-based weights for global estimation. Following the reference implementation, weights are computed based on: - Time distance: distance to center of treated block @@ -655,7 +664,7 @@ def _compute_joint_weights( return delta - def _solve_joint_model( + def _solve_global_model( self, Y: np.ndarray, delta: np.ndarray, @@ -668,10 +677,10 @@ def _solve_joint_model( """ n_periods, n_units = Y.shape if lambda_nn >= 1e10: - mu, alpha, beta = self._solve_joint_no_lowrank(Y, delta) + mu, alpha, beta = self._solve_global_no_lowrank(Y, delta) L = np.zeros((n_periods, n_units)) else: - mu, alpha, beta, L = self._solve_joint_with_lowrank( + mu, alpha, beta, L = self._solve_global_with_lowrank( Y, delta, lambda_nn, self.max_iter, self.tol ) return mu, alpha, beta, L @@ -718,7 +727,7 @@ def _extract_posthoc_tau( return att, treatment_effects, tau_values - def _loocv_score_joint( + def _loocv_score_global( self, Y: np.ndarray, D: np.ndarray, @@ -731,12 +740,12 @@ def _loocv_score_joint( n_periods: int, ) -> float: """ - Compute LOOCV score for joint method with specific parameter combination. + Compute LOOCV score for global method with specific parameter combination. Following paper's Equation 5: Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]² - For joint method, we exclude each control observation, fit the joint model + For global method, we exclude each control observation, fit the global model on remaining data, and compute the pseudo-treatment effect at the excluded obs. Parameters @@ -766,7 +775,7 @@ def _loocv_score_joint( LOOCV score (sum of squared pseudo-treatment effects). """ # Compute global weights (same for all LOOCV iterations) - delta = self._compute_joint_weights( + delta = self._compute_global_weights( Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods ) @@ -779,7 +788,7 @@ def _loocv_score_joint( delta_ex[t_ex, i_ex] = 0.0 try: - mu, alpha, beta, L = self._solve_joint_model(Y, delta_ex, lambda_nn) + mu, alpha, beta, L = self._solve_global_model(Y, delta_ex, lambda_nn) # Pseudo treatment effect: τ = Y - μ - α - β - L if np.isfinite(Y[t_ex, i_ex]): @@ -796,7 +805,7 @@ def _loocv_score_joint( return tau_sq_sum - def _solve_joint_no_lowrank( + def _solve_global_no_lowrank( self, Y: np.ndarray, delta: np.ndarray, @@ -806,7 +815,7 @@ def _solve_joint_no_lowrank( Solves: min Σ (1-W)*δ_{it}(Y_{it} - μ - α_i - β_t)² - The (1-W) masking is already applied to delta by _compute_joint_weights, + The (1-W) masking is already applied to delta by _compute_global_weights, so treated observations have zero weight and do not affect the fit. Parameters @@ -880,7 +889,7 @@ def _solve_joint_no_lowrank( return float(mu), alpha, beta - def _solve_joint_with_lowrank( + def _solve_global_with_lowrank( self, Y: np.ndarray, delta: np.ndarray, @@ -893,7 +902,7 @@ def _solve_joint_with_lowrank( Solves: min Σ (1-W)*δ_{it}(Y_{it} - μ - α_i - β_t - L_{it})² + λ_nn||L||_* - The (1-W) masking is already applied to delta by _compute_joint_weights, + The (1-W) masking is already applied to delta by _compute_global_weights, so treated observations have zero weight and do not affect the fit. Parameters @@ -942,7 +951,7 @@ def _solve_joint_with_lowrank( # Step 1: Fix L, solve for (mu, alpha, beta) Y_adj = Y_safe - L - mu, alpha, beta = self._solve_joint_no_lowrank(Y_adj, delta_masked) + mu, alpha, beta = self._solve_global_no_lowrank(Y_adj, delta_masked) # Step 2: Fix (mu, alpha, beta), update L with FISTA acceleration R = Y_safe - mu - alpha[np.newaxis, :] - beta[:, np.newaxis] @@ -981,11 +990,11 @@ def _solve_joint_with_lowrank( # Final re-solve with converged L (match Rust behavior) Y_adj = Y_safe - L - mu, alpha, beta = self._solve_joint_no_lowrank(Y_adj, delta_masked) + mu, alpha, beta = self._solve_global_no_lowrank(Y_adj, delta_masked) return mu, alpha, beta, L - def _fit_joint( + def _fit_global( self, data: pd.DataFrame, outcome: str, @@ -1024,10 +1033,10 @@ def _fit_joint( (fixed `treated_periods` across resamples). The treatment timing is inferred from the data once and held constant for all bootstrap iterations. For staggered adoption designs where treatment timing varies - across units, use `method="twostep"` which computes observation-specific + across units, use `method="local"` which computes observation-specific weights that naturally handle heterogeneous timing. """ - # Data setup (same as twostep method) + # Data setup (same as local method) all_units = sorted(data[unit].unique()) all_periods = sorted(data[time].unique()) @@ -1097,7 +1106,7 @@ def _fit_joint( if n_pre_periods < 2: raise ValueError("Need at least 2 pre-treatment periods") - # Check for staggered adoption (joint method requires simultaneous treatment) + # Check for staggered adoption (global method requires simultaneous treatment) # Use only observed periods (skip missing) to avoid false positives on unbalanced panels first_treat_by_unit = [] for i in treated_unit_idx: @@ -1115,7 +1124,7 @@ def _fit_joint( raise ValueError( f"method='global' requires simultaneous treatment adoption, but your data " f"shows staggered adoption (units first treated at periods {unique_starts}). " - f"Use method='twostep' which properly handles staggered adoption designs." + f"Use method='local' which properly handles staggered adoption designs." ) # LOOCV grid search for tuning parameters @@ -1124,7 +1133,7 @@ def _fit_joint( best_score = np.inf control_mask = D == 0 - if HAS_RUST_BACKEND and _rust_loocv_grid_search_joint is not None: + if HAS_RUST_BACKEND and _rust_loocv_grid_search_global is not None: try: # Prepare inputs for Rust function control_mask_u8 = control_mask.astype(np.uint8) @@ -1133,7 +1142,7 @@ def _fit_joint( lambda_unit_arr = np.array(self.lambda_unit_grid, dtype=np.float64) lambda_nn_arr = np.array(self.lambda_nn_grid, dtype=np.float64) - result = _rust_loocv_grid_search_joint( + result = _rust_loocv_grid_search_global( Y, D.astype(np.float64), control_mask_u8, lambda_time_arr, lambda_unit_arr, lambda_nn_arr, self.max_iter, self.tol, @@ -1170,7 +1179,7 @@ def _fit_joint( except Exception as e: # Fall back to Python implementation on error logger.debug( - "Rust LOOCV grid search (joint) failed, falling back to Python: %s", e + "Rust LOOCV grid search (global) failed, falling back to Python: %s", e ) best_lambda = None best_score = np.inf @@ -1193,7 +1202,7 @@ def _fit_joint( ln = 1e10 if np.isinf(lambda_nn_val) else lambda_nn_val try: - score = self._loocv_score_joint( + score = self._loocv_score_global( Y, D, control_obs, lt, lu, ln, treated_periods, n_units, n_periods ) @@ -1223,11 +1232,11 @@ def _fit_joint( lambda_nn = 1e10 # Compute final weights and fit - delta = self._compute_joint_weights( + delta = self._compute_global_weights( Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods ) - mu, alpha, beta, L = self._solve_joint_model(Y, delta, lambda_nn) + mu, alpha, beta, L = self._solve_global_model(Y, delta, lambda_nn) # Post-hoc tau extraction (per paper Eq. 2) att, treatment_effects, tau_values = self._extract_posthoc_tau( @@ -1258,7 +1267,7 @@ def _fit_joint( # Bootstrap variance estimation effective_lambda = (lambda_time, lambda_unit, lambda_nn) - se, bootstrap_dist = self._bootstrap_variance_joint( + se, bootstrap_dist = self._bootstrap_variance_global( data, outcome, treatment, unit, time, effective_lambda, treated_periods ) @@ -1300,7 +1309,7 @@ def _fit_joint( self.is_fitted_ = True return self.results_ - def _bootstrap_variance_joint( + def _bootstrap_variance_global( self, data: pd.DataFrame, outcome: str, @@ -1311,7 +1320,7 @@ def _bootstrap_variance_joint( treated_periods: int, ) -> Tuple[float, np.ndarray]: """ - Compute bootstrap standard error for joint method. + Compute bootstrap standard error for global method. Uses Rust backend when available for parallel bootstrap (5-15x speedup). @@ -1340,7 +1349,7 @@ def _bootstrap_variance_joint( lambda_time, lambda_unit, lambda_nn = optimal_lambda # Try Rust backend for parallel bootstrap (5-15x speedup) - if HAS_RUST_BACKEND and _rust_bootstrap_trop_variance_joint is not None: + if HAS_RUST_BACKEND and _rust_bootstrap_trop_variance_global is not None: try: # Create matrices for Rust function all_units = sorted(data[unit].unique()) @@ -1359,7 +1368,7 @@ def _bootstrap_variance_joint( .values ) - bootstrap_estimates, se = _rust_bootstrap_trop_variance_joint( + bootstrap_estimates, se = _rust_bootstrap_trop_variance_global( Y, D, lambda_time, lambda_unit, lambda_nn, self.n_bootstrap, self.max_iter, self.tol, @@ -1378,7 +1387,7 @@ def _bootstrap_variance_joint( except Exception as e: logger.debug( - "Rust bootstrap (joint) failed, falling back to Python: %s", e + "Rust bootstrap (global) failed, falling back to Python: %s", e ) # Python fallback implementation @@ -1419,7 +1428,7 @@ def _bootstrap_variance_joint( ], ignore_index=True) try: - tau = self._fit_joint_with_fixed_lambda( + tau = self._fit_global_with_fixed_lambda( boot_data, outcome, treatment, unit, time, optimal_lambda, treated_periods ) @@ -1441,7 +1450,7 @@ def _bootstrap_variance_joint( se = np.std(bootstrap_estimates, ddof=1) return float(se), bootstrap_estimates - def _fit_joint_with_fixed_lambda( + def _fit_global_with_fixed_lambda( self, data: pd.DataFrame, outcome: str, @@ -1478,12 +1487,12 @@ def _fit_joint_with_fixed_lambda( ) # Compute weights (includes (1-W) masking) - delta = self._compute_joint_weights( + delta = self._compute_global_weights( Y, D, lambda_time, lambda_unit, treated_periods, n_units, n_periods ) # Fit model on control data and extract post-hoc tau - mu, alpha, beta, L = self._solve_joint_model(Y, delta, lambda_nn) + mu, alpha, beta, L = self._solve_global_model(Y, delta, lambda_nn) att, _, _ = self._extract_posthoc_tau(Y, D, mu, alpha, beta, L) return att @@ -1541,9 +1550,9 @@ def fit( # Dispatch based on estimation method if self.method == "global": - return self._fit_joint(data, outcome, treatment, unit, time) + return self._fit_global(data, outcome, treatment, unit, time) - # Below is the twostep method (default) + # Below is the local method (default) # Get unique units and periods all_units = sorted(data[unit].unique()) all_periods = sorted(data[time].unique()) @@ -2660,10 +2669,18 @@ def get_params(self) -> Dict[str, Any]: def set_params(self, **params) -> "TROP": """Set estimator parameters.""" for key, value in params.items(): + if key == "method" and value == "twostep": + warnings.warn( + "method='twostep' is deprecated and will be removed in " + "v3.0. Use method='local' instead.", + FutureWarning, + stacklevel=2, + ) + value = "local" if key == "method" and value == "joint": warnings.warn( - "method='joint' is deprecated and will be removed in a " - "future version. Use method='global' instead.", + "method='joint' is deprecated and will be removed in " + "v3.0. Use method='global' instead.", FutureWarning, stacklevel=2, ) diff --git a/docs/api/trop.rst b/docs/api/trop.rst index cfcdc8c..2610624 100644 --- a/docs/api/trop.rst +++ b/docs/api/trop.rst @@ -100,7 +100,7 @@ Estimation Methods TROP supports two estimation methods via the ``method`` parameter: -**Two-Step Method** (``method='twostep'``, default) +**Local Method** (``method='local'``, default) The default method follows Algorithm 2 from the paper: @@ -124,7 +124,7 @@ the estimator is consistent if any one of the three components A computationally efficient adaptation using the ``(1-W)`` masking principle from Eq. 2. Fits a single global model rather than per-treated-cell models. For the paper's full per-treated-cell estimator (Algorithm 2), use -``method='twostep'``. +``method='local'``. 1. **Compute weights**: Distance-based unit and time weights computed once (distance to center of treated block, RMSE to average treated trajectory), @@ -145,15 +145,16 @@ For the paper's full per-treated-cell estimator (Algorithm 2), use The global method is **faster** (single optimization vs N_treated optimizations). Treatment effects are **heterogeneous** per-observation residuals; ATT is their mean. -``method='joint'`` is a deprecated alias for ``method='global'`` and will be -removed in a future version. +``method='twostep'`` is a deprecated alias for ``method='local'`` and will be +removed in v3.0. ``method='joint'`` is a deprecated alias for ``method='global'`` +and will be removed in v3.0. .. list-table:: :header-rows: 1 :widths: 20 40 40 * - Feature - - Two-Step (default) + - Local (default) - Global * - Treatment effect - Per-observation τ_{it} (per-obs models) @@ -168,7 +169,7 @@ removed in a future version. - Observation-specific - Global (center of treated block) -Use ``method='twostep'`` for observation-specific weight optimization. +Use ``method='local'`` for observation-specific weight optimization. Use ``method='global'`` for faster estimation with global weights. Example Usage @@ -228,9 +229,9 @@ Using the global method for faster estimation:: unit='unit_id', time='period') # Compare methods - trop_twostep = TROP(method='twostep', ...) # Default (per-observation) - results_twostep = trop_twostep.fit(data, ...) - print(f"Two-step ATT: {results_twostep.att:.3f}") + trop_local = TROP(method='local', ...) # Default (per-observation) + results_local = trop_local.fit(data, ...) + print(f"Local ATT: {results_local.att:.3f}") print(f"Global ATT: {results_global.att:.3f}") Examining factor structure:: diff --git a/docs/choosing_estimator.rst b/docs/choosing_estimator.rst index 3670f9a..8a13c85 100644 --- a/docs/choosing_estimator.rst +++ b/docs/choosing_estimator.rst @@ -405,7 +405,7 @@ exponential unit distance weights, and time decay weights with LOOCV tuning. .. note:: TROP is computationally intensive. Use ``method='global'`` for faster - estimation at the cost of some flexibility vs. ``method='twostep'``. + estimation at the cost of some flexibility vs. ``method='local'``. Bacon Decomposition ~~~~~~~~~~~~~~~~~~~ diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 54bfbac..de4f46a 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -1275,14 +1275,16 @@ Optimization (Equation 2): ``` (α̂, β̂, L̂) = argmin_{α,β,L} Σ_j Σ_s θ_s^{i,t} ω_j^{i,t} (1-W_js)(Y_js - α_j - β_s - L_js)² + λ_nn ||L||_* ``` -Solved via alternating minimization. For α, β (or μ, α, β, τ in joint): weighted least -squares (closed form). For L: proximal gradient with step size η = 1/(2·max(W)): +Solved via alternating minimization. For α, β: weighted least squares (closed form). +The global solver adds an intercept μ and solves for (μ, α, β, L) on control data only, +extracting τ_it post-hoc as residuals (see Global section below). +For L: proximal gradient with step size η = 1/(2·max(W)): ``` Gradient step: G = L + (W/max(W)) ⊙ (R - L) Proximal step: L = U × soft_threshold(Σ, η·λ_nn) × V' (SVD of G = UΣV') ``` -where R is the residual after removing fixed effects (and τ·D in joint mode). -Both the twostep and global solvers use FISTA/Nesterov acceleration for the +where R is the residual after removing fixed effects. +Both the local and global solvers use FISTA/Nesterov acceleration for the inner L update (O(1/k²) convergence rate, up to 20 inner iterations per outer alternating step). @@ -1372,12 +1374,13 @@ Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]² ### TROP Global Estimation Method -**Method**: `method="global"` in TROP estimator (`method="joint"` is a deprecated alias) +**Method**: `method="global"` in TROP estimator (`method="joint"` is a deprecated alias; +`method="twostep"` is a deprecated alias for `method="local"`) **Approach**: Computationally efficient adaptation using the (1-W) masking principle from Eq. 2. Fits a single global model on control data, then extracts treatment effects as post-hoc residuals. For the paper's full -per-treated-cell estimator (Algorithm 2), use `method='twostep'`. +per-treated-cell estimator (Algorithm 2), use `method='local'`. **Objective function** (Equation G1): ``` @@ -1400,7 +1403,7 @@ ATT = mean(τ̂_{it}) over all treated observations Treatment effects are **heterogeneous** per-observation values. ATT is their mean. -**Weight computation** (differs from twostep): +**Weight computation** (differs from local): - Time weights: δ_time(t) = exp(-λ_time × |t - center|) where center = T - treated_periods/2 - Unit weights: δ_unit(i) = exp(-λ_unit × RMSE(i, treated_avg)) where RMSE is computed over pre-treatment periods comparing to average treated trajectory @@ -1424,7 +1427,7 @@ Treatment effects are **heterogeneous** per-observation values. ATT is their mea 3. **Post-hoc**: Extract τ̂_{it} = Y_{it} - μ̂ - α̂_i - β̂_t - L̂_{it} for treated cells -**LOOCV parameter selection** (unified with twostep, Equation 5): +**LOOCV parameter selection** (unified with local, Equation 5): Following paper's Equation 5 and footnote 2: ``` Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]² @@ -1441,14 +1444,14 @@ For global method, LOOCV works as follows: 3. Select λ combination that minimizes Q(λ) **Rust acceleration**: The LOOCV grid search is parallelized in Rust for 5-10x speedup. -- `loocv_grid_search_joint()` - Parallel LOOCV across all λ combinations -- `bootstrap_trop_variance_joint()` - Parallel bootstrap variance estimation +- `loocv_grid_search_global()` - Parallel LOOCV across all λ combinations +- `bootstrap_trop_variance_global()` - Parallel bootstrap variance estimation -**Key differences from twostep method**: +**Key differences from local method**: - Global weights (distance to treated block center) vs. per-observation weights - Single model fit per λ combination vs. N_treated fits - Treatment effects are post-hoc residuals from a single global model (global) - vs. post-hoc residuals from per-observation models (twostep) + vs. post-hoc residuals from per-observation models (local) - Both use (1-W) masking (control-only fitting) - Faster computation for large panels @@ -1457,14 +1460,14 @@ For global method, LOOCV works as follows: to receive treatment at the same time. A `ValueError` is raised if staggered adoption is detected (units first treated at different periods). Treatment timing is inferred once and held constant for bootstrap variance estimation. - For staggered adoption designs, use `method="twostep"`. + For staggered adoption designs, use `method="local"`. **Reference**: Adapted from reference implementation. See also Athey et al. (2025). **Edge Cases (treated NaN outcomes):** - **Partial NaN**: When some treated outcomes Y_{it} are NaN/missing: - `_extract_posthoc_tau()` (global) skips these cells; only finite τ̂ values are averaged - - Twostep loop skips NaN outcomes entirely (no model fit, no tau appended) + - Local loop skips NaN outcomes entirely (no model fit, no tau appended) - `n_treated_obs` in results reflects valid (finite) count, not total D==1 count - `df_trop = max(1, n_valid_treated - 1)` uses valid count - Warning issued when n_valid_treated < total treated count @@ -1475,13 +1478,15 @@ For global method, LOOCV works as follows: iterations succeed. `safe_inference()` propagates NaN downstream. **Requirements checklist:** -- [x] Same LOOCV framework as twostep (Equation 5) +- [x] Same LOOCV framework as local (Equation 5) - [x] Global weight computation using treated block center - [x] (1-W) masking for control-only fitting (per paper Eq. 2) - [x] Alternating minimization for nuclear norm penalty - [x] Returns ATT = mean of per-observation post-hoc τ̂_{it} - [x] Rust acceleration for LOOCV and bootstrap +- **Note:** `method="twostep"` renamed to `method="local"` and `method="joint"` renamed to `method="global"` to form a natural local/global pair. Both old names are deprecated aliases, removal planned for v3.0. + --- # Diagnostics & Sensitivity diff --git a/docs/troubleshooting.rst b/docs/troubleshooting.rst index 3fa925a..ade5c89 100644 --- a/docs/troubleshooting.rst +++ b/docs/troubleshooting.rst @@ -582,6 +582,26 @@ inaccurate with missing observations. Deprecation Warnings -------------------- +"method='twostep' is deprecated" +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +**Problem:** TROP emits a ``FutureWarning`` that ``method='twostep'`` is +deprecated. + +**Causes:** + +1. Code uses the old ``method='twostep'`` parameter name + +**Solutions:** + +.. code-block:: python + + # Old (deprecated) + trop = TROP(method='twostep') + + # New (use 'local' instead) + trop = TROP(method='local') + "method='joint' is deprecated" ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/tutorials/10_trop.ipynb b/docs/tutorials/10_trop.ipynb index 8844660..5c8d9d5 100644 --- a/docs/tutorials/10_trop.ipynb +++ b/docs/tutorials/10_trop.ipynb @@ -598,14 +598,14 @@ }, { "cell_type": "code", - "source": "# Compare estimation methods\nprint(\"Estimation method comparison:\")\nprint(\"=\"*60)\n\nimport time\n\n# Two-step method (default)\nstart = time.time()\ntrop_twostep = TROP(\n method='twostep',\n lambda_time_grid=[0.0, 1.0],\n lambda_unit_grid=[0.0, 1.0], \n lambda_nn_grid=[0.0, 0.1],\n n_bootstrap=20,\n seed=42\n)\nresults_twostep = trop_twostep.fit(\n df,\n outcome='outcome',\n treatment='treated',\n unit='unit',\n time='period'\n)\ntwostep_time = time.time() - start\n\n# Global method\nstart = time.time()\ntrop_global = TROP(\n method='global',\n lambda_time_grid=[0.0, 1.0],\n lambda_unit_grid=[0.0, 1.0], \n lambda_nn_grid=[0.0, 0.1],\n n_bootstrap=20,\n seed=42\n)\nresults_global = trop_global.fit(\n df,\n outcome='outcome',\n treatment='treated',\n unit='unit',\n time='period'\n)\nglobal_time = time.time() - start\n\nprint(f\"\\n{'Method':<15} {'ATT':>10} {'SE':>10} {'Time (s)':>12}\")\nprint(\"-\"*60)\nprint(f\"{'Two-step':<15} {results_twostep.att:>10.4f} {results_twostep.se:>10.4f} {twostep_time:>12.2f}\")\nprint(f\"{'Global':<15} {results_global.att:>10.4f} {results_global.se:>10.4f} {global_time:>12.2f}\")\nprint(f\"\\nTrue ATT: {true_att}\")\nprint(f\"Two-step bias: {results_twostep.att - true_att:.4f}\")\nprint(f\"Global bias: {results_global.att - true_att:.4f}\")", + "source": "# Compare estimation methods\nprint(\"Estimation method comparison:\")\nprint(\"=\"*60)\n\nimport time\n\n# Local method (default)\nstart = time.time()\ntrop_local = TROP(\n method='local',\n lambda_time_grid=[0.0, 1.0],\n lambda_unit_grid=[0.0, 1.0], \n lambda_nn_grid=[0.0, 0.1],\n n_bootstrap=20,\n seed=42\n)\nresults_local = trop_local.fit(\n df,\n outcome='outcome',\n treatment='treated',\n unit='unit',\n time='period'\n)\nlocal_time = time.time() - start\n\n# Global method\nstart = time.time()\ntrop_global = TROP(\n method='global',\n lambda_time_grid=[0.0, 1.0],\n lambda_unit_grid=[0.0, 1.0], \n lambda_nn_grid=[0.0, 0.1],\n n_bootstrap=20,\n seed=42\n)\nresults_global = trop_global.fit(\n df,\n outcome='outcome',\n treatment='treated',\n unit='unit',\n time='period'\n)\nglobal_time = time.time() - start\n\nprint(f\"\\n{'Method':<15} {'ATT':>10} {'SE':>10} {'Time (s)':>12}\")\nprint(\"-\"*60)\nprint(f\"{'Local':<15} {results_local.att:>10.4f} {results_local.se:>10.4f} {local_time:>12.2f}\")\nprint(f\"{'Global':<15} {results_global.att:>10.4f} {results_global.se:>10.4f} {global_time:>12.2f}\")\nprint(f\"\\nTrue ATT: {true_att}\")\nprint(f\"Local bias: {results_local.att - true_att:.4f}\")\nprint(f\"Global bias: {results_global.att - true_att:.4f}\")", "metadata": {}, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", - "source": "## 10. Estimation Methods: Two-Step vs Global\n\nTROP supports two estimation methods via the `method` parameter:\n\n**Two-Step Method** (`method='twostep'`, default):\n- Follows Algorithm 2 from the paper\n- Computes observation-specific weights for each treated observation\n- Fits a model per treated observation, then averages the individual effects\n- More flexible, allows for heterogeneous treatment effects\n- Computationally intensive (N_treated optimizations)\n\n**Global Method** (`method='global'`):\n- Fits a single model on control data using (1-W) masked weights (per paper Eq. 2)\n- Extracts per-observation treatment effects as post-hoc residuals: τ_it = Y_it - μ - α_i - β_t - L_it\n- ATT = mean(τ_it) over treated observations\n- Faster (single optimization) with global weights\n\nNote: `method='joint'` is a deprecated alias for `method='global'`.", + "source": "## 10. Estimation Methods: Local vs Global\n\nTROP supports two estimation methods via the `method` parameter:\n\n**Local Method** (`method='local'`, default):\n- Follows Algorithm 2 from the paper\n- Computes observation-specific weights for each treated observation\n- Fits a model per treated observation, then averages the individual effects\n- More flexible, allows for heterogeneous treatment effects\n- Computationally intensive (N_treated optimizations)\n\n**Global Method** (`method='global'`):\n- Fits a single model on control data using (1-W) masked weights (per paper Eq. 2)\n- Extracts per-observation treatment effects as post-hoc residuals: τ_it = Y_it - μ - α_i - β_t - L_it\n- ATT = mean(τ_it) over treated observations\n- Faster (single optimization) with global weights\n\nNote: `method='twostep'` is a deprecated alias for `method='local'`, and `method='joint'` is a deprecated alias for `method='global'`. Both will be removed in v3.0.", "metadata": {} }, { @@ -638,7 +638,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": "## Summary\n\nKey takeaways for TROP:\n\n1. **Best use cases**: Factor confounding, unobserved time-varying confounders with interactive effects\n2. **Factor estimation**: Nuclear norm regularization with LOOCV for tuning\n3. **Three tuning parameters**: λ_time, λ_unit, λ_nn selected automatically via LOOCV\n4. **Unit weights**: Exponential distance-based weighting of control units, where distance is computed as RMS outcome difference on control periods excluding the target period\n5. **Time weights**: Exponential decay weighting of pre-treatment periods\n6. **Weights**: Importance weights controlling relative contribution of observations (higher = more relevant)\n7. **Estimation methods**:\n - `method='twostep'` (default): Per-observation estimation, allows heterogeneous effects\n - `method='global'`: Single model with (1-W) masking, post-hoc heterogeneous effects, faster\n\n**When to use TROP vs SDID**:\n- Use **SDID** when parallel trends is plausible and factors are not a concern\n- Use **TROP** when you suspect factor confounding (regional shocks, economic cycles, latent factors)\n- Running both provides a useful robustness check\n\n**When to use twostep vs global method**:\n- Use **twostep** (default) for maximum flexibility with per-observation weights\n- Use **global** for faster estimation with global weights\n\n**Reference**:\n- Athey, S., Imbens, G. W., Qu, Z., & Viviano, D. (2025). Triply Robust Panel Estimators. *Working Paper*. https://arxiv.org/abs/2508.21536" + "source": "## Summary\n\nKey takeaways for TROP:\n\n1. **Best use cases**: Factor confounding, unobserved time-varying confounders with interactive effects\n2. **Factor estimation**: Nuclear norm regularization with LOOCV for tuning\n3. **Three tuning parameters**: λ_time, λ_unit, λ_nn selected automatically via LOOCV\n4. **Unit weights**: Exponential distance-based weighting of control units, where distance is computed as RMS outcome difference on control periods excluding the target period\n5. **Time weights**: Exponential decay weighting of pre-treatment periods\n6. **Weights**: Importance weights controlling relative contribution of observations (higher = more relevant)\n7. **Estimation methods**:\n - `method='local'` (default): Per-observation estimation, allows heterogeneous effects\n - `method='global'`: Single model with (1-W) masking, post-hoc heterogeneous effects, faster\n\n**When to use TROP vs SDID**:\n- Use **SDID** when parallel trends is plausible and factors are not a concern\n- Use **TROP** when you suspect factor confounding (regional shocks, economic cycles, latent factors)\n- Running both provides a useful robustness check\n\n**When to use local vs global method**:\n- Use **local** (default) for maximum flexibility with per-observation weights\n- Use **global** for faster estimation with global weights\n\n**Reference**:\n- Athey, S., Imbens, G. W., Qu, Z., & Viviano, D. (2025). Triply Robust Panel Estimators. *Working Paper*. https://arxiv.org/abs/2508.21536" }, { "cell_type": "code", diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 9507e1c..6dd2fdd 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -42,14 +42,14 @@ fn _rust_backend(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(linalg::solve_ols, m)?)?; m.add_function(wrap_pyfunction!(linalg::compute_robust_vcov, m)?)?; - // TROP estimator acceleration (twostep method) + // TROP estimator acceleration (local method) m.add_function(wrap_pyfunction!(trop::compute_unit_distance_matrix, m)?)?; m.add_function(wrap_pyfunction!(trop::loocv_grid_search, m)?)?; m.add_function(wrap_pyfunction!(trop::bootstrap_trop_variance, m)?)?; - // TROP estimator acceleration (joint method) - m.add_function(wrap_pyfunction!(trop::loocv_grid_search_joint, m)?)?; - m.add_function(wrap_pyfunction!(trop::bootstrap_trop_variance_joint, m)?)?; + // TROP estimator acceleration (global method) + m.add_function(wrap_pyfunction!(trop::loocv_grid_search_global, m)?)?; + m.add_function(wrap_pyfunction!(trop::bootstrap_trop_variance_global, m)?)?; // Diagnostics m.add_function(wrap_pyfunction!(rust_backend_info, m)?)?; diff --git a/rust/src/trop.rs b/rust/src/trop.rs index 997bf36..e26528e 100644 --- a/rust/src/trop.rs +++ b/rust/src/trop.rs @@ -1081,12 +1081,16 @@ pub fn bootstrap_trop_variance<'py>( } // ============================================================================ -// Joint method implementation +// Global method implementation +// +// Note: Only the #[pyfunction] exports were renamed (joint → global) to match +// the Python public API. The private Rust helpers below retain their original +// `*_joint*` names to keep the Rust-only rename scope minimal. // ============================================================================ -/// Compute global weights for joint method estimation. +/// Compute global weights for global method estimation. /// -/// Unlike twostep (which computes per-observation weights), joint uses global +/// Unlike local (which computes per-observation weights), global uses /// weights based on: /// - Time weights: distance to center of treated block /// - Unit weights: RMSE to average treated trajectory over pre-periods @@ -1196,7 +1200,7 @@ fn compute_joint_weights( delta } -/// Solve joint TWFE via weighted least squares (no low-rank, no tau). +/// Solve global TWFE via weighted least squares (no low-rank, no tau). /// /// Minimizes: min Σ δ_{it}(Y_{it} - μ - α_i - β_t)² /// @@ -1315,7 +1319,7 @@ fn solve_joint_no_lowrank( Some((mu, alpha, beta)) } -/// Solve joint TWFE + low-rank via alternating minimization (no tau). +/// Solve global TWFE + low-rank via alternating minimization (no tau). /// /// Minimizes: min Σ δ_{it}(Y_{it} - μ - α_i - β_t - L_{it})² + λ_nn||L||_* /// @@ -1422,12 +1426,12 @@ fn solve_joint_with_lowrank( Some((mu, alpha, beta, l)) } -/// Compute LOOCV score for joint method with specific parameter combination. +/// Compute LOOCV score for global method with specific parameter combination. /// /// Following paper's Equation 5: /// Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]² /// -/// For joint method, we exclude each control observation, fit the joint model +/// For global method, we exclude each control observation, fit the global model /// on remaining data, and compute the pseudo-treatment effect at the excluded obs. /// /// # Returns @@ -1502,7 +1506,7 @@ fn loocv_score_joint( } } -/// Perform LOOCV grid search for joint method using parallel grid search. +/// Perform LOOCV grid search for global method using parallel grid search. /// /// Evaluates all combinations of (lambda_time, lambda_unit, lambda_nn) in parallel /// and returns the combination with lowest LOOCV score. @@ -1522,7 +1526,7 @@ fn loocv_score_joint( #[pyfunction] #[pyo3(signature = (y, d, control_mask, lambda_time_grid, lambda_unit_grid, lambda_nn_grid, max_iter, tol))] #[allow(clippy::too_many_arguments)] -pub fn loocv_grid_search_joint<'py>( +pub fn loocv_grid_search_global<'py>( _py: Python<'py>, y: PyReadonlyArray2<'py, f64>, d: PyReadonlyArray2<'py, f64>, @@ -1630,7 +1634,7 @@ pub fn loocv_grid_search_joint<'py>( Ok((best_lt, best_lu, best_ln, best_score, n_valid, n_attempted, first_failed)) } -/// Compute bootstrap variance estimation for TROP joint method in parallel. +/// Compute bootstrap variance estimation for TROP global method in parallel. /// /// Performs unit-level block bootstrap, parallelizing across bootstrap iterations. /// Uses stratified sampling to preserve treated/control unit ratio. @@ -1651,7 +1655,7 @@ pub fn loocv_grid_search_joint<'py>( #[pyfunction] #[pyo3(signature = (y, d, lambda_time, lambda_unit, lambda_nn, n_bootstrap, max_iter, tol, seed))] #[allow(clippy::too_many_arguments)] -pub fn bootstrap_trop_variance_joint<'py>( +pub fn bootstrap_trop_variance_global<'py>( py: Python<'py>, y: PyReadonlyArray2<'py, f64>, d: PyReadonlyArray2<'py, f64>, @@ -1737,7 +1741,7 @@ pub fn bootstrap_trop_variance_joint<'py>( } } - // Compute weights and fit joint model + // Compute weights and fit global model let delta = compute_joint_weights( &y_boot.view(), &d_boot.view(), diff --git a/tests/test_rust_backend.py b/tests/test_rust_backend.py index 1921339..312347f 100644 --- a/tests/test_rust_backend.py +++ b/tests/test_rust_backend.py @@ -1151,12 +1151,12 @@ def test_trop_produces_valid_results(self): @pytest.mark.slow @pytest.mark.skipif(not HAS_RUST_BACKEND, reason="Rust backend not available") -class TestTROPJointRustBackend: - """Test suite for TROP joint method Rust backend functions.""" +class TestTROPGlobalRustBackend: + """Test suite for TROP global method Rust backend functions.""" - def test_loocv_grid_search_joint_returns_valid_result(self): - """Test loocv_grid_search_joint returns valid tuning parameters.""" - from diff_diff._rust_backend import loocv_grid_search_joint + def test_loocv_grid_search_global_returns_valid_result(self): + """Test loocv_grid_search_global returns valid tuning parameters.""" + from diff_diff._rust_backend import loocv_grid_search_global np.random.seed(42) n_periods, n_units = 10, 20 @@ -1173,7 +1173,7 @@ def test_loocv_grid_search_joint_returns_valid_result(self): lambda_unit_grid = np.array([0.0, 1.0]) lambda_nn_grid = np.array([0.0, 0.1]) - result = loocv_grid_search_joint( + result = loocv_grid_search_global( Y, D, control_mask, lambda_time_grid, lambda_unit_grid, lambda_nn_grid, 100, 1e-6, @@ -1192,9 +1192,9 @@ def test_loocv_grid_search_joint_returns_valid_result(self): assert n_attempted > 0 assert best_score >= 0 or np.isinf(best_score) - def test_loocv_grid_search_joint_reproducible(self): - """Test loocv_grid_search_joint is deterministic (no subsampling).""" - from diff_diff._rust_backend import loocv_grid_search_joint + def test_loocv_grid_search_global_reproducible(self): + """Test loocv_grid_search_global is deterministic (no subsampling).""" + from diff_diff._rust_backend import loocv_grid_search_global np.random.seed(42) n_periods, n_units = 8, 15 @@ -1210,12 +1210,12 @@ def test_loocv_grid_search_joint_reproducible(self): lambda_unit_grid = np.array([0.0, 0.5]) lambda_nn_grid = np.array([0.0, 0.1]) - result1 = loocv_grid_search_joint( + result1 = loocv_grid_search_global( Y, D, control_mask, lambda_time_grid, lambda_unit_grid, lambda_nn_grid, 50, 1e-6, ) - result2 = loocv_grid_search_joint( + result2 = loocv_grid_search_global( Y, D, control_mask, lambda_time_grid, lambda_unit_grid, lambda_nn_grid, 50, 1e-6, @@ -1224,9 +1224,9 @@ def test_loocv_grid_search_joint_reproducible(self): # Without subsampling, results should be deterministic assert result1[:4] == result2[:4] - def test_bootstrap_trop_variance_joint_shape(self): - """Test bootstrap_trop_variance_joint returns valid output.""" - from diff_diff._rust_backend import bootstrap_trop_variance_joint + def test_bootstrap_trop_variance_global_shape(self): + """Test bootstrap_trop_variance_global returns valid output.""" + from diff_diff._rust_backend import bootstrap_trop_variance_global np.random.seed(42) n_periods, n_units = 8, 15 @@ -1237,7 +1237,7 @@ def test_bootstrap_trop_variance_joint_shape(self): D = np.zeros((n_periods, n_units)) D[-n_post:, :n_treated] = 1.0 - estimates, se = bootstrap_trop_variance_joint( + estimates, se = bootstrap_trop_variance_global( Y, D, 0.5, 0.5, 0.1, 50, 50, 1e-6, 42 ) @@ -1246,9 +1246,9 @@ def test_bootstrap_trop_variance_joint_shape(self): assert isinstance(se, float) assert se >= 0 - def test_bootstrap_trop_variance_joint_reproducible(self): - """Test bootstrap_trop_variance_joint is reproducible.""" - from diff_diff._rust_backend import bootstrap_trop_variance_joint + def test_bootstrap_trop_variance_global_reproducible(self): + """Test bootstrap_trop_variance_global is reproducible.""" + from diff_diff._rust_backend import bootstrap_trop_variance_global np.random.seed(42) n_periods, n_units = 8, 15 @@ -1259,10 +1259,10 @@ def test_bootstrap_trop_variance_joint_reproducible(self): D = np.zeros((n_periods, n_units)) D[-n_post:, :n_treated] = 1.0 - est1, se1 = bootstrap_trop_variance_joint( + est1, se1 = bootstrap_trop_variance_global( Y, D, 0.5, 0.5, 0.1, 50, 50, 1e-6, 42 ) - est2, se2 = bootstrap_trop_variance_joint( + est2, se2 = bootstrap_trop_variance_global( Y, D, 0.5, 0.5, 0.1, 50, 50, 1e-6, 42 ) @@ -1272,11 +1272,11 @@ def test_bootstrap_trop_variance_joint_reproducible(self): @pytest.mark.slow @pytest.mark.skipif(not HAS_RUST_BACKEND, reason="Rust backend not available") -class TestTROPJointRustVsNumpy: - """Tests comparing TROP joint Rust and NumPy implementations.""" +class TestTROPGlobalRustVsNumpy: + """Tests comparing TROP global Rust and NumPy implementations.""" - def test_trop_joint_produces_valid_results(self): - """Test TROP joint with Rust backend produces valid results.""" + def test_trop_global_produces_valid_results(self): + """Test TROP global with Rust backend produces valid results.""" import pandas as pd from diff_diff import TROP @@ -1305,7 +1305,7 @@ def test_trop_joint_produces_valid_results(self): df = pd.DataFrame(data) trop = TROP( - method="joint", + method="global", lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], @@ -1327,8 +1327,8 @@ def test_trop_joint_produces_valid_results(self): assert results.lambda_unit in [0.0, 1.0] assert results.lambda_nn in [0.0, 0.1] - def test_trop_joint_and_twostep_agree_in_direction(self): - """Test joint and twostep methods agree on treatment effect direction.""" + def test_trop_global_and_local_agree_in_direction(self): + """Test global and local methods agree on treatment effect direction.""" import pandas as pd from diff_diff import TROP @@ -1356,33 +1356,33 @@ def test_trop_joint_and_twostep_agree_in_direction(self): df = pd.DataFrame(data) - # Fit with joint method - trop_joint = TROP( - method="joint", + # Fit with global method + trop_global = TROP( + method="global", lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], n_bootstrap=20, seed=42 ) - results_joint = trop_joint.fit(df, 'outcome', 'treated', 'unit', 'time') + results_global = trop_global.fit(df, 'outcome', 'treated', 'unit', 'time') - # Fit with twostep method - trop_twostep = TROP( - method="twostep", + # Fit with local method + trop_local = TROP( + method="local", lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], n_bootstrap=20, seed=42 ) - results_twostep = trop_twostep.fit(df, 'outcome', 'treated', 'unit', 'time') + results_local = trop_local.fit(df, 'outcome', 'treated', 'unit', 'time') # Both should have same sign (both positive for true_effect=2.0) - assert np.sign(results_joint.att) == np.sign(results_twostep.att) + assert np.sign(results_global.att) == np.sign(results_local.att) - def test_trop_joint_handles_nan_outcomes(self): - """Test TROP joint method handles NaN outcome values gracefully.""" + def test_trop_global_handles_nan_outcomes(self): + """Test TROP global method handles NaN outcome values gracefully.""" import pandas as pd from diff_diff import TROP @@ -1423,7 +1423,7 @@ def test_trop_joint_handles_nan_outcomes(self): assert n_nan > 0, "Should have introduced some NaN values" trop = TROP( - method="joint", + method="global", lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], @@ -1440,7 +1440,7 @@ def test_trop_joint_handles_nan_outcomes(self): # ATT should still be positive (true effect is positive) assert results.att > 0, f"ATT {results.att:.2f} should be positive" - def test_trop_joint_no_valid_pre_unit_gets_zero_weight(self): + def test_trop_global_no_valid_pre_unit_gets_zero_weight(self): """Test that units with no valid pre-period data get zero weight. When a control unit has all NaN values in the pre-treatment period, @@ -1488,9 +1488,9 @@ def test_trop_joint_no_valid_pre_unit_gets_zero_weight(self): unit_pre_data = df[(df['unit'] == control_unit_with_no_pre) & (df['time'] < (n_periods - n_post))] assert unit_pre_data['outcome'].isna().all(), "Control unit should have all NaN in pre-period" - # Fit with joint method - should handle gracefully + # Fit with global method - should handle gracefully trop = TROP( - method="joint", + method="global", lambda_time_grid=[0.5, 1.0], lambda_unit_grid=[0.5, 1.0], lambda_nn_grid=[0.0], @@ -1509,7 +1509,7 @@ def test_trop_joint_no_valid_pre_unit_gets_zero_weight(self): assert abs(results.att - true_effect) < 1.5, \ f"ATT {results.att:.2f} should be close to true effect {true_effect}" - def test_trop_joint_nan_exclusion_rust_python_parity(self): + def test_trop_global_nan_exclusion_rust_python_parity(self): """Test Rust and Python backends produce matching results with NaN data. This verifies that when data contains NaN values: @@ -1559,7 +1559,7 @@ def test_trop_joint_nan_exclusion_rust_python_parity(self): # Common TROP parameters trop_params = dict( - method="joint", + method="global", lambda_time_grid=[0.5, 1.0], lambda_unit_grid=[0.5, 1.0], lambda_nn_grid=[0.0], @@ -1578,8 +1578,8 @@ def test_trop_joint_nan_exclusion_rust_python_parity(self): trop_module = sys.modules['diff_diff.trop'] with patch.object(trop_module, 'HAS_RUST_BACKEND', False), \ - patch.object(trop_module, '_rust_loocv_grid_search_joint', None), \ - patch.object(trop_module, '_rust_bootstrap_trop_variance_joint', None): + patch.object(trop_module, '_rust_loocv_grid_search_global', None), \ + patch.object(trop_module, '_rust_bootstrap_trop_variance_global', None): trop_python = TROP(**trop_params) results_python = trop_python.fit(df.copy(), 'outcome', 'treated', 'unit', 'time') @@ -1599,7 +1599,7 @@ def test_trop_joint_nan_exclusion_rust_python_parity(self): assert results_rust.att > 0, f"Rust ATT {results_rust.att} should be positive" assert results_python.att > 0, f"Python ATT {results_python.att} should be positive" - def test_trop_joint_treated_pre_nan_rust_python_parity(self): + def test_trop_global_treated_pre_nan_rust_python_parity(self): """Test Rust/Python parity when treated units have pre-period NaN. When all treated units have NaN at a pre-period, average_treated[t] = NaN. @@ -1649,7 +1649,7 @@ def test_trop_joint_treated_pre_nan_rust_python_parity(self): # Common TROP parameters trop_params = dict( - method="joint", + method="global", lambda_time_grid=[1.0], lambda_unit_grid=[1.0], lambda_nn_grid=[0.0], @@ -1668,8 +1668,8 @@ def test_trop_joint_treated_pre_nan_rust_python_parity(self): trop_module = sys.modules['diff_diff.trop'] with patch.object(trop_module, 'HAS_RUST_BACKEND', False), \ - patch.object(trop_module, '_rust_loocv_grid_search_joint', None), \ - patch.object(trop_module, '_rust_bootstrap_trop_variance_joint', None): + patch.object(trop_module, '_rust_loocv_grid_search_global', None), \ + patch.object(trop_module, '_rust_bootstrap_trop_variance_global', None): trop_python = TROP(**trop_params) results_python = trop_python.fit(df.copy(), 'outcome', 'treated', 'unit', 'time') @@ -1684,7 +1684,7 @@ def test_trop_joint_treated_pre_nan_rust_python_parity(self): f"Rust ATT ({results_rust.att:.3f}) and Python ATT ({results_python.att:.3f}) " \ f"differ by {att_diff:.3f}, should be < 0.5" - def test_trop_joint_solver_parity_no_lowrank(self): + def test_trop_global_solver_parity_no_lowrank(self): """Test Rust/Python solver parity for no-lowrank path (lambda_nn >= 1e10). Both backends should produce matching (mu, alpha, beta) at atol=1e-6. @@ -1732,8 +1732,8 @@ def test_trop_joint_solver_parity_no_lowrank(self): # Python-only backend trop_module = sys.modules['diff_diff.trop'] with patch.object(trop_module, 'HAS_RUST_BACKEND', False), \ - patch.object(trop_module, '_rust_loocv_grid_search_joint', None), \ - patch.object(trop_module, '_rust_bootstrap_trop_variance_joint', None): + patch.object(trop_module, '_rust_loocv_grid_search_global', None), \ + patch.object(trop_module, '_rust_bootstrap_trop_variance_global', None): trop_python = TROP(**trop_params) results_python = trop_python.fit(df.copy(), 'outcome', 'treated', 'unit', 'time') @@ -1754,7 +1754,7 @@ def test_trop_joint_solver_parity_no_lowrank(self): assert abs(r_val - p_val) < 1e-6, \ f"Time effect mismatch for {key}: Rust={r_val:.8f}, Python={p_val:.8f}" - def test_trop_joint_solver_parity_with_lowrank(self): + def test_trop_global_solver_parity_with_lowrank(self): """Test Rust/Python solver parity for with-lowrank path (finite lambda_nn). Both backends should produce matching (mu, alpha, beta) at atol=1e-6. @@ -1803,8 +1803,8 @@ def test_trop_joint_solver_parity_with_lowrank(self): # Python-only backend trop_module = sys.modules['diff_diff.trop'] with patch.object(trop_module, 'HAS_RUST_BACKEND', False), \ - patch.object(trop_module, '_rust_loocv_grid_search_joint', None), \ - patch.object(trop_module, '_rust_bootstrap_trop_variance_joint', None): + patch.object(trop_module, '_rust_loocv_grid_search_global', None), \ + patch.object(trop_module, '_rust_bootstrap_trop_variance_global', None): trop_python = TROP(**trop_params) results_python = trop_python.fit(df.copy(), 'outcome', 'treated', 'unit', 'time') diff --git a/tests/test_trop.py b/tests/test_trop.py index 7deabd7..f8207f4 100644 --- a/tests/test_trop.py +++ b/tests/test_trop.py @@ -2601,7 +2601,7 @@ class TestTROPNuclearNormSolver: def test_proximal_step_size_correctness(self): """Verify L converges to prox_{λ/2}(R) for uniform weights.""" - trop_est = TROP(method="joint", n_bootstrap=2) + trop_est = TROP(method="global", n_bootstrap=2) # Small problem with known solution rng = np.random.default_rng(42) @@ -2631,7 +2631,7 @@ def test_lowrank_objective_decreases(self): delta = rng.uniform(0.5, 2.0, (6, 4)) lambda_nn = 0.3 - trop_est = TROP(method="joint", n_bootstrap=2) + trop_est = TROP(method="global", n_bootstrap=2) L = np.zeros_like(R) objectives = [] @@ -2655,14 +2655,14 @@ def test_lowrank_objective_decreases(self): f"Objective increased at step {k}: {objectives[k]} > {objectives[k-1]}" ) - def test_twostep_nonuniform_weights_objective(self): + def test_local_nonuniform_weights_objective(self): """Verify objective decreases with non-uniform weights (W_max < 1).""" rng = np.random.default_rng(123) R = rng.normal(0, 1, (6, 4)) W = rng.uniform(0.1, 0.8, (6, 4)) lambda_nn = 0.3 - trop_est = TROP(method="twostep", n_bootstrap=2) + trop_est = TROP(method="local", n_bootstrap=2) # Initial objective with L=0 L_init = np.zeros_like(R) @@ -2704,7 +2704,7 @@ def test_zero_weights_no_division_error(self): W = np.zeros((6, 4)) L_init = rng.normal(0, 1, (6, 4)) - trop_est = TROP(method="twostep", n_bootstrap=2) + trop_est = TROP(method="local", n_bootstrap=2) result = trop_est._weighted_nuclear_norm_solve( Y=Y, W=W, @@ -2719,18 +2719,20 @@ def test_zero_weights_no_division_error(self): @pytest.mark.slow -class TestTROPJointMethod: - """Tests for TROP method='joint'. - - The joint method estimates a single scalar treatment effect τ via - weighted least squares, as opposed to the twostep method which - computes per-observation effects. +class TestTROPGlobalMethod: + """Tests for TROP method='global'. + + The global method fits a single model on control data with global + weights, then extracts per-observation treatment effects as + residuals (τ_it = Y_it - μ - α_i - β_t - L_it). ATT is the mean + of these effects. The local method instead fits a separate model + per treated observation with observation-specific weights. """ - def test_joint_basic(self, simple_panel_data): - """Joint method runs and produces reasonable ATT.""" + def test_global_basic(self, simple_panel_data): + """Global method runs and produces reasonable ATT.""" trop_est = TROP( - method="joint", + method="global", lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], @@ -2753,10 +2755,10 @@ def test_joint_basic(self, simple_panel_data): # ATT should be positive (true effect is 3.0) assert results.att > 0 - def test_joint_no_lowrank(self, simple_panel_data): - """Joint method with lambda_nn=inf (no low-rank).""" + def test_global_no_lowrank(self, simple_panel_data): + """Global method with lambda_nn=inf (no low-rank).""" trop_est = TROP( - method="joint", + method="global", lambda_time_grid=[0.0], lambda_unit_grid=[0.0], lambda_nn_grid=[float('inf')], # Disable low-rank @@ -2777,11 +2779,11 @@ def test_joint_no_lowrank(self, simple_panel_data): # Factor matrix should be all zeros assert np.allclose(results.factor_matrix, 0.0) - def test_joint_with_lowrank(self, factor_dgp_data, ci_params): - """Joint method with finite lambda_nn (with low-rank).""" + def test_global_with_lowrank(self, factor_dgp_data, ci_params): + """Global method with finite lambda_nn (with low-rank).""" n_boot = ci_params.bootstrap(20) trop_est = TROP( - method="joint", + method="global", lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1, 1.0], @@ -2801,18 +2803,18 @@ def test_joint_with_lowrank(self, factor_dgp_data, ci_params): # Should produce non-zero factor matrix if low-rank is used # (depends on which lambda_nn is selected) - def test_joint_matches_direction(self, simple_panel_data): - """Joint method sign/magnitude roughly matches twostep.""" - # Fit with twostep - trop_twostep = TROP( - method="twostep", + def test_global_matches_direction(self, simple_panel_data): + """Global method sign/magnitude roughly matches local.""" + # Fit with local + trop_local = TROP( + method="local", lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], n_bootstrap=10, seed=42, ) - results_twostep = trop_twostep.fit( + results_local = trop_local.fit( simple_panel_data, outcome="outcome", treatment="treated", @@ -2820,16 +2822,16 @@ def test_joint_matches_direction(self, simple_panel_data): time="period", ) - # Fit with joint - trop_joint = TROP( - method="joint", + # Fit with global + trop_global = TROP( + method="global", lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], n_bootstrap=10, seed=42, ) - results_joint = trop_joint.fit( + results_global = trop_global.fit( simple_panel_data, outcome="outcome", treatment="treated", @@ -2838,11 +2840,11 @@ def test_joint_matches_direction(self, simple_panel_data): ) # Both should have positive ATT (true effect is 3.0) - assert results_twostep.att > 0 - assert results_joint.att > 0 + assert results_local.att > 0 + assert results_global.att > 0 # Signs should match - assert np.sign(results_twostep.att) == np.sign(results_joint.att) + assert np.sign(results_local.att) == np.sign(results_global.att) def test_method_parameter_validation(self): """Invalid method raises ValueError.""" @@ -2865,24 +2867,38 @@ def test_method_in_get_params_joint_deprecated(self): def test_method_in_set_params(self): """method parameter can be set via set_params().""" - trop_est = TROP(method="twostep") - assert trop_est.method == "twostep" + trop_est = TROP(method="local") + assert trop_est.method == "local" trop_est.set_params(method="global") assert trop_est.method == "global" def test_method_set_params_joint_deprecated(self): """'joint' alias maps to 'global' via set_params().""" - trop_est = TROP(method="twostep") + trop_est = TROP(method="local") with pytest.warns(FutureWarning, match="deprecated"): trop_est.set_params(method="joint") assert trop_est.method == "global" - def test_joint_bootstrap_variance(self, simple_panel_data, ci_params): - """Joint method bootstrap variance estimation works.""" + def test_method_in_get_params_twostep_deprecated(self): + """'twostep' alias maps to 'local' in get_params().""" + with pytest.warns(FutureWarning, match="deprecated"): + trop_est = TROP(method="twostep") + params = trop_est.get_params() + assert params["method"] == "local" + + def test_method_set_params_twostep_deprecated(self): + """'twostep' alias maps to 'local' via set_params().""" + trop_est = TROP(method="global") + with pytest.warns(FutureWarning, match="deprecated"): + trop_est.set_params(method="twostep") + assert trop_est.method == "local" + + def test_global_bootstrap_variance(self, simple_panel_data, ci_params): + """Global method bootstrap variance estimation works.""" n_boot = ci_params.bootstrap(20) trop_est = TROP( - method="joint", + method="global", lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], @@ -2901,11 +2917,11 @@ def test_joint_bootstrap_variance(self, simple_panel_data, ci_params): assert results.n_bootstrap == n_boot assert results.bootstrap_distribution is not None - def test_joint_confidence_interval(self, simple_panel_data, ci_params): - """Joint method produces valid confidence intervals.""" + def test_global_confidence_interval(self, simple_panel_data, ci_params): + """Global method produces valid confidence intervals.""" n_boot = ci_params.bootstrap(30) trop_est = TROP( - method="joint", + method="global", lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], @@ -2925,14 +2941,14 @@ def test_joint_confidence_interval(self, simple_panel_data, ci_params): assert lower < results.att < upper assert lower < upper - def test_joint_loocv_selects_from_grid(self, simple_panel_data): - """Joint method LOOCV selects tuning parameters from the grid.""" + def test_global_loocv_selects_from_grid(self, simple_panel_data): + """Global method LOOCV selects tuning parameters from the grid.""" grid_time = [0.0, 0.5, 1.0] grid_unit = [0.0, 0.5, 1.0] grid_nn = [0.0, 0.1] trop_est = TROP( - method="joint", + method="global", lambda_time_grid=grid_time, lambda_unit_grid=grid_unit, lambda_nn_grid=grid_nn, @@ -2954,10 +2970,10 @@ def test_joint_loocv_selects_from_grid(self, simple_panel_data): # LOOCV score should be computed assert np.isfinite(results.loocv_score) or np.isnan(results.loocv_score) - def test_joint_loocv_score_internal(self, simple_panel_data): - """Test the internal _loocv_score_joint method produces valid scores.""" + def test_global_loocv_score_internal(self, simple_panel_data): + """Test the internal _loocv_score_global method produces valid scores.""" trop_est = TROP( - method="joint", + method="global", lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], @@ -2992,21 +3008,21 @@ def test_joint_loocv_score_internal(self, simple_panel_data): treated_periods = 3 # From fixture: n_post = 3 # Score should be finite - score = trop_est._loocv_score_joint( + score = trop_est._loocv_score_global( Y, D, control_obs, 0.0, 0.0, 0.0, treated_periods, n_units, n_periods ) assert np.isfinite(score) or np.isinf(score), "Score should be finite or inf" # Score with larger lambda_nn should still work - score2 = trop_est._loocv_score_joint( + score2 = trop_est._loocv_score_global( Y, D, control_obs, 1.0, 1.0, 0.1, treated_periods, n_units, n_periods ) assert np.isfinite(score2) or np.isinf(score2), "Score should be finite or inf" - def test_joint_handles_nan_outcomes(self, simple_panel_data): - """Joint method handles NaN outcome values gracefully.""" + def test_global_handles_nan_outcomes(self, simple_panel_data): + """Global method handles NaN outcome values gracefully.""" # Introduce NaN in some control observations data = simple_panel_data.copy() control_mask = data['treated'] == 0 @@ -3018,7 +3034,7 @@ def test_joint_handles_nan_outcomes(self, simple_panel_data): data.loc[nan_indices, 'outcome'] = np.nan trop_est = TROP( - method="joint", + method="global", lambda_time_grid=[0.0, 1.0], lambda_unit_grid=[0.0, 1.0], lambda_nn_grid=[0.0, 0.1], @@ -3039,8 +3055,8 @@ def test_joint_handles_nan_outcomes(self, simple_panel_data): # ATT should be positive (true effect is 3.0) assert results.att > 0, "ATT should be positive" - def test_joint_with_lowrank_handles_nan(self, simple_panel_data): - """Joint method with low-rank handles NaN values correctly.""" + def test_global_with_lowrank_handles_nan(self, simple_panel_data): + """Global method with low-rank handles NaN values correctly.""" # Introduce NaN in some control observations data = simple_panel_data.copy() control_mask = data['treated'] == 0 @@ -3052,7 +3068,7 @@ def test_joint_with_lowrank_handles_nan(self, simple_panel_data): data.loc[nan_indices, 'outcome'] = np.nan trop_est = TROP( - method="joint", + method="global", lambda_time_grid=[0.0], lambda_unit_grid=[0.0], lambda_nn_grid=[0.1], # Finite lambda_nn enables low-rank @@ -3071,7 +3087,7 @@ def test_joint_with_lowrank_handles_nan(self, simple_panel_data): assert np.isfinite(results.att), "ATT should be finite with NaN data" assert np.isfinite(results.se), "SE should be finite with NaN data" - def test_joint_nan_exclusion_behavior(self, simple_panel_data): + def test_global_nan_exclusion_behavior(self, simple_panel_data): """Verify NaN observations are truly excluded from estimation. This tests the PR #113 fix: NaN observations should not contribute @@ -3098,7 +3114,7 @@ def test_joint_nan_exclusion_behavior(self, simple_panel_data): # Fit on both versions with identical settings trop_nan = TROP( - method="joint", + method="global", lambda_time_grid=[1.0], lambda_unit_grid=[1.0], lambda_nn_grid=[0.0], # Disable low-rank for cleaner comparison @@ -3106,7 +3122,7 @@ def test_joint_nan_exclusion_behavior(self, simple_panel_data): seed=42, ) trop_dropped = TROP( - method="joint", + method="global", lambda_time_grid=[1.0], lambda_unit_grid=[1.0], lambda_nn_grid=[0.0], @@ -3136,7 +3152,7 @@ def test_joint_nan_exclusion_behavior(self, simple_panel_data): f"({results_dropped.att:.4f}) - true NaN exclusion" ) - def test_joint_unit_no_valid_pre_gets_zero_weight(self, simple_panel_data): + def test_global_unit_no_valid_pre_gets_zero_weight(self, simple_panel_data): """Verify units with no valid pre-period data get zero weight. This tests the PR #113 fix: units with no valid pre-period observations @@ -3159,7 +3175,7 @@ def test_joint_unit_no_valid_pre_gets_zero_weight(self, simple_panel_data): data.loc[mask, 'outcome'] = np.nan trop_est = TROP( - method="joint", + method="global", lambda_time_grid=[1.0], lambda_unit_grid=[1.0], # Non-zero lambda_unit to use distance weighting lambda_nn_grid=[0.0], @@ -3179,8 +3195,8 @@ def test_joint_unit_no_valid_pre_gets_zero_weight(self, simple_panel_data): assert np.isfinite(results.att), "ATT should be finite even with unit having no pre-period data" assert np.isfinite(results.se), "SE should be finite" - def test_joint_treated_pre_nan_handling(self, simple_panel_data): - """Verify joint method handles NaN in treated units during pre-periods. + def test_global_treated_pre_nan_handling(self, simple_panel_data): + """Verify global method handles NaN in treated units during pre-periods. When all treated units have NaN at a pre-period, average_treated[t] = NaN. This period should be excluded from unit distance calculation (both numerator @@ -3211,7 +3227,7 @@ def test_joint_treated_pre_nan_handling(self, simple_panel_data): assert n_nan == len(treated_units), f"Should have {len(treated_units)} NaN, got {n_nan}" trop_est = TROP( - method="joint", + method="global", lambda_time_grid=[1.0], lambda_unit_grid=[1.0], lambda_nn_grid=[0.0], @@ -3230,7 +3246,7 @@ def test_joint_treated_pre_nan_handling(self, simple_panel_data): assert np.isfinite(results.att), f"ATT should be finite, got {results.att}" assert np.isfinite(results.se), f"SE should be finite, got {results.se}" - def test_joint_rejects_staggered_adoption(self): + def test_global_rejects_staggered_adoption(self): """Global method raises ValueError for staggered adoption data. The global method assumes all treated units receive treatment at the @@ -3259,7 +3275,7 @@ def test_joint_rejects_staggered_adoption(self): trop.fit(df, 'outcome', 'treated', 'unit', 'time') def test_global_method_alias(self, simple_panel_data): - """method='global' works and produces same results as deprecated 'joint'.""" + """method='global' runs and produces a valid positive ATT.""" trop_est = TROP( method="global", lambda_time_grid=[0.0, 1.0], @@ -3310,7 +3326,7 @@ def test_global_uses_control_only_weights(self, simple_panel_data): treated_periods = np.sum(np.any(D == 1, axis=1)) - delta = trop_est._compute_joint_weights( + delta = trop_est._compute_global_weights( Y, D, 1.0, 1.0, int(treated_periods), n_units, n_periods ) @@ -3404,10 +3420,10 @@ def test_global_treated_outcome_does_not_affect_fit(self, simple_panel_data): ) # Compute weights and fit with original data - delta = trop_est._compute_joint_weights( + delta = trop_est._compute_global_weights( Y, D, 1.0, 1.0, treated_periods, n_units, n_periods ) - mu1, alpha1, beta1, L1 = trop_est._solve_joint_with_lowrank( + mu1, alpha1, beta1, L1 = trop_est._solve_global_with_lowrank( Y, delta, 0.1, 100, 1e-6 ) @@ -3416,10 +3432,10 @@ def test_global_treated_outcome_does_not_affect_fit(self, simple_panel_data): Y_perturbed[D == 1] += 1000.0 # Recompute (same weights since (1-W) zeroes treated cells) - delta2 = trop_est._compute_joint_weights( + delta2 = trop_est._compute_global_weights( Y_perturbed, D, 1.0, 1.0, treated_periods, n_units, n_periods ) - mu2, alpha2, beta2, L2 = trop_est._solve_joint_with_lowrank( + mu2, alpha2, beta2, L2 = trop_est._solve_global_with_lowrank( Y_perturbed, delta2, 0.1, 100, 1e-6 ) @@ -3479,8 +3495,8 @@ def test_global_n_treated_obs_partial_nan(self): f"Expected {total_treated - n_nan}, got {results.n_treated_obs}" assert np.isfinite(results.att) - def test_twostep_n_treated_obs_partial_nan(self): - """Twostep method: n_treated_obs reflects only finite outcomes.""" + def test_local_n_treated_obs_partial_nan(self): + """Local method: n_treated_obs reflects only finite outcomes.""" df = self._make_panel() treated_mask = (df['treated'] == 1) @@ -3492,7 +3508,7 @@ def test_twostep_n_treated_obs_partial_nan(self): total_treated = int(treated_mask.sum()) trop_est = TROP( - method="twostep", + method="local", lambda_time_grid=[1.0], lambda_unit_grid=[1.0], lambda_nn_grid=[np.inf], @@ -3507,8 +3523,8 @@ def test_twostep_n_treated_obs_partial_nan(self): f"Expected {total_treated - n_nan}, got {results.n_treated_obs}" assert np.isfinite(results.att) - def test_twostep_nan_treated_not_poison_att(self): - """Twostep: NaN treated outcomes don't poison ATT via np.mean.""" + def test_local_nan_treated_not_poison_att(self): + """Local: NaN treated outcomes don't poison ATT via np.mean.""" df = self._make_panel(effect=3.0) # Make ONE treated outcome NaN @@ -3517,7 +3533,7 @@ def test_twostep_nan_treated_not_poison_att(self): df.loc[first_treated_idx, 'outcome'] = np.nan trop_est = TROP( - method="twostep", + method="local", lambda_time_grid=[1.0], lambda_unit_grid=[1.0], lambda_nn_grid=[np.inf], @@ -3558,14 +3574,14 @@ def test_global_all_treated_nan_warns(self): assert results.n_treated_obs == 0 assert np.isnan(results.att) - def test_twostep_all_treated_nan_warns(self): - """Twostep method warns when all treated outcomes are NaN.""" + def test_local_all_treated_nan_warns(self): + """Local method warns when all treated outcomes are NaN.""" df = self._make_panel() df.loc[df['treated'] == 1, 'outcome'] = np.nan trop_est = TROP( - method="twostep", + method="local", lambda_time_grid=[1.0], lambda_unit_grid=[1.0], lambda_nn_grid=[np.inf], @@ -3602,15 +3618,15 @@ def test_global_bootstrap_zero_draws_returns_nan_se(self): ) # Disable Rust backend so Python fallback path is tested, - # then patch _fit_joint_with_fixed_lambda to always raise + # then patch _fit_global_with_fixed_lambda to always raise trop_module = sys.modules['diff_diff.trop'] with patch.object(trop_module, 'HAS_RUST_BACKEND', False), \ - patch.object(trop_module, '_rust_bootstrap_trop_variance_joint', None), \ - patch.object(TROP, '_fit_joint_with_fixed_lambda', + patch.object(trop_module, '_rust_bootstrap_trop_variance_global', None), \ + patch.object(TROP, '_fit_global_with_fixed_lambda', side_effect=ValueError("forced failure")): with warnings.catch_warnings(): warnings.simplefilter("ignore") - se, dist = trop_est._bootstrap_variance_joint( + se, dist = trop_est._bootstrap_variance_global( df, 'outcome', 'treated', 'unit', 'time', (1.0, 1.0, 1e10), 3, ) @@ -3618,14 +3634,14 @@ def test_global_bootstrap_zero_draws_returns_nan_se(self): assert np.isnan(se), f"SE should be NaN when 0 draws succeed, got {se}" assert len(dist) == 0 - def test_twostep_bootstrap_zero_draws_returns_nan_se(self): - """Twostep bootstrap with 0 successful draws returns NaN SE, not 0.0.""" + def test_local_bootstrap_zero_draws_returns_nan_se(self): + """Local bootstrap with 0 successful draws returns NaN SE, not 0.0.""" from unittest.mock import patch df = TestTROPNValidTreated._make_panel() trop_est = TROP( - method="twostep", + method="local", lambda_time_grid=[1.0], lambda_unit_grid=[1.0], lambda_nn_grid=[np.inf],