diff --git a/.gitignore b/.gitignore index 725375b4..93c3ecb4 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ blech.dir *.Rhistory .aider* .env +utils/glm/_temp/ diff --git a/params/_templates/glm_params.json b/params/_templates/glm_params.json new file mode 100644 index 00000000..c6519b88 --- /dev/null +++ b/params/_templates/glm_params.json @@ -0,0 +1,9 @@ +{ + "bin_size": 25, + "history_window": 250, + "n_basis_funcs": 8, + "time_lims": [1500, 4500], + "include_coupling": false, + "separate_tastes": true, + "separate_regions": true +} diff --git a/utils/glm/README.md b/utils/glm/README.md new file mode 100644 index 00000000..15fec944 --- /dev/null +++ b/utils/glm/README.md @@ -0,0 +1,116 @@ +# GLM-based Firing Rate Estimation + +This module fits Generalized Linear Models (GLMs) to electrophysiological spike data using the [nemos](https://nemos.readthedocs.io/) library. + +## Why Separate Scripts? + +The nemos library has dependency conflicts with blech_clust's conda environment (particularly around numpy/jax versions). To work around this, the GLM fitting runs in a **separate virtual environment** dedicated to nemos. + +This module uses a two-environment architecture: + +``` +┌─────────────────────────────────────────────────────────────┐ +│ infer_glm_rates.py │ +│ (Main Orchestrator) │ +│ Can run from any Python environment │ +└─────────────────────┬───────────────────────────────────────┘ + │ + ┌───────────┴───────────┐ + ▼ ▼ +┌─────────────────────┐ ┌─────────────────────┐ +│ _glm_extract_data.py│ │ _glm_fit_models.py │ +│ │ │ │ +│ Runs in blech_clust │ │ Runs in nemos │ +│ conda environment │ │ virtual environment│ +│ │ │ │ +│ - Loads spike data │ │ - Fits GLM models │ +│ - Uses ephys_data │ │ - Uses nemos │ +│ - Saves to temp/ │ │ - Reads from temp/ │ +└─────────────────────┘ └─────────────────────┘ + │ │ + └───────────┬───────────┘ + ▼ + ┌───────────────┐ + │ _temp/ │ + │ (intermediate │ + │ files) │ + └───────────────┘ +``` + +## Files + +| File | Purpose | Environment | +|------|---------|-------------| +| `infer_glm_rates.py` | Main entry point. Detects environments, orchestrates data flow | Any | +| `_glm_extract_data.py` | Extracts spike data using ephys_data | blech_clust conda | +| `_glm_fit_models.py` | Fits GLM models, generates plots, writes HDF5 | nemos venv | +| `_temp/` | Temporary directory for intermediate numpy files (gitignored) | N/A | + +## Setup + +1. **blech_clust environment**: Should already exist if you're using blech_clust. + +2. **nemos environment**: Create a separate virtual environment: + ```bash + python -m venv ~/nemos_env + source ~/nemos_env/bin/activate + pip install nemos matplotlib pandas tables + ``` + +## Usage + +```bash +python utils/glm/infer_glm_rates.py [options] +``` + +The script will: +1. Auto-detect the blech_clust conda environment +2. Auto-detect or prompt for the nemos virtual environment +3. Extract data using blech_clust's ephys_data +4. Fit GLM models using nemos +5. Save results to HDF5 and generate plots + +### Options + +| Option | Default | Description | +|--------|---------|-------------| +| `--bin_size` | 25 | Bin size in ms | +| `--history_window` | 250 | History window in ms for autoregressive effects | +| `--n_basis_funcs` | 8 | Number of basis functions for history filter | +| `--time_lims` | [1500, 4500] | Time limits for analysis [start, end] in ms | +| `--include_coupling` | False | Include coupling between neurons | +| `--separate_tastes` | False | Fit separate models for each taste | +| `--separate_regions` | False | Fit separate models for each region | +| `--retrain` | False | Force retraining even if model exists | +| `--blech_clust_env` | blech_clust | Name of blech_clust conda environment | +| `--nemos_env` | (auto) | Path to nemos venv Python interpreter | + +### Example + +```bash +# Basic usage +python utils/glm/infer_glm_rates.py /path/to/data + +# With options +python utils/glm/infer_glm_rates.py /path/to/data \ + --separate_tastes \ + --separate_regions \ + --include_coupling \ + --nemos_env ~/nemos_env/bin/python +``` + +## Output + +- **HDF5**: Results saved under `/glm_output/regions/` in the data's HDF5 file +- **CSV**: `glm_output/bits_per_spike_summary.csv` with model performance metrics +- **Plots**: `glm_output/plots/` with distribution and individual neuron plots +- **Models**: `glm_output/artifacts/` with pickled GLM models + +## Model Details + +The GLM includes: +- **Spike history**: Raised cosine log basis functions capture autoregressive effects +- **Stimulus features** (optional): Convolved indicator for stimulus onset +- **Neural coupling** (optional): Cross-neuron dependencies within region + +Performance is measured using **bits per spike**, which quantifies information gain over a baseline homogeneous Poisson model. diff --git a/utils/glm/__init__.py b/utils/glm/__init__.py new file mode 100644 index 00000000..884e72ee --- /dev/null +++ b/utils/glm/__init__.py @@ -0,0 +1,9 @@ +""" +GLM-based firing rate estimation module. + +This module provides GLM fitting for electrophysiological spike data using nemos. +Due to dependency conflicts, nemos runs in a separate virtual environment. + +Usage: + python -m blech_clust.utils.glm.infer_glm_rates [options] +""" diff --git a/utils/glm/_glm_extract_data.py b/utils/glm/_glm_extract_data.py new file mode 100644 index 00000000..83ff3b1f --- /dev/null +++ b/utils/glm/_glm_extract_data.py @@ -0,0 +1,118 @@ +""" +Data extraction script for GLM fitting. + +This script runs in the blech_clust conda environment and extracts spike data +using ephys_data. The extracted data is saved as numpy files for the GLM +fitting script to load. + +This script is called by infer_glm_rates.py and should not be run directly. +""" + +import sys +import os +import json +import numpy as np + +def main(): + if len(sys.argv) != 2: + print("Usage: _glm_extract_data.py ") + sys.exit(1) + + params_path = sys.argv[1] + + # Load parameters + with open(params_path, 'r') as f: + params = json.load(f) + + data_dir = params['data_dir'] + temp_dir = params['temp_dir'] + time_lims = params['time_lims'] + bin_size = params['bin_size'] + + # Add blech_clust to path + script_path = os.path.abspath(__file__) + blech_clust_path = os.path.dirname(os.path.dirname(script_path)) + sys.path.insert(0, blech_clust_path) + + # Import blech_clust modules + from blech_clust.utils.blech_utils import imp_metadata, pipeline_graph_check + from blech_clust.utils.ephys_data import ephys_data + + print(f"Extracting data from: {data_dir}") + + # Pipeline check + this_pipeline_check = pipeline_graph_check(data_dir) + this_pipeline_check.check_previous(script_path) + this_pipeline_check.write_to_log(script_path, 'attempted') + + # Load data + data = ephys_data.ephys_data(data_dir) + data.get_spikes() + data.get_region_units() + + # Build region mapping + region_dict = dict(zip(data.region_names, data.region_units)) + region_vec = np.zeros(len(np.concatenate(data.region_units)), dtype=object) + for region_name, unit_list in region_dict.items(): + region_vec[unit_list] = region_name + + n_tastes = len(data.spikes) + n_neurons = data.spikes[0].shape[1] + + print(f" Tastes: {n_tastes}") + print(f" Neurons: {n_neurons}") + print(f" Regions: {data.region_names}") + + # Process and save spike data for each taste + spike_data = {} + for taste_idx, taste_spikes in enumerate(data.spikes): + # Cut to time limits + taste_spikes = taste_spikes[..., time_lims[0]:time_lims[1]] + + # Bin spikes + n_trials, n_neurons_taste, n_time = taste_spikes.shape + n_bins = n_time // bin_size + trimmed = taste_spikes[..., :n_bins * bin_size] + binned = trimmed.reshape(n_trials, n_neurons_taste, n_bins, bin_size).sum(axis=-1) + + spike_data[f'taste_{taste_idx}'] = binned + print(f" Taste {taste_idx}: {binned.shape}") + + # Save extracted data + extracted_data = { + 'spike_data': spike_data, + 'region_names': list(data.region_names), + 'region_units': {name: list(units) for name, units in region_dict.items()}, + 'region_vec': region_vec.tolist(), + 'n_tastes': n_tastes, + 'n_neurons': n_neurons, + 'hdf5_path': data.hdf5_path, + } + + # Save as numpy file + output_file = os.path.join(temp_dir, 'extracted_data.npz') + np.savez(output_file, **{ + 'spike_data_keys': list(spike_data.keys()), + 'region_names': np.array(data.region_names, dtype=object), + 'region_vec': np.array(region_vec, dtype=object), + 'n_tastes': n_tastes, + 'n_neurons': n_neurons, + 'hdf5_path': data.hdf5_path, + }) + + # Save spike arrays separately (npz doesn't handle nested dicts well) + for key, arr in spike_data.items(): + np.save(os.path.join(temp_dir, f'{key}.npy'), arr) + + # Save region units as JSON + with open(os.path.join(temp_dir, 'region_units.json'), 'w') as f: + json.dump({name: list(map(int, units)) for name, units in region_dict.items()}, f) + + print(f"\nData saved to: {temp_dir}") + + # Write to pipeline log + this_pipeline_check.write_to_log(script_path, 'completed') + + +if __name__ == '__main__': + main() diff --git a/utils/glm/_glm_fit_models.py b/utils/glm/_glm_fit_models.py new file mode 100644 index 00000000..23dccccf --- /dev/null +++ b/utils/glm/_glm_fit_models.py @@ -0,0 +1,475 @@ +""" +GLM fitting script using nemos. + +This script runs in the nemos virtual environment and fits GLM models to +spike data extracted by _glm_extract_data.py. + +This script is called by infer_glm_rates.py and should not be run directly. +""" + +import sys +import os +import json +import numpy as np +import pickle +import pandas as pd +import matplotlib +matplotlib.use('Agg') # Non-interactive backend +import matplotlib.pyplot as plt + +# Import nemos +try: + import nemos as nmo +except ImportError: + print("ERROR: nemos not found. Install with: pip install nemos") + sys.exit(1) + + +############################################################ +# GLM fitting functions +############################################################ + +def compute_bits_per_spike(model, X, y, baseline_rate=None): + """ + Compute bits per spike metric for model comparison. + + Bits per spike measures how much better the model predicts spikes + compared to a baseline (homogeneous Poisson) model. + """ + # Get model predictions + pred_rate = model.predict(X) + + # Compute log-likelihood under model + eps = 1e-10 + ll_model = np.sum(y * np.log(pred_rate + eps) - pred_rate) + + # Compute log-likelihood under baseline (homogeneous Poisson) + if baseline_rate is None: + baseline_rate = np.mean(y, axis=0, keepdims=True) + ll_baseline = np.sum(y * np.log(baseline_rate + eps) - baseline_rate) + + # Bits per spike = (LL_model - LL_baseline) / (n_spikes * log(2)) + n_spikes = np.sum(y) + if n_spikes > 0: + bits_per_spike = (ll_model - ll_baseline) / (n_spikes * np.log(2)) + else: + bits_per_spike = 0.0 + + return bits_per_spike + + +def fit_glm_single_neuron( + binned_spikes, + neuron_idx, + history_window_bins, + n_basis_funcs, + stim_bin=None, + other_neurons=None, + include_coupling=False, +): + """ + Fit GLM for a single neuron. + + Args: + binned_spikes: Binned spike counts (trials, neurons, time) + neuron_idx: Index of target neuron + history_window_bins: History window in bins + n_basis_funcs: Number of basis functions + stim_bin: Stimulus onset bin (optional) + other_neurons: Indices of other neurons for coupling (optional) + include_coupling: Whether to include coupling + + Returns: + model: Fitted GLM + X: Feature matrix + y: Target spike counts + feature_info: Dictionary with feature information + valid_mask: Boolean mask for valid samples + """ + n_trials, n_neurons, n_bins = binned_spikes.shape + + # Reshape to (samples, neurons) by concatenating trials + spikes_flat = binned_spikes.transpose(0, 2, 1).reshape(-1, n_neurons) + + # Target neuron spikes + y = spikes_flat[:, neuron_idx] + + # Build feature matrix + feature_list = [] + feature_info = {'names': [], 'slices': []} + current_idx = 0 + + # 1. Spike history basis for target neuron + history_basis = nmo.basis.RaisedCosineLogConv( + n_basis_funcs, + window_size=history_window_bins + ) + history_features = history_basis.compute_features( + spikes_flat[:, neuron_idx:neuron_idx+1] + ) + feature_list.append(history_features) + feature_info['names'].append('history') + feature_info['slices'].append(slice(current_idx, current_idx + n_basis_funcs)) + current_idx += n_basis_funcs + + # 2. Stimulus features (if provided) + if stim_bin is not None: + n_stim_basis = 5 + stim_window = min(20, history_window_bins) + + # Create stimulus indicator for each trial + stim_indicator = np.zeros((n_trials * n_bins,)) + for trial in range(n_trials): + trial_stim_idx = trial * n_bins + stim_bin + if 0 <= trial_stim_idx < len(stim_indicator): + stim_indicator[trial_stim_idx] = 1.0 + + stim_basis = nmo.basis.RaisedCosineLogConv(n_stim_basis, window_size=stim_window) + stim_features = stim_basis.compute_features(stim_indicator.reshape(-1, 1)) + feature_list.append(stim_features) + feature_info['names'].append('stimulus') + feature_info['slices'].append(slice(current_idx, current_idx + n_stim_basis)) + current_idx += n_stim_basis + + # 3. Coupling from other neurons (if requested) + if include_coupling and other_neurons is not None and len(other_neurons) > 0: + n_coupling_basis = max(2, n_basis_funcs // 2) + coupling_window = max(2, history_window_bins // 2) + coupling_basis = nmo.basis.RaisedCosineLogConv( + n_coupling_basis, + window_size=coupling_window + ) + for other_idx in other_neurons: + coupling_features = coupling_basis.compute_features( + spikes_flat[:, other_idx:other_idx+1] + ) + feature_list.append(coupling_features) + feature_info['names'].append(f'coupling_{other_idx}') + n_coupling_feats = coupling_features.shape[1] + feature_info['slices'].append(slice(current_idx, current_idx + n_coupling_feats)) + current_idx += n_coupling_feats + + # Concatenate all features + X = np.hstack(feature_list) + + # Handle NaN values from convolution edges + valid_mask = ~np.any(np.isnan(X), axis=1) + X_valid = X[valid_mask] + y_valid = y[valid_mask] + + # Fit GLM + model = nmo.glm.GLM( + regularizer=nmo.regularizer.Ridge() + ) + model.fit(X_valid, y_valid) + + return model, X_valid, y_valid, feature_info, valid_mask + + +def predict_firing_rates(model, X, valid_mask, original_shape): + """ + Generate firing rate predictions and reshape to original trial structure. + """ + n_trials, n_bins = original_shape + + # Predict on valid samples + pred_valid = model.predict(X) + + # Reconstruct full array with NaN for invalid samples + pred_full = np.full(n_trials * n_bins, np.nan) + pred_full[valid_mask] = pred_valid + + # Reshape to (trials, time) + pred_rates = pred_full.reshape(n_trials, n_bins) + + return pred_rates + + +############################################################ +# Main +############################################################ + +def main(): + if len(sys.argv) != 2: + print("Usage: _glm_fit_models.py ") + sys.exit(1) + + params_path = sys.argv[1] + + # Load parameters + with open(params_path, 'r') as f: + params = json.load(f) + + data_dir = params['data_dir'] + temp_dir = params['temp_dir'] + output_path = params['output_path'] + bin_size = params['bin_size'] + history_window = params['history_window'] + n_basis_funcs = params['n_basis_funcs'] + time_lims = params['time_lims'] + include_coupling = params['include_coupling'] + separate_tastes = params['separate_tastes'] + separate_regions = params['separate_regions'] + retrain = params['retrain'] + + history_window_bins = history_window // bin_size + + # Stimulus time (typically at 2000ms) + stim_time = 2000 + stim_bin = (stim_time - time_lims[0]) // bin_size + + # Setup directories + artifacts_dir = os.path.join(output_path, 'artifacts') + plots_dir = os.path.join(output_path, 'plots') + os.makedirs(artifacts_dir, exist_ok=True) + os.makedirs(plots_dir, exist_ok=True) + + # Load extracted data + print("Loading extracted spike data...") + extracted = np.load(os.path.join(temp_dir, 'extracted_data.npz'), allow_pickle=True) + + spike_data_keys = extracted['spike_data_keys'] + region_names = list(extracted['region_names']) + region_vec = list(extracted['region_vec']) + n_tastes = int(extracted['n_tastes']) + n_neurons = int(extracted['n_neurons']) + hdf5_path = str(extracted['hdf5_path']) + + # Load region units + with open(os.path.join(temp_dir, 'region_units.json'), 'r') as f: + region_dict = json.load(f) + + # Load spike arrays + spike_data = {} + for key in spike_data_keys: + spike_data[key] = np.load(os.path.join(temp_dir, f'{key}.npy')) + + print(f" Tastes: {n_tastes}") + print(f" Neurons: {n_neurons}") + print(f" Regions: {region_names}") + + # Save parameters + with open(os.path.join(artifacts_dir, 'params.json'), 'w') as f: + json.dump(params, f, indent=2) + + # Results storage + results = { + 'pred_firing': [], + 'binned_spikes': [], + 'bits_per_spike': [], + 'taste_idx': [], + 'region_name': [], + 'neuron_idx': [], + } + + # Process each taste + for taste_idx in range(n_tastes): + print(f"\n=== Processing taste {taste_idx} ===") + + binned = spike_data[f'taste_{taste_idx}'] + n_trials, n_neurons_taste, n_bins = binned.shape + + print(f" Trials: {n_trials}, Neurons: {n_neurons_taste}, Bins: {n_bins}") + + # Determine which neurons to process + if separate_regions: + regions_to_process = [(name, units) for name, units in region_dict.items() + if name.lower() != 'none'] + else: + regions_to_process = [('all', list(range(n_neurons_taste)))] + + for region_name, region_units in regions_to_process: + print(f" Region: {region_name} ({len(region_units)} neurons)") + + for neuron_idx in region_units: + # Determine coupling neurons + if include_coupling: + other_neurons = [i for i in region_units if i != neuron_idx] + else: + other_neurons = None + + # Model save path + model_name = f'taste_{taste_idx}_region_{region_name}_neuron_{neuron_idx}' + model_path = os.path.join(artifacts_dir, f'{model_name}.pkl') + + if os.path.exists(model_path) and not retrain: + print(f" Loading existing model for neuron {neuron_idx}") + with open(model_path, 'rb') as f: + saved = pickle.load(f) + model = saved['model'] + X_valid = saved['X'] + y_valid = saved['y'] + feature_info = saved['feature_info'] + valid_mask = saved['valid_mask'] + else: + print(f" Fitting GLM for neuron {neuron_idx}") + model, X_valid, y_valid, feature_info, valid_mask = fit_glm_single_neuron( + binned, + neuron_idx, + history_window_bins, + n_basis_funcs, + stim_bin=stim_bin, + other_neurons=other_neurons, + include_coupling=include_coupling, + ) + + # # Save model + # with open(model_path, 'wb') as f: + # pickle.dump({ + # 'model': model, + # 'X': X_valid, + # 'y': y_valid, + # 'feature_info': feature_info, + # 'valid_mask': valid_mask, + # }, f) + + # Predict firing rates + pred_rates = predict_firing_rates( + model, X_valid, valid_mask, (n_trials, n_bins) + ) + + # Compute bits per spike + bps = compute_bits_per_spike(model, X_valid, y_valid) + + # Store results + results['pred_firing'].append(pred_rates) + results['binned_spikes'].append(binned[:, neuron_idx, :]) + results['bits_per_spike'].append(bps) + results['taste_idx'].append(taste_idx) + results['region_name'].append(region_name) + results['neuron_idx'].append(neuron_idx) + + ############################################################ + # Generate plots + ############################################################ + + print("\n=== Generating plots ===") + + # Bits per spike summary + bps_df = pd.DataFrame({ + 'taste': results['taste_idx'], + 'region': results['region_name'], + 'neuron': results['neuron_idx'], + 'bits_per_spike': results['bits_per_spike'], + }) + + # Save summary + bps_df.to_csv(os.path.join(output_path, 'bits_per_spike_summary.csv'), index=False) + + # # Plot bits per spike distribution + # fig, ax = plt.subplots(figsize=(10, 6)) + # for region in bps_df['region'].unique(): + # region_bps = bps_df[bps_df['region'] == region]['bits_per_spike'] + # ax.hist(region_bps, alpha=0.5, label=region, bins=20) + # ax.set_xlabel('Bits per Spike') + # ax.set_ylabel('Count') + # ax.set_title('GLM Model Performance: Bits per Spike') + # ax.legend() + # fig.savefig(os.path.join(plots_dir, 'bits_per_spike_distribution.png'), dpi=150) + # plt.close(fig) + # + # Plot mean bits per spike by taste and region + fig, ax = plt.subplots(figsize=(10, 6)) + bps_summary = bps_df.groupby(['taste', 'region'])['bits_per_spike'].mean().unstack() + bps_summary.plot(kind='bar', ax=ax) + ax.set_xlabel('Taste') + ax.set_ylabel('Mean Bits per Spike') + ax.set_title('GLM Performance by Taste and Region') + ax.legend(title='Region') + fig.savefig(os.path.join(plots_dir, 'bits_per_spike_by_taste_region.png'), dpi=150) + plt.close(fig) + + # Plot example neuron predictions + ind_plot_dir = os.path.join(plots_dir, 'individual_neurons') + os.makedirs(ind_plot_dir, exist_ok=True) + + print("Plotting individual neuron predictions...") + for i in range(min(10, len(results['pred_firing']))): + taste_idx = results['taste_idx'][i] + region = results['region_name'][i] + neuron_idx = results['neuron_idx'][i] + pred = results['pred_firing'][i] + binned = results['binned_spikes'][i] + bps = results['bits_per_spike'][i] + + fig, axes = plt.subplots(3, 1, figsize=(12, 8), sharex=True) + + # Raster of binned spikes + axes[0].imshow(binned, aspect='auto', cmap='Greys', interpolation='none') + axes[0].set_ylabel('Trial') + axes[0].set_title(f'Binned Spikes - Taste {taste_idx}, {region}, Neuron {neuron_idx}') + + # Predicted rates + axes[1].imshow(pred, aspect='auto', cmap='viridis', interpolation='none') + axes[1].set_ylabel('Trial') + axes[1].set_title(f'GLM Predicted Rates (bits/spike: {bps:.3f})') + + # Mean comparison + time_bins = np.arange(pred.shape[1]) * bin_size + time_lims[0] + axes[2].plot(time_bins, np.nanmean(binned, axis=0), 'k-', label='Binned', alpha=0.7) + axes[2].plot(time_bins, np.nanmean(pred, axis=0), 'r-', label='GLM Pred', alpha=0.7) + axes[2].axvline(stim_time, color='b', linestyle='--', label='Stimulus') + axes[2].set_xlabel('Time (ms)') + axes[2].set_ylabel('Firing Rate') + axes[2].legend() + axes[2].set_title('Mean Firing Rate Comparison') + + plt.tight_layout() + fig.savefig(os.path.join(ind_plot_dir, f'taste_{taste_idx}_{region}_neuron_{neuron_idx}.png'), dpi=150) + plt.close(fig) + + ############################################################ + # Write results to HDF5 + ############################################################ + + print("\n=== Writing results to HDF5 ===") + + # Import tables here to write HDF5 + import tables + + with tables.open_file(hdf5_path, 'r+') as hf5: + # Remove existing GLM output if present + if '/glm_output' in hf5: + hf5.remove_node('/glm_output', recursive=True) + + # Create GLM output group + hf5.create_group('/', 'glm_output', 'GLM-based firing rate estimates') + glm_grp = hf5.get_node('/glm_output') + + # Store parameters + hf5.create_array(glm_grp, 'bin_size', np.array([bin_size])) + hf5.create_array(glm_grp, 'history_window', np.array([history_window])) + hf5.create_array(glm_grp, 'time_lims', np.array(time_lims)) + + # Create regions group + hf5.create_group('/glm_output', 'regions', 'Region-specific GLM output') + regions_grp = hf5.get_node('/glm_output/regions') + + # Organize results by taste and region + for i in range(len(results['pred_firing'])): + taste_idx = results['taste_idx'][i] + region = results['region_name'][i] + neuron_idx = results['neuron_idx'][i] + + group_name = f'taste_{taste_idx}_region_{region}_neuron_{neuron_idx}' + + neuron_grp = hf5.create_group(regions_grp, group_name, + f'Taste {taste_idx}, Region {region}, Neuron {neuron_idx}') + + hf5.create_array(neuron_grp, 'pred_firing', results['pred_firing'][i]) + hf5.create_array(neuron_grp, 'binned_spikes', results['binned_spikes'][i]) + hf5.create_array(neuron_grp, 'bits_per_spike', np.array([results['bits_per_spike'][i]])) + + print(f"\nResults saved to {hdf5_path}") + print(f"Plots saved to {plots_dir}") + print(f"Summary saved to {os.path.join(output_path, 'bits_per_spike_summary.csv')}") + + # Print summary statistics + print("\n=== Summary ===") + print(f"Total neurons processed: {len(results['pred_firing'])}") + print(f"Mean bits per spike: {np.mean(results['bits_per_spike']):.4f}") + print(f"Std bits per spike: {np.std(results['bits_per_spike']):.4f}") + + +if __name__ == '__main__': + main() diff --git a/utils/glm/infer_glm_rates.py b/utils/glm/infer_glm_rates.py new file mode 100644 index 00000000..fc00ba24 --- /dev/null +++ b/utils/glm/infer_glm_rates.py @@ -0,0 +1,265 @@ +""" +GLM-based firing rate estimation using nemos. + +This script orchestrates GLM fitting across two Python environments: +1. blech_clust conda environment: Extracts spike data using ephys_data +2. nemos virtual environment: Fits GLM models + +The script automatically detects environments and saves intermediate data files +for cross-environment communication. + +Usage: + python infer_glm_rates.py [options] + +Options: + --bin_size: Bin size in ms for spike binning (default: 25) + --history_window: History window in ms for autoregressive effects (default: 250) + --n_basis_funcs: Number of basis functions for history filter (default: 8) + --time_lims: Time limits for analysis [start, end] in ms (default: [1500, 4500]) + --include_coupling: Include coupling between neurons (default: False) + --separate_tastes: Fit separate models for each taste (default: False) + --separate_regions: Fit separate models for each region (default: False) + --retrain: Force retraining even if model exists (default: False) + --blech_clust_env: Name of blech_clust conda environment (default: blech_clust) + --nemos_env: Path to nemos venv Python interpreter (default: searches common locations) +""" + +import argparse +import os +import sys +import subprocess +import json +import shutil + +############################################################ +# Argument parsing +############################################################ + +def parse_args(test_mode=False): + if test_mode: + print('====================') + print('Running in test mode') + print('====================') + # data_dir = '/home/abuzarmahmood/projects/blech_clust/pipeline_testing/test_data_handling/test_data/KM45_5tastes_210620_113227_new' + # data_dir = '/media/storage/abu_resorted/bla_gc/AM11_4Tastes_191030_114043_copy' + data_dir = '/media/storage/abu_resorted/bla_gc/AM35_4Tastes_201228_124547' + args = argparse.Namespace( + data_dir=data_dir, + bin_size=25, + history_window=250, + n_basis_funcs=8, + time_lims=[1500, 4500], + include_coupling=True, + separate_tastes=True, + separate_regions=True, + retrain=False, + blech_clust_env='blech_clust', + nemos_env='/home/abuzarmahmood/Desktop/blech_clust/nemos_venv/bin/python', + ) + return args, test_mode + else: + parser = argparse.ArgumentParser( + description='Infer firing rates using GLM (nemos)') + parser.add_argument('data_dir', help='Path to data directory') + parser.add_argument('--bin_size', type=int, default=25, + help='Bin size in ms for spike binning (default: %(default)s)') + parser.add_argument('--history_window', type=int, default=250, + help='History window in ms for autoregressive effects (default: %(default)s)') + parser.add_argument('--n_basis_funcs', type=int, default=8, + help='Number of basis functions for history filter (default: %(default)s)') + parser.add_argument('--time_lims', type=int, nargs=2, default=[1500, 4500], + help='Time limits for analysis [start, end] in ms (default: %(default)s)') + parser.add_argument('--include_coupling', action='store_false', + help='Include coupling between neurons (default: %(default)s)') + parser.add_argument('--separate_tastes', action='store_true', + help='Fit separate models for each taste (default: %(default)s)') + parser.add_argument('--separate_regions', action='store_true', + help='Fit separate models for each region (default: %(default)s)') + parser.add_argument('--retrain', action='store_true', + help='Force retraining of model (default: %(default)s)') + parser.add_argument('--blech_clust_env', type=str, default='blech_clust', + help='Name of blech_clust conda environment (default: %(default)s)') + parser.add_argument('--nemos_env', type=str, default=None, + help='Path to nemos venv Python interpreter') + return parser.parse_args(), test_mode + + +############################################################ +# Environment detection +############################################################ + +def find_conda_env(env_name): + """Find conda environment by name and return its Python path.""" + try: + result = subprocess.run( + # ['conda', 'run', '-n', env_name, 'which', 'python'], + ['conda', 'env', 'list'], + capture_output=True, text=True, timeout=30 + ) + # Parse output to find env path + for line in result.stdout.splitlines(): + if line.startswith(env_name + ' '): + env_path = line.split()[1] + python_path = os.path.join(env_path, 'bin', 'python') + return python_path + if result.returncode == 0: + return result.stdout.strip() + except (subprocess.TimeoutExpired, FileNotFoundError): + pass + + # Try common conda locations + home = os.path.expanduser('~') + common_paths = [ + f'{home}/anaconda3/envs/{env_name}/bin/python', + f'{home}/miniconda3/envs/{env_name}/bin/python', + f'{home}/miniforge3/envs/{env_name}/bin/python', + f'/opt/conda/envs/{env_name}/bin/python', + ] + for path in common_paths: + if os.path.exists(path): + return path + return None + + +def find_nemos_env(): + """Find nemos virtual environment Python interpreter.""" + home = os.path.expanduser('~') + common_paths = [ + f'{home}/nemos_env/bin/python', + f'{home}/venvs/nemos/bin/python', + f'{home}/.venvs/nemos/bin/python', + f'{home}/envs/nemos/bin/python', + f'{home}/.local/share/virtualenvs/nemos/bin/python', + ] + for path in common_paths: + if os.path.exists(path): + # Verify nemos is installed + try: + result = subprocess.run( + [path, '-c', 'import nemos'], + capture_output=True, timeout=10 + ) + if result.returncode == 0: + return path + except subprocess.TimeoutExpired: + pass + return None + + +def prompt_for_env(env_type, default_name=None): + """Prompt user for environment path.""" + print(f"\n{env_type} environment not found automatically.") + if default_name: + print(f"Expected conda environment name: {default_name}") + + response = input(f"Enter path to {env_type} Python interpreter (or 'q' to quit): ").strip() + if response.lower() == 'q': + sys.exit(1) + if os.path.exists(response): + return response + else: + print(f"Path not found: {response}") + return prompt_for_env(env_type, default_name) + + +############################################################ +# Main orchestration +############################################################ + +def main(): + args, _ = parse_args() + # args, test_mode = parse_args(test_mode=True) + + if not test_mode: + script_dir = os.path.dirname(os.path.abspath(__file__)) + else: + script_dir = '/home/abuzarmahmood/Desktop/blech_clust/utils/glm' + data_dir = os.path.abspath(args.data_dir) + + # Setup output directories + output_path = os.path.join(data_dir, 'glm_output') + # Use temp dir within the glm module directory + temp_dir = os.path.join(script_dir, '_temp') + os.makedirs(temp_dir, exist_ok=True) + os.makedirs(output_path, exist_ok=True) + + print("=" * 60) + print("GLM-based Firing Rate Estimation") + print("=" * 60) + print(f"Data directory: {data_dir}") + + # Find blech_clust environment + print("\n[1/4] Locating blech_clust environment...") + blech_python = find_conda_env(args.blech_clust_env) + if blech_python is None: + blech_python = prompt_for_env('blech_clust', args.blech_clust_env) + print(f" Found: {blech_python}") + + # Find nemos environment + print("\n[2/4] Locating nemos environment...") + if args.nemos_env: + nemos_python = args.nemos_env + else: + nemos_python = find_nemos_env() + if nemos_python is None: + nemos_python = prompt_for_env('nemos') + print(f" Found: {nemos_python}") + + # Save parameters for sub-scripts + params = { + 'data_dir': data_dir, + 'bin_size': args.bin_size, + 'history_window': args.history_window, + 'n_basis_funcs': args.n_basis_funcs, + 'time_lims': args.time_lims, + 'include_coupling': args.include_coupling, + 'separate_tastes': args.separate_tastes, + 'separate_regions': args.separate_regions, + 'retrain': args.retrain, + 'temp_dir': temp_dir, + 'output_path': output_path, + } + params_path = os.path.join(temp_dir, 'glm_params.json') + with open(params_path, 'w') as f: + json.dump(params, f, indent=2) + + # Step 1: Extract data using blech_clust environment + print("\n[3/4] Extracting spike data (blech_clust environment)...") + extract_script = os.path.join(script_dir, '_glm_extract_data.py') + + result = subprocess.run( + [blech_python, extract_script, params_path], + capture_output=True, text=True + ) + if result.returncode != 0: + print("ERROR: Data extraction failed") + print(result.stderr) + sys.exit(1) + print(result.stdout) + + # Step 2: Fit GLM using nemos environment + print("\n[4/4] Fitting GLM models (nemos environment)...") + fit_script = os.path.join(script_dir, '_glm_fit_models.py') + + result = subprocess.run( + [nemos_python, fit_script, params_path], + capture_output=True, text=True + ) + if result.returncode != 0: + print("ERROR: GLM fitting failed") + print(result.stderr) + # sys.exit(1) + print(result.stdout) + + # Cleanup temp files + print("\nCleaning up temporary files...") + shutil.rmtree(temp_dir) + + print("\n" + "=" * 60) + print("GLM fitting complete!") + print(f"Results saved to: {output_path}") + print("=" * 60) + + +if __name__ == '__main__': + main()