@@ -105,7 +105,9 @@ def fit(
105105 treatment : str = None ,
106106 time : str = None ,
107107 formula : str = None ,
108- covariates : list = None
108+ covariates : list = None ,
109+ fixed_effects : list = None ,
110+ absorb : list = None
109111 ) -> DiDResults :
110112 """
111113 Fit the Difference-in-Differences model.
@@ -124,7 +126,15 @@ def fit(
124126 R-style formula (e.g., "outcome ~ treated * post").
125127 If provided, overrides outcome, treatment, and time parameters.
126128 covariates : list, optional
127- List of covariate column names to include in the regression.
129+ List of covariate column names to include as linear controls.
130+ fixed_effects : list, optional
131+ List of categorical column names to include as fixed effects.
132+ Creates dummy variables for each category (drops first level).
133+ Use for low-dimensional fixed effects (e.g., industry, region).
134+ absorb : list, optional
135+ List of categorical column names for high-dimensional fixed effects.
136+ Uses within-transformation (demeaning) instead of dummy variables.
137+ More efficient for large numbers of categories (e.g., firm, individual).
128138
129139 Returns
130140 -------
@@ -135,6 +145,18 @@ def fit(
135145 ------
136146 ValueError
137147 If required parameters are missing or data validation fails.
148+
149+ Examples
150+ --------
151+ Using fixed effects (dummy variables):
152+
153+ >>> did.fit(data, outcome='sales', treatment='treated', time='post',
154+ ... fixed_effects=['state', 'industry'])
155+
156+ Using absorbed fixed effects (within-transformation):
157+
158+ >>> did.fit(data, outcome='sales', treatment='treated', time='post',
159+ ... absorb=['firm_id'])
138160 """
139161 # Parse formula if provided
140162 if formula is not None :
@@ -147,10 +169,35 @@ def fit(
147169 # Validate inputs
148170 self ._validate_data (data , outcome , treatment , time , covariates )
149171
172+ # Validate fixed effects and absorb columns
173+ if fixed_effects :
174+ for fe in fixed_effects :
175+ if fe not in data .columns :
176+ raise ValueError (f"Fixed effect column '{ fe } ' not found in data" )
177+ if absorb :
178+ for ab in absorb :
179+ if ab not in data .columns :
180+ raise ValueError (f"Absorb column '{ ab } ' not found in data" )
181+
182+ # Handle absorbed fixed effects (within-transformation)
183+ working_data = data .copy ()
184+ absorbed_vars = []
185+ n_absorbed_effects = 0
186+
187+ if absorb :
188+ # Apply within-transformation for each absorbed variable
189+ vars_to_demean = [outcome ] + (covariates or [])
190+ for ab_var in absorb :
191+ n_absorbed_effects += working_data [ab_var ].nunique () - 1
192+ for var in vars_to_demean :
193+ group_means = working_data .groupby (ab_var )[var ].transform ("mean" )
194+ working_data [var ] = working_data [var ] - group_means
195+ absorbed_vars .append (ab_var )
196+
150197 # Extract variables
151- y = data [outcome ].values .astype (float )
152- d = data [treatment ].values .astype (float )
153- t = data [time ].values .astype (float )
198+ y = working_data [outcome ].values .astype (float )
199+ d = working_data [treatment ].values .astype (float )
200+ t = working_data [time ].values .astype (float )
154201
155202 # Validate binary variables
156203 validate_binary (d , "treatment" )
@@ -166,9 +213,18 @@ def fit(
166213 # Add covariates if provided
167214 if covariates :
168215 for cov in covariates :
169- X = np .column_stack ([X , data [cov ].values .astype (float )])
216+ X = np .column_stack ([X , working_data [cov ].values .astype (float )])
170217 var_names .append (cov )
171218
219+ # Add fixed effects as dummy variables
220+ if fixed_effects :
221+ for fe in fixed_effects :
222+ # Create dummies, drop first category to avoid multicollinearity
223+ dummies = pd .get_dummies (data [fe ], prefix = fe , drop_first = True )
224+ for col in dummies .columns :
225+ X = np .column_stack ([X , dummies [col ].values .astype (float )])
226+ var_names .append (col )
227+
172228 # Fit OLS
173229 coefficients , residuals , fitted , r_squared = self ._fit_ols (X , y )
174230
@@ -190,8 +246,8 @@ def fit(
190246 att = coefficients [att_idx ]
191247 se = np .sqrt (vcov [att_idx , att_idx ])
192248
193- # Compute test statistics
194- df = len (y ) - X .shape [1 ]
249+ # Compute test statistics (adjust df for absorbed fixed effects)
250+ df = len (y ) - X .shape [1 ] - n_absorbed_effects
195251 t_stat = att / se
196252 p_value = compute_p_value (t_stat , df = df )
197253 conf_int = compute_confidence_interval (att , se , self .alpha , df = df )
0 commit comments