From b866312a825ee1f3cfa75288e7ef64200dfecb8c Mon Sep 17 00:00:00 2001 From: Abuzar Mahmood Date: Sun, 18 Jan 2026 20:26:16 +0000 Subject: [PATCH 1/5] Add GLM-based firing rate estimation script Implements issue #708: GLM-based firing rate inference using nemos library. Features: - Spike history modeling with raised cosine basis functions - Optional stimulus timing features - Optional coupling between simultaneously recorded neurons - Bits per spike metric for model comparison vs binned rates - Separate processing by taste and/or region Co-authored-by: Ona --- params/_templates/glm_params.json | 9 + requirements/requirements-optional.txt | 1 + utils/infer_glm_rates.py | 629 +++++++++++++++++++++++++ 3 files changed, 639 insertions(+) create mode 100644 params/_templates/glm_params.json create mode 100644 utils/infer_glm_rates.py diff --git a/params/_templates/glm_params.json b/params/_templates/glm_params.json new file mode 100644 index 00000000..c6519b88 --- /dev/null +++ b/params/_templates/glm_params.json @@ -0,0 +1,9 @@ +{ + "bin_size": 25, + "history_window": 250, + "n_basis_funcs": 8, + "time_lims": [1500, 4500], + "include_coupling": false, + "separate_tastes": true, + "separate_regions": true +} diff --git a/requirements/requirements-optional.txt b/requirements/requirements-optional.txt index f17103f0..cff4cb21 100644 --- a/requirements/requirements-optional.txt +++ b/requirements/requirements-optional.txt @@ -1 +1,2 @@ pymc +nemos diff --git a/utils/infer_glm_rates.py b/utils/infer_glm_rates.py new file mode 100644 index 00000000..61a0a23d --- /dev/null +++ b/utils/infer_glm_rates.py @@ -0,0 +1,629 @@ +""" +GLM-based firing rate estimation using nemos. + +This module fits Generalized Linear Models (GLMs) to electrophysiological spike data +to infer firing rates. The GLM can incorporate: +- Spike history (autoregressive effects) +- Stimulus timing +- Coupling between simultaneously recorded neurons + +Model comparison is provided via bits per spike metric against binned rates. + +Usage: + python infer_glm_rates.py [options] + +Options: + --bin_size: Bin size in ms for spike binning (default: 25) + --history_window: History window in ms for autoregressive effects (default: 250) + --n_basis_funcs: Number of basis functions for history filter (default: 8) + --time_lims: Time limits for analysis [start, end] in ms (default: [1500, 4500]) + --include_coupling: Include coupling between neurons (default: False) + --separate_tastes: Fit separate models for each taste (default: False) + --separate_regions: Fit separate models for each region (default: False) + --retrain: Force retraining even if model exists (default: False) +""" + +import argparse +import os + +test_mode = False +if test_mode: + print('====================') + print('Running in test mode') + print('====================') + data_dir = '/home/abuzarmahmood/projects/blech_clust/pipeline_testing/test_data_handling/test_data/KM45_5tastes_210620_113227_new' + script_path = '/home/abuzarmahmood/projects/blech_clust/utils/infer_glm_rates.py' + blech_clust_path = os.path.dirname(os.path.dirname(script_path)) + args = argparse.Namespace( + data_dir=data_dir, + bin_size=25, + history_window=250, + n_basis_funcs=8, + time_lims=[1500, 4500], + include_coupling=False, + separate_tastes=True, + separate_regions=True, + retrain=False, + ) +else: + parser = argparse.ArgumentParser( + description='Infer firing rates using GLM (nemos)') + parser.add_argument('data_dir', help='Path to data directory') + parser.add_argument('--bin_size', type=int, default=25, + help='Bin size in ms for spike binning (default: %(default)s)') + parser.add_argument('--history_window', type=int, default=250, + help='History window in ms for autoregressive effects (default: %(default)s)') + parser.add_argument('--n_basis_funcs', type=int, default=8, + help='Number of basis functions for history filter (default: %(default)s)') + parser.add_argument('--time_lims', type=int, nargs=2, default=[1500, 4500], + help='Time limits for analysis [start, end] in ms (default: %(default)s)') + parser.add_argument('--include_coupling', action='store_true', + help='Include coupling between neurons (default: %(default)s)') + parser.add_argument('--separate_tastes', action='store_true', + help='Fit separate models for each taste (default: %(default)s)') + parser.add_argument('--separate_regions', action='store_true', + help='Fit separate models for each region (default: %(default)s)') + parser.add_argument('--retrain', action='store_true', + help='Force retraining of model (default: %(default)s)') + + args = parser.parse_args() + data_dir = args.data_dir + script_path = os.path.abspath(__file__) + blech_clust_path = os.path.dirname(os.path.dirname(script_path)) + +############################################################ +############################################################ + +import tables +import matplotlib.pyplot as plt +import numpy as np +import sys +from pprint import pprint +import json +from itertools import product +import pandas as pd +import pickle + +sys.path.append(blech_clust_path) + +from blech_clust.utils.blech_utils import imp_metadata, pipeline_graph_check +from blech_clust.utils.ephys_data import visualize as vz +from blech_clust.utils.ephys_data import ephys_data + +# Import nemos for GLM fitting +try: + import nemos as nmo +except ImportError: + raise ImportError( + 'nemos is required for GLM fitting. Install with: pip install nemos' + ) + +############################################################ +############################################################ + + +def compute_bits_per_spike(model, X, y, baseline_rate=None): + """ + Compute bits per spike metric for model comparison. + + Bits per spike measures how much better the model predicts spikes + compared to a baseline (homogeneous Poisson) model. + + Args: + model: Fitted nemos GLM model + X: Feature matrix + y: Spike counts + baseline_rate: Baseline firing rate (if None, uses mean rate) + + Returns: + bits_per_spike: Information gain in bits per spike + """ + # Get model predictions + pred_rate = model.predict(X) + + # Compute log-likelihood under model + # For Poisson: LL = y * log(rate) - rate - log(y!) + # We ignore the log(y!) term as it cancels in comparison + eps = 1e-10 + ll_model = np.sum(y * np.log(pred_rate + eps) - pred_rate) + + # Compute log-likelihood under baseline (homogeneous Poisson) + if baseline_rate is None: + baseline_rate = np.mean(y, axis=0, keepdims=True) + ll_baseline = np.sum(y * np.log(baseline_rate + eps) - baseline_rate) + + # Bits per spike = (LL_model - LL_baseline) / (n_spikes * log(2)) + n_spikes = np.sum(y) + if n_spikes > 0: + bits_per_spike = (ll_model - ll_baseline) / (n_spikes * np.log(2)) + else: + bits_per_spike = 0.0 + + return bits_per_spike + + +def bin_spikes(spike_data, bin_size): + """ + Bin spike data. + + Args: + spike_data: Array of shape (trials, neurons, time) + bin_size: Bin size in samples + + Returns: + binned: Array of shape (trials, neurons, n_bins) + """ + n_trials, n_neurons, n_time = spike_data.shape + n_bins = n_time // bin_size + trimmed = spike_data[..., :n_bins * bin_size] + binned = trimmed.reshape(n_trials, n_neurons, n_bins, bin_size).sum(axis=-1) + return binned + + +def create_stimulus_feature(n_samples, stim_bin, n_basis=5, window_size=10): + """ + Create stimulus feature using raised cosine basis. + + Args: + n_samples: Number of time samples + stim_bin: Bin index of stimulus onset + n_basis: Number of basis functions + window_size: Window size for basis + + Returns: + stim_features: Stimulus feature matrix + """ + # Create stimulus indicator + stim = np.zeros(n_samples) + if 0 <= stim_bin < n_samples: + stim[stim_bin] = 1.0 + + # Convolve with raised cosine basis + basis = nmo.basis.RaisedCosineLogConv(n_basis, window_size=window_size) + stim_features = basis.compute_features(stim.reshape(-1, 1)) + + return stim_features + + +def fit_glm_single_neuron( + binned_spikes, + neuron_idx, + history_window_bins, + n_basis_funcs, + stim_bin=None, + other_neurons=None, + include_coupling=False, +): + """ + Fit GLM for a single neuron. + + Args: + binned_spikes: Binned spike counts (trials, neurons, time) + neuron_idx: Index of target neuron + history_window_bins: History window in bins + n_basis_funcs: Number of basis functions + stim_bin: Stimulus onset bin (optional) + other_neurons: Indices of other neurons for coupling (optional) + include_coupling: Whether to include coupling + + Returns: + model: Fitted GLM + X: Feature matrix + y: Target spike counts + feature_info: Dictionary with feature information + """ + n_trials, n_neurons, n_bins = binned_spikes.shape + + # Reshape to (samples, neurons) by concatenating trials + spikes_flat = binned_spikes.transpose(0, 2, 1).reshape(-1, n_neurons) + + # Target neuron spikes + y = spikes_flat[:, neuron_idx] + + # Build feature matrix + feature_list = [] + feature_info = {'names': [], 'slices': []} + current_idx = 0 + + # 1. Spike history basis for target neuron + history_basis = nmo.basis.RaisedCosineLogConv( + n_basis_funcs, + window_size=history_window_bins + ) + history_features = history_basis.compute_features( + spikes_flat[:, neuron_idx:neuron_idx+1] + ) + feature_list.append(history_features) + feature_info['names'].append('history') + feature_info['slices'].append(slice(current_idx, current_idx + n_basis_funcs)) + current_idx += n_basis_funcs + + # 2. Stimulus features (if provided) + if stim_bin is not None: + n_stim_basis = 5 + stim_window = min(20, history_window_bins) + + # Create stimulus indicator for each trial + stim_indicator = np.zeros((n_trials * n_bins,)) + for trial in range(n_trials): + trial_stim_idx = trial * n_bins + stim_bin + if 0 <= trial_stim_idx < len(stim_indicator): + stim_indicator[trial_stim_idx] = 1.0 + + stim_basis = nmo.basis.RaisedCosineLogConv(n_stim_basis, window_size=stim_window) + stim_features = stim_basis.compute_features(stim_indicator.reshape(-1, 1)) + feature_list.append(stim_features) + feature_info['names'].append('stimulus') + feature_info['slices'].append(slice(current_idx, current_idx + n_stim_basis)) + current_idx += n_stim_basis + + # 3. Coupling from other neurons (if requested) + if include_coupling and other_neurons is not None and len(other_neurons) > 0: + coupling_basis = nmo.basis.RaisedCosineLogConv( + n_basis_funcs // 2, # Fewer basis functions for coupling + window_size=history_window_bins // 2 + ) + for other_idx in other_neurons: + coupling_features = coupling_basis.compute_features( + spikes_flat[:, other_idx:other_idx+1] + ) + feature_list.append(coupling_features) + feature_info['names'].append(f'coupling_{other_idx}') + n_coupling_feats = coupling_features.shape[1] + feature_info['slices'].append(slice(current_idx, current_idx + n_coupling_feats)) + current_idx += n_coupling_feats + + # Concatenate all features + X = np.hstack(feature_list) + + # Handle NaN values from convolution edges + valid_mask = ~np.any(np.isnan(X), axis=1) + X_valid = X[valid_mask] + y_valid = y[valid_mask] + + # Fit GLM + model = nmo.glm.GLM( + regularizer=nmo.regularizer.Ridge(regularizer_strength=0.01) + ) + model.fit(X_valid, y_valid) + + return model, X_valid, y_valid, feature_info, valid_mask + + +def predict_firing_rates(model, X, valid_mask, original_shape): + """ + Generate firing rate predictions and reshape to original trial structure. + + Args: + model: Fitted GLM + X: Feature matrix (valid samples only) + valid_mask: Boolean mask for valid samples + original_shape: Original shape (n_trials, n_bins) + + Returns: + pred_rates: Predicted firing rates (n_trials, n_bins) + """ + n_trials, n_bins = original_shape + + # Predict on valid samples + pred_valid = model.predict(X) + + # Reconstruct full array with NaN for invalid samples + pred_full = np.full(n_trials * n_bins, np.nan) + pred_full[valid_mask] = pred_valid + + # Reshape to (trials, time) + pred_rates = pred_full.reshape(n_trials, n_bins) + + return pred_rates + + +############################################################ +############################################################ + +if not test_mode: + metadata_handler = imp_metadata([[], args.data_dir]) + this_pipeline_check = pipeline_graph_check(args.data_dir) + this_pipeline_check.check_previous(script_path) + this_pipeline_check.write_to_log(script_path, 'attempted') + +# Setup output directories +output_path = os.path.join(data_dir, 'glm_output') +artifacts_dir = os.path.join(output_path, 'artifacts') +plots_dir = os.path.join(output_path, 'plots') + +for dir_path in [output_path, artifacts_dir, plots_dir]: + if not os.path.exists(dir_path): + os.makedirs(dir_path) + +print(f'Processing data from {data_dir}') +print(f'Parameters:') +print(f' bin_size: {args.bin_size} ms') +print(f' history_window: {args.history_window} ms') +print(f' n_basis_funcs: {args.n_basis_funcs}') +print(f' time_lims: {args.time_lims}') +print(f' include_coupling: {args.include_coupling}') +print(f' separate_tastes: {args.separate_tastes}') +print(f' separate_regions: {args.separate_regions}') + +# Save parameters +params_dict = { + 'bin_size': args.bin_size, + 'history_window': args.history_window, + 'n_basis_funcs': args.n_basis_funcs, + 'time_lims': args.time_lims, + 'include_coupling': args.include_coupling, + 'separate_tastes': args.separate_tastes, + 'separate_regions': args.separate_regions, +} +with open(os.path.join(artifacts_dir, 'params.json'), 'w') as f: + json.dump(params_dict, f, indent=4) + +############################################################ +# Load data +############################################################ + +basename = os.path.basename(data_dir) +data = ephys_data.ephys_data(data_dir) +data.get_spikes() +data.get_region_units() + +# Build region mapping +region_dict = dict(zip(data.region_names, data.region_units)) +region_vec = np.zeros(len(np.concatenate(data.region_units)), dtype=object) +for region_name, unit_list in region_dict.items(): + region_vec[unit_list] = region_name + +n_tastes = len(data.spikes) +n_neurons = data.spikes[0].shape[1] + +print(f'Loaded {n_tastes} tastes, {n_neurons} neurons') +print(f'Regions: {data.region_names}') + +############################################################ +# Process data +############################################################ + +bin_size = args.bin_size +history_window = args.history_window +history_window_bins = history_window // bin_size +n_basis_funcs = args.n_basis_funcs +time_lims = args.time_lims + +# Stimulus time (typically at 2000ms) +stim_time = 2000 +stim_bin = (stim_time - time_lims[0]) // bin_size + +# Determine processing groups +group_by_list = [] +if args.separate_tastes: + group_by_list.append('taste') +if args.separate_regions: + group_by_list.append('region') + +# Results storage +results = { + 'pred_firing': [], + 'binned_spikes': [], + 'bits_per_spike': [], + 'models': [], + 'taste_idx': [], + 'region_name': [], + 'neuron_idx': [], +} + +# Process each taste +for taste_idx, taste_spikes in enumerate(data.spikes): + print(f'\n=== Processing taste {taste_idx} ===') + + # Cut to time limits + taste_spikes = taste_spikes[..., time_lims[0]:time_lims[1]] + + # Bin spikes + binned = bin_spikes(taste_spikes, bin_size) + n_trials, n_neurons_taste, n_bins = binned.shape + + print(f' Trials: {n_trials}, Neurons: {n_neurons_taste}, Bins: {n_bins}') + + # Determine which neurons to process + if args.separate_regions: + regions_to_process = [(name, units) for name, units in region_dict.items() + if name.lower() != 'none'] + else: + regions_to_process = [('all', list(range(n_neurons_taste)))] + + for region_name, region_units in regions_to_process: + print(f' Region: {region_name} ({len(region_units)} neurons)') + + for neuron_idx in region_units: + # Determine coupling neurons + if args.include_coupling: + other_neurons = [i for i in region_units if i != neuron_idx] + else: + other_neurons = None + + # Model save path + model_name = f'taste_{taste_idx}_region_{region_name}_neuron_{neuron_idx}' + model_path = os.path.join(artifacts_dir, f'{model_name}.pkl') + + if os.path.exists(model_path) and not args.retrain: + print(f' Loading existing model for neuron {neuron_idx}') + with open(model_path, 'rb') as f: + saved = pickle.load(f) + model = saved['model'] + X_valid = saved['X'] + y_valid = saved['y'] + feature_info = saved['feature_info'] + valid_mask = saved['valid_mask'] + else: + print(f' Fitting GLM for neuron {neuron_idx}') + model, X_valid, y_valid, feature_info, valid_mask = fit_glm_single_neuron( + binned, + neuron_idx, + history_window_bins, + n_basis_funcs, + stim_bin=stim_bin, + other_neurons=other_neurons, + include_coupling=args.include_coupling, + ) + + # Save model + with open(model_path, 'wb') as f: + pickle.dump({ + 'model': model, + 'X': X_valid, + 'y': y_valid, + 'feature_info': feature_info, + 'valid_mask': valid_mask, + }, f) + + # Predict firing rates + pred_rates = predict_firing_rates( + model, X_valid, valid_mask, (n_trials, n_bins) + ) + + # Compute bits per spike + bps = compute_bits_per_spike(model, X_valid, y_valid) + + # Store results + results['pred_firing'].append(pred_rates) + results['binned_spikes'].append(binned[:, neuron_idx, :]) + results['bits_per_spike'].append(bps) + results['models'].append(model) + results['taste_idx'].append(taste_idx) + results['region_name'].append(region_name) + results['neuron_idx'].append(neuron_idx) + +############################################################ +# Generate plots +############################################################ + +print('\n=== Generating plots ===') + +# Bits per spike summary +bps_df = pd.DataFrame({ + 'taste': results['taste_idx'], + 'region': results['region_name'], + 'neuron': results['neuron_idx'], + 'bits_per_spike': results['bits_per_spike'], +}) + +# Save summary +bps_df.to_csv(os.path.join(output_path, 'bits_per_spike_summary.csv'), index=False) + +# Plot bits per spike distribution +fig, ax = plt.subplots(figsize=(10, 6)) +for region in bps_df['region'].unique(): + region_bps = bps_df[bps_df['region'] == region]['bits_per_spike'] + ax.hist(region_bps, alpha=0.5, label=region, bins=20) +ax.set_xlabel('Bits per Spike') +ax.set_ylabel('Count') +ax.set_title('GLM Model Performance: Bits per Spike') +ax.legend() +fig.savefig(os.path.join(plots_dir, 'bits_per_spike_distribution.png'), dpi=150) +plt.close(fig) + +# Plot mean bits per spike by taste and region +fig, ax = plt.subplots(figsize=(10, 6)) +bps_summary = bps_df.groupby(['taste', 'region'])['bits_per_spike'].mean().unstack() +bps_summary.plot(kind='bar', ax=ax) +ax.set_xlabel('Taste') +ax.set_ylabel('Mean Bits per Spike') +ax.set_title('GLM Performance by Taste and Region') +ax.legend(title='Region') +fig.savefig(os.path.join(plots_dir, 'bits_per_spike_by_taste_region.png'), dpi=150) +plt.close(fig) + +# Plot example neuron predictions +ind_plot_dir = os.path.join(plots_dir, 'individual_neurons') +if not os.path.exists(ind_plot_dir): + os.makedirs(ind_plot_dir) + +print('Plotting individual neuron predictions...') +for i in range(min(10, len(results['pred_firing']))): # Plot first 10 neurons + taste_idx = results['taste_idx'][i] + region = results['region_name'][i] + neuron_idx = results['neuron_idx'][i] + pred = results['pred_firing'][i] + binned = results['binned_spikes'][i] + bps = results['bits_per_spike'][i] + + fig, axes = plt.subplots(3, 1, figsize=(12, 8), sharex=True) + + # Raster of binned spikes + axes[0].imshow(binned, aspect='auto', cmap='Greys', interpolation='none') + axes[0].set_ylabel('Trial') + axes[0].set_title(f'Binned Spikes - Taste {taste_idx}, {region}, Neuron {neuron_idx}') + + # Predicted rates + axes[1].imshow(pred, aspect='auto', cmap='viridis', interpolation='none') + axes[1].set_ylabel('Trial') + axes[1].set_title(f'GLM Predicted Rates (bits/spike: {bps:.3f})') + + # Mean comparison + time_bins = np.arange(pred.shape[1]) * bin_size + time_lims[0] + axes[2].plot(time_bins, np.nanmean(binned, axis=0), 'k-', label='Binned', alpha=0.7) + axes[2].plot(time_bins, np.nanmean(pred, axis=0), 'r-', label='GLM Pred', alpha=0.7) + axes[2].axvline(stim_time, color='b', linestyle='--', label='Stimulus') + axes[2].set_xlabel('Time (ms)') + axes[2].set_ylabel('Firing Rate') + axes[2].legend() + axes[2].set_title('Mean Firing Rate Comparison') + + plt.tight_layout() + fig.savefig(os.path.join(ind_plot_dir, f'taste_{taste_idx}_{region}_neuron_{neuron_idx}.png'), dpi=150) + plt.close(fig) + +############################################################ +# Write results to HDF5 +############################################################ + +print('\n=== Writing results to HDF5 ===') + +hdf5_path = data.hdf5_path +with tables.open_file(hdf5_path, 'r+') as hf5: + # Remove existing GLM output if present + if '/glm_output' in hf5: + hf5.remove_node('/glm_output', recursive=True) + + # Create GLM output group + hf5.create_group('/', 'glm_output', 'GLM-based firing rate estimates') + glm_grp = hf5.get_node('/glm_output') + + # Store parameters + hf5.create_array(glm_grp, 'bin_size', np.array([bin_size])) + hf5.create_array(glm_grp, 'history_window', np.array([history_window])) + hf5.create_array(glm_grp, 'time_lims', np.array(time_lims)) + + # Create regions group + hf5.create_group('/glm_output', 'regions', 'Region-specific GLM output') + regions_grp = hf5.get_node('/glm_output/regions') + + # Organize results by taste and region + for i in range(len(results['pred_firing'])): + taste_idx = results['taste_idx'][i] + region = results['region_name'][i] + neuron_idx = results['neuron_idx'][i] + + group_name = f'taste_{taste_idx}_region_{region}_neuron_{neuron_idx}' + + neuron_grp = hf5.create_group(regions_grp, group_name, + f'Taste {taste_idx}, Region {region}, Neuron {neuron_idx}') + + hf5.create_array(neuron_grp, 'pred_firing', results['pred_firing'][i]) + hf5.create_array(neuron_grp, 'binned_spikes', results['binned_spikes'][i]) + hf5.create_array(neuron_grp, 'bits_per_spike', np.array([results['bits_per_spike'][i]])) + +print(f'\nResults saved to {hdf5_path}') +print(f'Plots saved to {plots_dir}') +print(f'Summary saved to {os.path.join(output_path, "bits_per_spike_summary.csv")}') + +# Print summary statistics +print('\n=== Summary ===') +print(f'Total neurons processed: {len(results["pred_firing"])}') +print(f'Mean bits per spike: {np.mean(results["bits_per_spike"]):.4f}') +print(f'Std bits per spike: {np.std(results["bits_per_spike"]):.4f}') + +# Write successful execution to log +if not test_mode: + this_pipeline_check.write_to_log(script_path, 'completed') From 22130193c96f7259279193caf6bff13a419d0cf2 Mon Sep 17 00:00:00 2001 From: Abuzar Mahmood Date: Mon, 19 Jan 2026 17:22:20 +0000 Subject: [PATCH 2/5] Refactor GLM script for separate nemos environment Move GLM scripts to utils/glm/ subdirectory with two-environment architecture: - infer_glm_rates.py: Main orchestrator that detects environments - _glm_extract_data.py: Runs in blech_clust conda env for data extraction - _glm_fit_models.py: Runs in nemos venv for GLM fitting Data is passed between environments via temporary numpy files. Remove nemos from requirements-optional.txt since it runs in separate venv. Co-authored-by: Ona --- .gitignore | 1 + requirements/requirements-optional.txt | 1 - utils/glm/__init__.py | 9 + utils/glm/_glm_extract_data.py | 118 +++++ utils/glm/_glm_fit_models.py | 475 +++++++++++++++++++ utils/glm/infer_glm_rates.py | 232 +++++++++ utils/infer_glm_rates.py | 629 ------------------------- 7 files changed, 835 insertions(+), 630 deletions(-) create mode 100644 utils/glm/__init__.py create mode 100644 utils/glm/_glm_extract_data.py create mode 100644 utils/glm/_glm_fit_models.py create mode 100644 utils/glm/infer_glm_rates.py delete mode 100644 utils/infer_glm_rates.py diff --git a/.gitignore b/.gitignore index 725375b4..93c3ecb4 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ blech.dir *.Rhistory .aider* .env +utils/glm/_temp/ diff --git a/requirements/requirements-optional.txt b/requirements/requirements-optional.txt index cff4cb21..f17103f0 100644 --- a/requirements/requirements-optional.txt +++ b/requirements/requirements-optional.txt @@ -1,2 +1 @@ pymc -nemos diff --git a/utils/glm/__init__.py b/utils/glm/__init__.py new file mode 100644 index 00000000..884e72ee --- /dev/null +++ b/utils/glm/__init__.py @@ -0,0 +1,9 @@ +""" +GLM-based firing rate estimation module. + +This module provides GLM fitting for electrophysiological spike data using nemos. +Due to dependency conflicts, nemos runs in a separate virtual environment. + +Usage: + python -m blech_clust.utils.glm.infer_glm_rates [options] +""" diff --git a/utils/glm/_glm_extract_data.py b/utils/glm/_glm_extract_data.py new file mode 100644 index 00000000..83ff3b1f --- /dev/null +++ b/utils/glm/_glm_extract_data.py @@ -0,0 +1,118 @@ +""" +Data extraction script for GLM fitting. + +This script runs in the blech_clust conda environment and extracts spike data +using ephys_data. The extracted data is saved as numpy files for the GLM +fitting script to load. + +This script is called by infer_glm_rates.py and should not be run directly. +""" + +import sys +import os +import json +import numpy as np + +def main(): + if len(sys.argv) != 2: + print("Usage: _glm_extract_data.py ") + sys.exit(1) + + params_path = sys.argv[1] + + # Load parameters + with open(params_path, 'r') as f: + params = json.load(f) + + data_dir = params['data_dir'] + temp_dir = params['temp_dir'] + time_lims = params['time_lims'] + bin_size = params['bin_size'] + + # Add blech_clust to path + script_path = os.path.abspath(__file__) + blech_clust_path = os.path.dirname(os.path.dirname(script_path)) + sys.path.insert(0, blech_clust_path) + + # Import blech_clust modules + from blech_clust.utils.blech_utils import imp_metadata, pipeline_graph_check + from blech_clust.utils.ephys_data import ephys_data + + print(f"Extracting data from: {data_dir}") + + # Pipeline check + this_pipeline_check = pipeline_graph_check(data_dir) + this_pipeline_check.check_previous(script_path) + this_pipeline_check.write_to_log(script_path, 'attempted') + + # Load data + data = ephys_data.ephys_data(data_dir) + data.get_spikes() + data.get_region_units() + + # Build region mapping + region_dict = dict(zip(data.region_names, data.region_units)) + region_vec = np.zeros(len(np.concatenate(data.region_units)), dtype=object) + for region_name, unit_list in region_dict.items(): + region_vec[unit_list] = region_name + + n_tastes = len(data.spikes) + n_neurons = data.spikes[0].shape[1] + + print(f" Tastes: {n_tastes}") + print(f" Neurons: {n_neurons}") + print(f" Regions: {data.region_names}") + + # Process and save spike data for each taste + spike_data = {} + for taste_idx, taste_spikes in enumerate(data.spikes): + # Cut to time limits + taste_spikes = taste_spikes[..., time_lims[0]:time_lims[1]] + + # Bin spikes + n_trials, n_neurons_taste, n_time = taste_spikes.shape + n_bins = n_time // bin_size + trimmed = taste_spikes[..., :n_bins * bin_size] + binned = trimmed.reshape(n_trials, n_neurons_taste, n_bins, bin_size).sum(axis=-1) + + spike_data[f'taste_{taste_idx}'] = binned + print(f" Taste {taste_idx}: {binned.shape}") + + # Save extracted data + extracted_data = { + 'spike_data': spike_data, + 'region_names': list(data.region_names), + 'region_units': {name: list(units) for name, units in region_dict.items()}, + 'region_vec': region_vec.tolist(), + 'n_tastes': n_tastes, + 'n_neurons': n_neurons, + 'hdf5_path': data.hdf5_path, + } + + # Save as numpy file + output_file = os.path.join(temp_dir, 'extracted_data.npz') + np.savez(output_file, **{ + 'spike_data_keys': list(spike_data.keys()), + 'region_names': np.array(data.region_names, dtype=object), + 'region_vec': np.array(region_vec, dtype=object), + 'n_tastes': n_tastes, + 'n_neurons': n_neurons, + 'hdf5_path': data.hdf5_path, + }) + + # Save spike arrays separately (npz doesn't handle nested dicts well) + for key, arr in spike_data.items(): + np.save(os.path.join(temp_dir, f'{key}.npy'), arr) + + # Save region units as JSON + with open(os.path.join(temp_dir, 'region_units.json'), 'w') as f: + json.dump({name: list(map(int, units)) for name, units in region_dict.items()}, f) + + print(f"\nData saved to: {temp_dir}") + + # Write to pipeline log + this_pipeline_check.write_to_log(script_path, 'completed') + + +if __name__ == '__main__': + main() diff --git a/utils/glm/_glm_fit_models.py b/utils/glm/_glm_fit_models.py new file mode 100644 index 00000000..2a214d00 --- /dev/null +++ b/utils/glm/_glm_fit_models.py @@ -0,0 +1,475 @@ +""" +GLM fitting script using nemos. + +This script runs in the nemos virtual environment and fits GLM models to +spike data extracted by _glm_extract_data.py. + +This script is called by infer_glm_rates.py and should not be run directly. +""" + +import sys +import os +import json +import numpy as np +import pickle +import pandas as pd +import matplotlib +matplotlib.use('Agg') # Non-interactive backend +import matplotlib.pyplot as plt + +# Import nemos +try: + import nemos as nmo +except ImportError: + print("ERROR: nemos not found. Install with: pip install nemos") + sys.exit(1) + + +############################################################ +# GLM fitting functions +############################################################ + +def compute_bits_per_spike(model, X, y, baseline_rate=None): + """ + Compute bits per spike metric for model comparison. + + Bits per spike measures how much better the model predicts spikes + compared to a baseline (homogeneous Poisson) model. + """ + # Get model predictions + pred_rate = model.predict(X) + + # Compute log-likelihood under model + eps = 1e-10 + ll_model = np.sum(y * np.log(pred_rate + eps) - pred_rate) + + # Compute log-likelihood under baseline (homogeneous Poisson) + if baseline_rate is None: + baseline_rate = np.mean(y, axis=0, keepdims=True) + ll_baseline = np.sum(y * np.log(baseline_rate + eps) - baseline_rate) + + # Bits per spike = (LL_model - LL_baseline) / (n_spikes * log(2)) + n_spikes = np.sum(y) + if n_spikes > 0: + bits_per_spike = (ll_model - ll_baseline) / (n_spikes * np.log(2)) + else: + bits_per_spike = 0.0 + + return bits_per_spike + + +def fit_glm_single_neuron( + binned_spikes, + neuron_idx, + history_window_bins, + n_basis_funcs, + stim_bin=None, + other_neurons=None, + include_coupling=False, +): + """ + Fit GLM for a single neuron. + + Args: + binned_spikes: Binned spike counts (trials, neurons, time) + neuron_idx: Index of target neuron + history_window_bins: History window in bins + n_basis_funcs: Number of basis functions + stim_bin: Stimulus onset bin (optional) + other_neurons: Indices of other neurons for coupling (optional) + include_coupling: Whether to include coupling + + Returns: + model: Fitted GLM + X: Feature matrix + y: Target spike counts + feature_info: Dictionary with feature information + valid_mask: Boolean mask for valid samples + """ + n_trials, n_neurons, n_bins = binned_spikes.shape + + # Reshape to (samples, neurons) by concatenating trials + spikes_flat = binned_spikes.transpose(0, 2, 1).reshape(-1, n_neurons) + + # Target neuron spikes + y = spikes_flat[:, neuron_idx] + + # Build feature matrix + feature_list = [] + feature_info = {'names': [], 'slices': []} + current_idx = 0 + + # 1. Spike history basis for target neuron + history_basis = nmo.basis.RaisedCosineLogConv( + n_basis_funcs, + window_size=history_window_bins + ) + history_features = history_basis.compute_features( + spikes_flat[:, neuron_idx:neuron_idx+1] + ) + feature_list.append(history_features) + feature_info['names'].append('history') + feature_info['slices'].append(slice(current_idx, current_idx + n_basis_funcs)) + current_idx += n_basis_funcs + + # 2. Stimulus features (if provided) + if stim_bin is not None: + n_stim_basis = 5 + stim_window = min(20, history_window_bins) + + # Create stimulus indicator for each trial + stim_indicator = np.zeros((n_trials * n_bins,)) + for trial in range(n_trials): + trial_stim_idx = trial * n_bins + stim_bin + if 0 <= trial_stim_idx < len(stim_indicator): + stim_indicator[trial_stim_idx] = 1.0 + + stim_basis = nmo.basis.RaisedCosineLogConv(n_stim_basis, window_size=stim_window) + stim_features = stim_basis.compute_features(stim_indicator.reshape(-1, 1)) + feature_list.append(stim_features) + feature_info['names'].append('stimulus') + feature_info['slices'].append(slice(current_idx, current_idx + n_stim_basis)) + current_idx += n_stim_basis + + # 3. Coupling from other neurons (if requested) + if include_coupling and other_neurons is not None and len(other_neurons) > 0: + n_coupling_basis = max(2, n_basis_funcs // 2) + coupling_window = max(2, history_window_bins // 2) + coupling_basis = nmo.basis.RaisedCosineLogConv( + n_coupling_basis, + window_size=coupling_window + ) + for other_idx in other_neurons: + coupling_features = coupling_basis.compute_features( + spikes_flat[:, other_idx:other_idx+1] + ) + feature_list.append(coupling_features) + feature_info['names'].append(f'coupling_{other_idx}') + n_coupling_feats = coupling_features.shape[1] + feature_info['slices'].append(slice(current_idx, current_idx + n_coupling_feats)) + current_idx += n_coupling_feats + + # Concatenate all features + X = np.hstack(feature_list) + + # Handle NaN values from convolution edges + valid_mask = ~np.any(np.isnan(X), axis=1) + X_valid = X[valid_mask] + y_valid = y[valid_mask] + + # Fit GLM + model = nmo.glm.GLM( + regularizer=nmo.regularizer.Ridge(regularizer_strength=0.01) + ) + model.fit(X_valid, y_valid) + + return model, X_valid, y_valid, feature_info, valid_mask + + +def predict_firing_rates(model, X, valid_mask, original_shape): + """ + Generate firing rate predictions and reshape to original trial structure. + """ + n_trials, n_bins = original_shape + + # Predict on valid samples + pred_valid = model.predict(X) + + # Reconstruct full array with NaN for invalid samples + pred_full = np.full(n_trials * n_bins, np.nan) + pred_full[valid_mask] = pred_valid + + # Reshape to (trials, time) + pred_rates = pred_full.reshape(n_trials, n_bins) + + return pred_rates + + +############################################################ +# Main +############################################################ + +def main(): + if len(sys.argv) != 2: + print("Usage: _glm_fit_models.py ") + sys.exit(1) + + params_path = sys.argv[1] + + # Load parameters + with open(params_path, 'r') as f: + params = json.load(f) + + data_dir = params['data_dir'] + temp_dir = params['temp_dir'] + output_path = params['output_path'] + bin_size = params['bin_size'] + history_window = params['history_window'] + n_basis_funcs = params['n_basis_funcs'] + time_lims = params['time_lims'] + include_coupling = params['include_coupling'] + separate_tastes = params['separate_tastes'] + separate_regions = params['separate_regions'] + retrain = params['retrain'] + + history_window_bins = history_window // bin_size + + # Stimulus time (typically at 2000ms) + stim_time = 2000 + stim_bin = (stim_time - time_lims[0]) // bin_size + + # Setup directories + artifacts_dir = os.path.join(output_path, 'artifacts') + plots_dir = os.path.join(output_path, 'plots') + os.makedirs(artifacts_dir, exist_ok=True) + os.makedirs(plots_dir, exist_ok=True) + + # Load extracted data + print("Loading extracted spike data...") + extracted = np.load(os.path.join(temp_dir, 'extracted_data.npz'), allow_pickle=True) + + spike_data_keys = extracted['spike_data_keys'] + region_names = list(extracted['region_names']) + region_vec = list(extracted['region_vec']) + n_tastes = int(extracted['n_tastes']) + n_neurons = int(extracted['n_neurons']) + hdf5_path = str(extracted['hdf5_path']) + + # Load region units + with open(os.path.join(temp_dir, 'region_units.json'), 'r') as f: + region_dict = json.load(f) + + # Load spike arrays + spike_data = {} + for key in spike_data_keys: + spike_data[key] = np.load(os.path.join(temp_dir, f'{key}.npy')) + + print(f" Tastes: {n_tastes}") + print(f" Neurons: {n_neurons}") + print(f" Regions: {region_names}") + + # Save parameters + with open(os.path.join(artifacts_dir, 'params.json'), 'w') as f: + json.dump(params, f, indent=2) + + # Results storage + results = { + 'pred_firing': [], + 'binned_spikes': [], + 'bits_per_spike': [], + 'taste_idx': [], + 'region_name': [], + 'neuron_idx': [], + } + + # Process each taste + for taste_idx in range(n_tastes): + print(f"\n=== Processing taste {taste_idx} ===") + + binned = spike_data[f'taste_{taste_idx}'] + n_trials, n_neurons_taste, n_bins = binned.shape + + print(f" Trials: {n_trials}, Neurons: {n_neurons_taste}, Bins: {n_bins}") + + # Determine which neurons to process + if separate_regions: + regions_to_process = [(name, units) for name, units in region_dict.items() + if name.lower() != 'none'] + else: + regions_to_process = [('all', list(range(n_neurons_taste)))] + + for region_name, region_units in regions_to_process: + print(f" Region: {region_name} ({len(region_units)} neurons)") + + for neuron_idx in region_units: + # Determine coupling neurons + if include_coupling: + other_neurons = [i for i in region_units if i != neuron_idx] + else: + other_neurons = None + + # Model save path + model_name = f'taste_{taste_idx}_region_{region_name}_neuron_{neuron_idx}' + model_path = os.path.join(artifacts_dir, f'{model_name}.pkl') + + if os.path.exists(model_path) and not retrain: + print(f" Loading existing model for neuron {neuron_idx}") + with open(model_path, 'rb') as f: + saved = pickle.load(f) + model = saved['model'] + X_valid = saved['X'] + y_valid = saved['y'] + feature_info = saved['feature_info'] + valid_mask = saved['valid_mask'] + else: + print(f" Fitting GLM for neuron {neuron_idx}") + model, X_valid, y_valid, feature_info, valid_mask = fit_glm_single_neuron( + binned, + neuron_idx, + history_window_bins, + n_basis_funcs, + stim_bin=stim_bin, + other_neurons=other_neurons, + include_coupling=include_coupling, + ) + + # Save model + with open(model_path, 'wb') as f: + pickle.dump({ + 'model': model, + 'X': X_valid, + 'y': y_valid, + 'feature_info': feature_info, + 'valid_mask': valid_mask, + }, f) + + # Predict firing rates + pred_rates = predict_firing_rates( + model, X_valid, valid_mask, (n_trials, n_bins) + ) + + # Compute bits per spike + bps = compute_bits_per_spike(model, X_valid, y_valid) + + # Store results + results['pred_firing'].append(pred_rates) + results['binned_spikes'].append(binned[:, neuron_idx, :]) + results['bits_per_spike'].append(bps) + results['taste_idx'].append(taste_idx) + results['region_name'].append(region_name) + results['neuron_idx'].append(neuron_idx) + + ############################################################ + # Generate plots + ############################################################ + + print("\n=== Generating plots ===") + + # Bits per spike summary + bps_df = pd.DataFrame({ + 'taste': results['taste_idx'], + 'region': results['region_name'], + 'neuron': results['neuron_idx'], + 'bits_per_spike': results['bits_per_spike'], + }) + + # Save summary + bps_df.to_csv(os.path.join(output_path, 'bits_per_spike_summary.csv'), index=False) + + # Plot bits per spike distribution + fig, ax = plt.subplots(figsize=(10, 6)) + for region in bps_df['region'].unique(): + region_bps = bps_df[bps_df['region'] == region]['bits_per_spike'] + ax.hist(region_bps, alpha=0.5, label=region, bins=20) + ax.set_xlabel('Bits per Spike') + ax.set_ylabel('Count') + ax.set_title('GLM Model Performance: Bits per Spike') + ax.legend() + fig.savefig(os.path.join(plots_dir, 'bits_per_spike_distribution.png'), dpi=150) + plt.close(fig) + + # Plot mean bits per spike by taste and region + fig, ax = plt.subplots(figsize=(10, 6)) + bps_summary = bps_df.groupby(['taste', 'region'])['bits_per_spike'].mean().unstack() + bps_summary.plot(kind='bar', ax=ax) + ax.set_xlabel('Taste') + ax.set_ylabel('Mean Bits per Spike') + ax.set_title('GLM Performance by Taste and Region') + ax.legend(title='Region') + fig.savefig(os.path.join(plots_dir, 'bits_per_spike_by_taste_region.png'), dpi=150) + plt.close(fig) + + # Plot example neuron predictions + ind_plot_dir = os.path.join(plots_dir, 'individual_neurons') + os.makedirs(ind_plot_dir, exist_ok=True) + + print("Plotting individual neuron predictions...") + for i in range(min(10, len(results['pred_firing']))): + taste_idx = results['taste_idx'][i] + region = results['region_name'][i] + neuron_idx = results['neuron_idx'][i] + pred = results['pred_firing'][i] + binned = results['binned_spikes'][i] + bps = results['bits_per_spike'][i] + + fig, axes = plt.subplots(3, 1, figsize=(12, 8), sharex=True) + + # Raster of binned spikes + axes[0].imshow(binned, aspect='auto', cmap='Greys', interpolation='none') + axes[0].set_ylabel('Trial') + axes[0].set_title(f'Binned Spikes - Taste {taste_idx}, {region}, Neuron {neuron_idx}') + + # Predicted rates + axes[1].imshow(pred, aspect='auto', cmap='viridis', interpolation='none') + axes[1].set_ylabel('Trial') + axes[1].set_title(f'GLM Predicted Rates (bits/spike: {bps:.3f})') + + # Mean comparison + time_bins = np.arange(pred.shape[1]) * bin_size + time_lims[0] + axes[2].plot(time_bins, np.nanmean(binned, axis=0), 'k-', label='Binned', alpha=0.7) + axes[2].plot(time_bins, np.nanmean(pred, axis=0), 'r-', label='GLM Pred', alpha=0.7) + axes[2].axvline(stim_time, color='b', linestyle='--', label='Stimulus') + axes[2].set_xlabel('Time (ms)') + axes[2].set_ylabel('Firing Rate') + axes[2].legend() + axes[2].set_title('Mean Firing Rate Comparison') + + plt.tight_layout() + fig.savefig(os.path.join(ind_plot_dir, f'taste_{taste_idx}_{region}_neuron_{neuron_idx}.png'), dpi=150) + plt.close(fig) + + ############################################################ + # Write results to HDF5 + ############################################################ + + print("\n=== Writing results to HDF5 ===") + + # Import tables here to write HDF5 + import tables + + with tables.open_file(hdf5_path, 'r+') as hf5: + # Remove existing GLM output if present + if '/glm_output' in hf5: + hf5.remove_node('/glm_output', recursive=True) + + # Create GLM output group + hf5.create_group('/', 'glm_output', 'GLM-based firing rate estimates') + glm_grp = hf5.get_node('/glm_output') + + # Store parameters + hf5.create_array(glm_grp, 'bin_size', np.array([bin_size])) + hf5.create_array(glm_grp, 'history_window', np.array([history_window])) + hf5.create_array(glm_grp, 'time_lims', np.array(time_lims)) + + # Create regions group + hf5.create_group('/glm_output', 'regions', 'Region-specific GLM output') + regions_grp = hf5.get_node('/glm_output/regions') + + # Organize results by taste and region + for i in range(len(results['pred_firing'])): + taste_idx = results['taste_idx'][i] + region = results['region_name'][i] + neuron_idx = results['neuron_idx'][i] + + group_name = f'taste_{taste_idx}_region_{region}_neuron_{neuron_idx}' + + neuron_grp = hf5.create_group(regions_grp, group_name, + f'Taste {taste_idx}, Region {region}, Neuron {neuron_idx}') + + hf5.create_array(neuron_grp, 'pred_firing', results['pred_firing'][i]) + hf5.create_array(neuron_grp, 'binned_spikes', results['binned_spikes'][i]) + hf5.create_array(neuron_grp, 'bits_per_spike', np.array([results['bits_per_spike'][i]])) + + print(f"\nResults saved to {hdf5_path}") + print(f"Plots saved to {plots_dir}") + print(f"Summary saved to {os.path.join(output_path, 'bits_per_spike_summary.csv')}") + + # Print summary statistics + print("\n=== Summary ===") + print(f"Total neurons processed: {len(results['pred_firing'])}") + print(f"Mean bits per spike: {np.mean(results['bits_per_spike']):.4f}") + print(f"Std bits per spike: {np.std(results['bits_per_spike']):.4f}") + + +if __name__ == '__main__': + main() diff --git a/utils/glm/infer_glm_rates.py b/utils/glm/infer_glm_rates.py new file mode 100644 index 00000000..aa19c240 --- /dev/null +++ b/utils/glm/infer_glm_rates.py @@ -0,0 +1,232 @@ +""" +GLM-based firing rate estimation using nemos. + +This script orchestrates GLM fitting across two Python environments: +1. blech_clust conda environment: Extracts spike data using ephys_data +2. nemos virtual environment: Fits GLM models + +The script automatically detects environments and saves intermediate data files +for cross-environment communication. + +Usage: + python infer_glm_rates.py [options] + +Options: + --bin_size: Bin size in ms for spike binning (default: 25) + --history_window: History window in ms for autoregressive effects (default: 250) + --n_basis_funcs: Number of basis functions for history filter (default: 8) + --time_lims: Time limits for analysis [start, end] in ms (default: [1500, 4500]) + --include_coupling: Include coupling between neurons (default: False) + --separate_tastes: Fit separate models for each taste (default: False) + --separate_regions: Fit separate models for each region (default: False) + --retrain: Force retraining even if model exists (default: False) + --blech_clust_env: Name of blech_clust conda environment (default: blech_clust) + --nemos_env: Path to nemos venv Python interpreter (default: searches common locations) +""" + +import argparse +import os +import sys +import subprocess +import json +import shutil + +############################################################ +# Argument parsing +############################################################ + +def parse_args(): + parser = argparse.ArgumentParser( + description='Infer firing rates using GLM (nemos)') + parser.add_argument('data_dir', help='Path to data directory') + parser.add_argument('--bin_size', type=int, default=25, + help='Bin size in ms for spike binning (default: %(default)s)') + parser.add_argument('--history_window', type=int, default=250, + help='History window in ms for autoregressive effects (default: %(default)s)') + parser.add_argument('--n_basis_funcs', type=int, default=8, + help='Number of basis functions for history filter (default: %(default)s)') + parser.add_argument('--time_lims', type=int, nargs=2, default=[1500, 4500], + help='Time limits for analysis [start, end] in ms (default: %(default)s)') + parser.add_argument('--include_coupling', action='store_true', + help='Include coupling between neurons (default: %(default)s)') + parser.add_argument('--separate_tastes', action='store_true', + help='Fit separate models for each taste (default: %(default)s)') + parser.add_argument('--separate_regions', action='store_true', + help='Fit separate models for each region (default: %(default)s)') + parser.add_argument('--retrain', action='store_true', + help='Force retraining of model (default: %(default)s)') + parser.add_argument('--blech_clust_env', type=str, default='blech_clust', + help='Name of blech_clust conda environment (default: %(default)s)') + parser.add_argument('--nemos_env', type=str, default=None, + help='Path to nemos venv Python interpreter') + return parser.parse_args() + + +############################################################ +# Environment detection +############################################################ + +def find_conda_env(env_name): + """Find conda environment by name and return its Python path.""" + try: + result = subprocess.run( + ['conda', 'run', '-n', env_name, 'which', 'python'], + capture_output=True, text=True, timeout=30 + ) + if result.returncode == 0: + return result.stdout.strip() + except (subprocess.TimeoutExpired, FileNotFoundError): + pass + + # Try common conda locations + home = os.path.expanduser('~') + common_paths = [ + f'{home}/anaconda3/envs/{env_name}/bin/python', + f'{home}/miniconda3/envs/{env_name}/bin/python', + f'{home}/miniforge3/envs/{env_name}/bin/python', + f'/opt/conda/envs/{env_name}/bin/python', + ] + for path in common_paths: + if os.path.exists(path): + return path + return None + + +def find_nemos_env(): + """Find nemos virtual environment Python interpreter.""" + home = os.path.expanduser('~') + common_paths = [ + f'{home}/nemos_env/bin/python', + f'{home}/venvs/nemos/bin/python', + f'{home}/.venvs/nemos/bin/python', + f'{home}/envs/nemos/bin/python', + f'{home}/.local/share/virtualenvs/nemos/bin/python', + ] + for path in common_paths: + if os.path.exists(path): + # Verify nemos is installed + try: + result = subprocess.run( + [path, '-c', 'import nemos'], + capture_output=True, timeout=10 + ) + if result.returncode == 0: + return path + except subprocess.TimeoutExpired: + pass + return None + + +def prompt_for_env(env_type, default_name=None): + """Prompt user for environment path.""" + print(f"\n{env_type} environment not found automatically.") + if default_name: + print(f"Expected conda environment name: {default_name}") + + response = input(f"Enter path to {env_type} Python interpreter (or 'q' to quit): ").strip() + if response.lower() == 'q': + sys.exit(1) + if os.path.exists(response): + return response + else: + print(f"Path not found: {response}") + return prompt_for_env(env_type, default_name) + + +############################################################ +# Main orchestration +############################################################ + +def main(): + args = parse_args() + + script_dir = os.path.dirname(os.path.abspath(__file__)) + data_dir = os.path.abspath(args.data_dir) + + # Setup output directories + output_path = os.path.join(data_dir, 'glm_output') + # Use temp dir within the glm module directory + temp_dir = os.path.join(script_dir, '_temp') + os.makedirs(temp_dir, exist_ok=True) + os.makedirs(output_path, exist_ok=True) + + print("=" * 60) + print("GLM-based Firing Rate Estimation") + print("=" * 60) + print(f"Data directory: {data_dir}") + + # Find blech_clust environment + print("\n[1/4] Locating blech_clust environment...") + blech_python = find_conda_env(args.blech_clust_env) + if blech_python is None: + blech_python = prompt_for_env('blech_clust', args.blech_clust_env) + print(f" Found: {blech_python}") + + # Find nemos environment + print("\n[2/4] Locating nemos environment...") + if args.nemos_env: + nemos_python = args.nemos_env + else: + nemos_python = find_nemos_env() + if nemos_python is None: + nemos_python = prompt_for_env('nemos') + print(f" Found: {nemos_python}") + + # Save parameters for sub-scripts + params = { + 'data_dir': data_dir, + 'bin_size': args.bin_size, + 'history_window': args.history_window, + 'n_basis_funcs': args.n_basis_funcs, + 'time_lims': args.time_lims, + 'include_coupling': args.include_coupling, + 'separate_tastes': args.separate_tastes, + 'separate_regions': args.separate_regions, + 'retrain': args.retrain, + 'temp_dir': temp_dir, + 'output_path': output_path, + } + params_path = os.path.join(temp_dir, 'glm_params.json') + with open(params_path, 'w') as f: + json.dump(params, f, indent=2) + + # Step 1: Extract data using blech_clust environment + print("\n[3/4] Extracting spike data (blech_clust environment)...") + extract_script = os.path.join(script_dir, '_glm_extract_data.py') + + result = subprocess.run( + [blech_python, extract_script, params_path], + capture_output=True, text=True + ) + if result.returncode != 0: + print("ERROR: Data extraction failed") + print(result.stderr) + sys.exit(1) + print(result.stdout) + + # Step 2: Fit GLM using nemos environment + print("\n[4/4] Fitting GLM models (nemos environment)...") + fit_script = os.path.join(script_dir, '_glm_fit_models.py') + + result = subprocess.run( + [nemos_python, fit_script, params_path], + capture_output=True, text=True + ) + if result.returncode != 0: + print("ERROR: GLM fitting failed") + print(result.stderr) + sys.exit(1) + print(result.stdout) + + # Cleanup temp files + print("\nCleaning up temporary files...") + shutil.rmtree(temp_dir) + + print("\n" + "=" * 60) + print("GLM fitting complete!") + print(f"Results saved to: {output_path}") + print("=" * 60) + + +if __name__ == '__main__': + main() diff --git a/utils/infer_glm_rates.py b/utils/infer_glm_rates.py deleted file mode 100644 index 61a0a23d..00000000 --- a/utils/infer_glm_rates.py +++ /dev/null @@ -1,629 +0,0 @@ -""" -GLM-based firing rate estimation using nemos. - -This module fits Generalized Linear Models (GLMs) to electrophysiological spike data -to infer firing rates. The GLM can incorporate: -- Spike history (autoregressive effects) -- Stimulus timing -- Coupling between simultaneously recorded neurons - -Model comparison is provided via bits per spike metric against binned rates. - -Usage: - python infer_glm_rates.py [options] - -Options: - --bin_size: Bin size in ms for spike binning (default: 25) - --history_window: History window in ms for autoregressive effects (default: 250) - --n_basis_funcs: Number of basis functions for history filter (default: 8) - --time_lims: Time limits for analysis [start, end] in ms (default: [1500, 4500]) - --include_coupling: Include coupling between neurons (default: False) - --separate_tastes: Fit separate models for each taste (default: False) - --separate_regions: Fit separate models for each region (default: False) - --retrain: Force retraining even if model exists (default: False) -""" - -import argparse -import os - -test_mode = False -if test_mode: - print('====================') - print('Running in test mode') - print('====================') - data_dir = '/home/abuzarmahmood/projects/blech_clust/pipeline_testing/test_data_handling/test_data/KM45_5tastes_210620_113227_new' - script_path = '/home/abuzarmahmood/projects/blech_clust/utils/infer_glm_rates.py' - blech_clust_path = os.path.dirname(os.path.dirname(script_path)) - args = argparse.Namespace( - data_dir=data_dir, - bin_size=25, - history_window=250, - n_basis_funcs=8, - time_lims=[1500, 4500], - include_coupling=False, - separate_tastes=True, - separate_regions=True, - retrain=False, - ) -else: - parser = argparse.ArgumentParser( - description='Infer firing rates using GLM (nemos)') - parser.add_argument('data_dir', help='Path to data directory') - parser.add_argument('--bin_size', type=int, default=25, - help='Bin size in ms for spike binning (default: %(default)s)') - parser.add_argument('--history_window', type=int, default=250, - help='History window in ms for autoregressive effects (default: %(default)s)') - parser.add_argument('--n_basis_funcs', type=int, default=8, - help='Number of basis functions for history filter (default: %(default)s)') - parser.add_argument('--time_lims', type=int, nargs=2, default=[1500, 4500], - help='Time limits for analysis [start, end] in ms (default: %(default)s)') - parser.add_argument('--include_coupling', action='store_true', - help='Include coupling between neurons (default: %(default)s)') - parser.add_argument('--separate_tastes', action='store_true', - help='Fit separate models for each taste (default: %(default)s)') - parser.add_argument('--separate_regions', action='store_true', - help='Fit separate models for each region (default: %(default)s)') - parser.add_argument('--retrain', action='store_true', - help='Force retraining of model (default: %(default)s)') - - args = parser.parse_args() - data_dir = args.data_dir - script_path = os.path.abspath(__file__) - blech_clust_path = os.path.dirname(os.path.dirname(script_path)) - -############################################################ -############################################################ - -import tables -import matplotlib.pyplot as plt -import numpy as np -import sys -from pprint import pprint -import json -from itertools import product -import pandas as pd -import pickle - -sys.path.append(blech_clust_path) - -from blech_clust.utils.blech_utils import imp_metadata, pipeline_graph_check -from blech_clust.utils.ephys_data import visualize as vz -from blech_clust.utils.ephys_data import ephys_data - -# Import nemos for GLM fitting -try: - import nemos as nmo -except ImportError: - raise ImportError( - 'nemos is required for GLM fitting. Install with: pip install nemos' - ) - -############################################################ -############################################################ - - -def compute_bits_per_spike(model, X, y, baseline_rate=None): - """ - Compute bits per spike metric for model comparison. - - Bits per spike measures how much better the model predicts spikes - compared to a baseline (homogeneous Poisson) model. - - Args: - model: Fitted nemos GLM model - X: Feature matrix - y: Spike counts - baseline_rate: Baseline firing rate (if None, uses mean rate) - - Returns: - bits_per_spike: Information gain in bits per spike - """ - # Get model predictions - pred_rate = model.predict(X) - - # Compute log-likelihood under model - # For Poisson: LL = y * log(rate) - rate - log(y!) - # We ignore the log(y!) term as it cancels in comparison - eps = 1e-10 - ll_model = np.sum(y * np.log(pred_rate + eps) - pred_rate) - - # Compute log-likelihood under baseline (homogeneous Poisson) - if baseline_rate is None: - baseline_rate = np.mean(y, axis=0, keepdims=True) - ll_baseline = np.sum(y * np.log(baseline_rate + eps) - baseline_rate) - - # Bits per spike = (LL_model - LL_baseline) / (n_spikes * log(2)) - n_spikes = np.sum(y) - if n_spikes > 0: - bits_per_spike = (ll_model - ll_baseline) / (n_spikes * np.log(2)) - else: - bits_per_spike = 0.0 - - return bits_per_spike - - -def bin_spikes(spike_data, bin_size): - """ - Bin spike data. - - Args: - spike_data: Array of shape (trials, neurons, time) - bin_size: Bin size in samples - - Returns: - binned: Array of shape (trials, neurons, n_bins) - """ - n_trials, n_neurons, n_time = spike_data.shape - n_bins = n_time // bin_size - trimmed = spike_data[..., :n_bins * bin_size] - binned = trimmed.reshape(n_trials, n_neurons, n_bins, bin_size).sum(axis=-1) - return binned - - -def create_stimulus_feature(n_samples, stim_bin, n_basis=5, window_size=10): - """ - Create stimulus feature using raised cosine basis. - - Args: - n_samples: Number of time samples - stim_bin: Bin index of stimulus onset - n_basis: Number of basis functions - window_size: Window size for basis - - Returns: - stim_features: Stimulus feature matrix - """ - # Create stimulus indicator - stim = np.zeros(n_samples) - if 0 <= stim_bin < n_samples: - stim[stim_bin] = 1.0 - - # Convolve with raised cosine basis - basis = nmo.basis.RaisedCosineLogConv(n_basis, window_size=window_size) - stim_features = basis.compute_features(stim.reshape(-1, 1)) - - return stim_features - - -def fit_glm_single_neuron( - binned_spikes, - neuron_idx, - history_window_bins, - n_basis_funcs, - stim_bin=None, - other_neurons=None, - include_coupling=False, -): - """ - Fit GLM for a single neuron. - - Args: - binned_spikes: Binned spike counts (trials, neurons, time) - neuron_idx: Index of target neuron - history_window_bins: History window in bins - n_basis_funcs: Number of basis functions - stim_bin: Stimulus onset bin (optional) - other_neurons: Indices of other neurons for coupling (optional) - include_coupling: Whether to include coupling - - Returns: - model: Fitted GLM - X: Feature matrix - y: Target spike counts - feature_info: Dictionary with feature information - """ - n_trials, n_neurons, n_bins = binned_spikes.shape - - # Reshape to (samples, neurons) by concatenating trials - spikes_flat = binned_spikes.transpose(0, 2, 1).reshape(-1, n_neurons) - - # Target neuron spikes - y = spikes_flat[:, neuron_idx] - - # Build feature matrix - feature_list = [] - feature_info = {'names': [], 'slices': []} - current_idx = 0 - - # 1. Spike history basis for target neuron - history_basis = nmo.basis.RaisedCosineLogConv( - n_basis_funcs, - window_size=history_window_bins - ) - history_features = history_basis.compute_features( - spikes_flat[:, neuron_idx:neuron_idx+1] - ) - feature_list.append(history_features) - feature_info['names'].append('history') - feature_info['slices'].append(slice(current_idx, current_idx + n_basis_funcs)) - current_idx += n_basis_funcs - - # 2. Stimulus features (if provided) - if stim_bin is not None: - n_stim_basis = 5 - stim_window = min(20, history_window_bins) - - # Create stimulus indicator for each trial - stim_indicator = np.zeros((n_trials * n_bins,)) - for trial in range(n_trials): - trial_stim_idx = trial * n_bins + stim_bin - if 0 <= trial_stim_idx < len(stim_indicator): - stim_indicator[trial_stim_idx] = 1.0 - - stim_basis = nmo.basis.RaisedCosineLogConv(n_stim_basis, window_size=stim_window) - stim_features = stim_basis.compute_features(stim_indicator.reshape(-1, 1)) - feature_list.append(stim_features) - feature_info['names'].append('stimulus') - feature_info['slices'].append(slice(current_idx, current_idx + n_stim_basis)) - current_idx += n_stim_basis - - # 3. Coupling from other neurons (if requested) - if include_coupling and other_neurons is not None and len(other_neurons) > 0: - coupling_basis = nmo.basis.RaisedCosineLogConv( - n_basis_funcs // 2, # Fewer basis functions for coupling - window_size=history_window_bins // 2 - ) - for other_idx in other_neurons: - coupling_features = coupling_basis.compute_features( - spikes_flat[:, other_idx:other_idx+1] - ) - feature_list.append(coupling_features) - feature_info['names'].append(f'coupling_{other_idx}') - n_coupling_feats = coupling_features.shape[1] - feature_info['slices'].append(slice(current_idx, current_idx + n_coupling_feats)) - current_idx += n_coupling_feats - - # Concatenate all features - X = np.hstack(feature_list) - - # Handle NaN values from convolution edges - valid_mask = ~np.any(np.isnan(X), axis=1) - X_valid = X[valid_mask] - y_valid = y[valid_mask] - - # Fit GLM - model = nmo.glm.GLM( - regularizer=nmo.regularizer.Ridge(regularizer_strength=0.01) - ) - model.fit(X_valid, y_valid) - - return model, X_valid, y_valid, feature_info, valid_mask - - -def predict_firing_rates(model, X, valid_mask, original_shape): - """ - Generate firing rate predictions and reshape to original trial structure. - - Args: - model: Fitted GLM - X: Feature matrix (valid samples only) - valid_mask: Boolean mask for valid samples - original_shape: Original shape (n_trials, n_bins) - - Returns: - pred_rates: Predicted firing rates (n_trials, n_bins) - """ - n_trials, n_bins = original_shape - - # Predict on valid samples - pred_valid = model.predict(X) - - # Reconstruct full array with NaN for invalid samples - pred_full = np.full(n_trials * n_bins, np.nan) - pred_full[valid_mask] = pred_valid - - # Reshape to (trials, time) - pred_rates = pred_full.reshape(n_trials, n_bins) - - return pred_rates - - -############################################################ -############################################################ - -if not test_mode: - metadata_handler = imp_metadata([[], args.data_dir]) - this_pipeline_check = pipeline_graph_check(args.data_dir) - this_pipeline_check.check_previous(script_path) - this_pipeline_check.write_to_log(script_path, 'attempted') - -# Setup output directories -output_path = os.path.join(data_dir, 'glm_output') -artifacts_dir = os.path.join(output_path, 'artifacts') -plots_dir = os.path.join(output_path, 'plots') - -for dir_path in [output_path, artifacts_dir, plots_dir]: - if not os.path.exists(dir_path): - os.makedirs(dir_path) - -print(f'Processing data from {data_dir}') -print(f'Parameters:') -print(f' bin_size: {args.bin_size} ms') -print(f' history_window: {args.history_window} ms') -print(f' n_basis_funcs: {args.n_basis_funcs}') -print(f' time_lims: {args.time_lims}') -print(f' include_coupling: {args.include_coupling}') -print(f' separate_tastes: {args.separate_tastes}') -print(f' separate_regions: {args.separate_regions}') - -# Save parameters -params_dict = { - 'bin_size': args.bin_size, - 'history_window': args.history_window, - 'n_basis_funcs': args.n_basis_funcs, - 'time_lims': args.time_lims, - 'include_coupling': args.include_coupling, - 'separate_tastes': args.separate_tastes, - 'separate_regions': args.separate_regions, -} -with open(os.path.join(artifacts_dir, 'params.json'), 'w') as f: - json.dump(params_dict, f, indent=4) - -############################################################ -# Load data -############################################################ - -basename = os.path.basename(data_dir) -data = ephys_data.ephys_data(data_dir) -data.get_spikes() -data.get_region_units() - -# Build region mapping -region_dict = dict(zip(data.region_names, data.region_units)) -region_vec = np.zeros(len(np.concatenate(data.region_units)), dtype=object) -for region_name, unit_list in region_dict.items(): - region_vec[unit_list] = region_name - -n_tastes = len(data.spikes) -n_neurons = data.spikes[0].shape[1] - -print(f'Loaded {n_tastes} tastes, {n_neurons} neurons') -print(f'Regions: {data.region_names}') - -############################################################ -# Process data -############################################################ - -bin_size = args.bin_size -history_window = args.history_window -history_window_bins = history_window // bin_size -n_basis_funcs = args.n_basis_funcs -time_lims = args.time_lims - -# Stimulus time (typically at 2000ms) -stim_time = 2000 -stim_bin = (stim_time - time_lims[0]) // bin_size - -# Determine processing groups -group_by_list = [] -if args.separate_tastes: - group_by_list.append('taste') -if args.separate_regions: - group_by_list.append('region') - -# Results storage -results = { - 'pred_firing': [], - 'binned_spikes': [], - 'bits_per_spike': [], - 'models': [], - 'taste_idx': [], - 'region_name': [], - 'neuron_idx': [], -} - -# Process each taste -for taste_idx, taste_spikes in enumerate(data.spikes): - print(f'\n=== Processing taste {taste_idx} ===') - - # Cut to time limits - taste_spikes = taste_spikes[..., time_lims[0]:time_lims[1]] - - # Bin spikes - binned = bin_spikes(taste_spikes, bin_size) - n_trials, n_neurons_taste, n_bins = binned.shape - - print(f' Trials: {n_trials}, Neurons: {n_neurons_taste}, Bins: {n_bins}') - - # Determine which neurons to process - if args.separate_regions: - regions_to_process = [(name, units) for name, units in region_dict.items() - if name.lower() != 'none'] - else: - regions_to_process = [('all', list(range(n_neurons_taste)))] - - for region_name, region_units in regions_to_process: - print(f' Region: {region_name} ({len(region_units)} neurons)') - - for neuron_idx in region_units: - # Determine coupling neurons - if args.include_coupling: - other_neurons = [i for i in region_units if i != neuron_idx] - else: - other_neurons = None - - # Model save path - model_name = f'taste_{taste_idx}_region_{region_name}_neuron_{neuron_idx}' - model_path = os.path.join(artifacts_dir, f'{model_name}.pkl') - - if os.path.exists(model_path) and not args.retrain: - print(f' Loading existing model for neuron {neuron_idx}') - with open(model_path, 'rb') as f: - saved = pickle.load(f) - model = saved['model'] - X_valid = saved['X'] - y_valid = saved['y'] - feature_info = saved['feature_info'] - valid_mask = saved['valid_mask'] - else: - print(f' Fitting GLM for neuron {neuron_idx}') - model, X_valid, y_valid, feature_info, valid_mask = fit_glm_single_neuron( - binned, - neuron_idx, - history_window_bins, - n_basis_funcs, - stim_bin=stim_bin, - other_neurons=other_neurons, - include_coupling=args.include_coupling, - ) - - # Save model - with open(model_path, 'wb') as f: - pickle.dump({ - 'model': model, - 'X': X_valid, - 'y': y_valid, - 'feature_info': feature_info, - 'valid_mask': valid_mask, - }, f) - - # Predict firing rates - pred_rates = predict_firing_rates( - model, X_valid, valid_mask, (n_trials, n_bins) - ) - - # Compute bits per spike - bps = compute_bits_per_spike(model, X_valid, y_valid) - - # Store results - results['pred_firing'].append(pred_rates) - results['binned_spikes'].append(binned[:, neuron_idx, :]) - results['bits_per_spike'].append(bps) - results['models'].append(model) - results['taste_idx'].append(taste_idx) - results['region_name'].append(region_name) - results['neuron_idx'].append(neuron_idx) - -############################################################ -# Generate plots -############################################################ - -print('\n=== Generating plots ===') - -# Bits per spike summary -bps_df = pd.DataFrame({ - 'taste': results['taste_idx'], - 'region': results['region_name'], - 'neuron': results['neuron_idx'], - 'bits_per_spike': results['bits_per_spike'], -}) - -# Save summary -bps_df.to_csv(os.path.join(output_path, 'bits_per_spike_summary.csv'), index=False) - -# Plot bits per spike distribution -fig, ax = plt.subplots(figsize=(10, 6)) -for region in bps_df['region'].unique(): - region_bps = bps_df[bps_df['region'] == region]['bits_per_spike'] - ax.hist(region_bps, alpha=0.5, label=region, bins=20) -ax.set_xlabel('Bits per Spike') -ax.set_ylabel('Count') -ax.set_title('GLM Model Performance: Bits per Spike') -ax.legend() -fig.savefig(os.path.join(plots_dir, 'bits_per_spike_distribution.png'), dpi=150) -plt.close(fig) - -# Plot mean bits per spike by taste and region -fig, ax = plt.subplots(figsize=(10, 6)) -bps_summary = bps_df.groupby(['taste', 'region'])['bits_per_spike'].mean().unstack() -bps_summary.plot(kind='bar', ax=ax) -ax.set_xlabel('Taste') -ax.set_ylabel('Mean Bits per Spike') -ax.set_title('GLM Performance by Taste and Region') -ax.legend(title='Region') -fig.savefig(os.path.join(plots_dir, 'bits_per_spike_by_taste_region.png'), dpi=150) -plt.close(fig) - -# Plot example neuron predictions -ind_plot_dir = os.path.join(plots_dir, 'individual_neurons') -if not os.path.exists(ind_plot_dir): - os.makedirs(ind_plot_dir) - -print('Plotting individual neuron predictions...') -for i in range(min(10, len(results['pred_firing']))): # Plot first 10 neurons - taste_idx = results['taste_idx'][i] - region = results['region_name'][i] - neuron_idx = results['neuron_idx'][i] - pred = results['pred_firing'][i] - binned = results['binned_spikes'][i] - bps = results['bits_per_spike'][i] - - fig, axes = plt.subplots(3, 1, figsize=(12, 8), sharex=True) - - # Raster of binned spikes - axes[0].imshow(binned, aspect='auto', cmap='Greys', interpolation='none') - axes[0].set_ylabel('Trial') - axes[0].set_title(f'Binned Spikes - Taste {taste_idx}, {region}, Neuron {neuron_idx}') - - # Predicted rates - axes[1].imshow(pred, aspect='auto', cmap='viridis', interpolation='none') - axes[1].set_ylabel('Trial') - axes[1].set_title(f'GLM Predicted Rates (bits/spike: {bps:.3f})') - - # Mean comparison - time_bins = np.arange(pred.shape[1]) * bin_size + time_lims[0] - axes[2].plot(time_bins, np.nanmean(binned, axis=0), 'k-', label='Binned', alpha=0.7) - axes[2].plot(time_bins, np.nanmean(pred, axis=0), 'r-', label='GLM Pred', alpha=0.7) - axes[2].axvline(stim_time, color='b', linestyle='--', label='Stimulus') - axes[2].set_xlabel('Time (ms)') - axes[2].set_ylabel('Firing Rate') - axes[2].legend() - axes[2].set_title('Mean Firing Rate Comparison') - - plt.tight_layout() - fig.savefig(os.path.join(ind_plot_dir, f'taste_{taste_idx}_{region}_neuron_{neuron_idx}.png'), dpi=150) - plt.close(fig) - -############################################################ -# Write results to HDF5 -############################################################ - -print('\n=== Writing results to HDF5 ===') - -hdf5_path = data.hdf5_path -with tables.open_file(hdf5_path, 'r+') as hf5: - # Remove existing GLM output if present - if '/glm_output' in hf5: - hf5.remove_node('/glm_output', recursive=True) - - # Create GLM output group - hf5.create_group('/', 'glm_output', 'GLM-based firing rate estimates') - glm_grp = hf5.get_node('/glm_output') - - # Store parameters - hf5.create_array(glm_grp, 'bin_size', np.array([bin_size])) - hf5.create_array(glm_grp, 'history_window', np.array([history_window])) - hf5.create_array(glm_grp, 'time_lims', np.array(time_lims)) - - # Create regions group - hf5.create_group('/glm_output', 'regions', 'Region-specific GLM output') - regions_grp = hf5.get_node('/glm_output/regions') - - # Organize results by taste and region - for i in range(len(results['pred_firing'])): - taste_idx = results['taste_idx'][i] - region = results['region_name'][i] - neuron_idx = results['neuron_idx'][i] - - group_name = f'taste_{taste_idx}_region_{region}_neuron_{neuron_idx}' - - neuron_grp = hf5.create_group(regions_grp, group_name, - f'Taste {taste_idx}, Region {region}, Neuron {neuron_idx}') - - hf5.create_array(neuron_grp, 'pred_firing', results['pred_firing'][i]) - hf5.create_array(neuron_grp, 'binned_spikes', results['binned_spikes'][i]) - hf5.create_array(neuron_grp, 'bits_per_spike', np.array([results['bits_per_spike'][i]])) - -print(f'\nResults saved to {hdf5_path}') -print(f'Plots saved to {plots_dir}') -print(f'Summary saved to {os.path.join(output_path, "bits_per_spike_summary.csv")}') - -# Print summary statistics -print('\n=== Summary ===') -print(f'Total neurons processed: {len(results["pred_firing"])}') -print(f'Mean bits per spike: {np.mean(results["bits_per_spike"]):.4f}') -print(f'Std bits per spike: {np.std(results["bits_per_spike"]):.4f}') - -# Write successful execution to log -if not test_mode: - this_pipeline_check.write_to_log(script_path, 'completed') From d2b18200c40c73f709af711a9b59aeb6ebf6a006 Mon Sep 17 00:00:00 2001 From: Abuzar Mahmood Date: Mon, 19 Jan 2026 17:22:54 +0000 Subject: [PATCH 3/5] Add README for GLM module explaining architecture Co-authored-by: Ona --- utils/glm/README.md | 116 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 utils/glm/README.md diff --git a/utils/glm/README.md b/utils/glm/README.md new file mode 100644 index 00000000..15fec944 --- /dev/null +++ b/utils/glm/README.md @@ -0,0 +1,116 @@ +# GLM-based Firing Rate Estimation + +This module fits Generalized Linear Models (GLMs) to electrophysiological spike data using the [nemos](https://nemos.readthedocs.io/) library. + +## Why Separate Scripts? + +The nemos library has dependency conflicts with blech_clust's conda environment (particularly around numpy/jax versions). To work around this, the GLM fitting runs in a **separate virtual environment** dedicated to nemos. + +This module uses a two-environment architecture: + +``` +┌─────────────────────────────────────────────────────────────┐ +│ infer_glm_rates.py │ +│ (Main Orchestrator) │ +│ Can run from any Python environment │ +└─────────────────────┬───────────────────────────────────────┘ + │ + ┌───────────┴───────────┐ + ▼ ▼ +┌─────────────────────┐ ┌─────────────────────┐ +│ _glm_extract_data.py│ │ _glm_fit_models.py │ +│ │ │ │ +│ Runs in blech_clust │ │ Runs in nemos │ +│ conda environment │ │ virtual environment│ +│ │ │ │ +│ - Loads spike data │ │ - Fits GLM models │ +│ - Uses ephys_data │ │ - Uses nemos │ +│ - Saves to temp/ │ │ - Reads from temp/ │ +└─────────────────────┘ └─────────────────────┘ + │ │ + └───────────┬───────────┘ + ▼ + ┌───────────────┐ + │ _temp/ │ + │ (intermediate │ + │ files) │ + └───────────────┘ +``` + +## Files + +| File | Purpose | Environment | +|------|---------|-------------| +| `infer_glm_rates.py` | Main entry point. Detects environments, orchestrates data flow | Any | +| `_glm_extract_data.py` | Extracts spike data using ephys_data | blech_clust conda | +| `_glm_fit_models.py` | Fits GLM models, generates plots, writes HDF5 | nemos venv | +| `_temp/` | Temporary directory for intermediate numpy files (gitignored) | N/A | + +## Setup + +1. **blech_clust environment**: Should already exist if you're using blech_clust. + +2. **nemos environment**: Create a separate virtual environment: + ```bash + python -m venv ~/nemos_env + source ~/nemos_env/bin/activate + pip install nemos matplotlib pandas tables + ``` + +## Usage + +```bash +python utils/glm/infer_glm_rates.py [options] +``` + +The script will: +1. Auto-detect the blech_clust conda environment +2. Auto-detect or prompt for the nemos virtual environment +3. Extract data using blech_clust's ephys_data +4. Fit GLM models using nemos +5. Save results to HDF5 and generate plots + +### Options + +| Option | Default | Description | +|--------|---------|-------------| +| `--bin_size` | 25 | Bin size in ms | +| `--history_window` | 250 | History window in ms for autoregressive effects | +| `--n_basis_funcs` | 8 | Number of basis functions for history filter | +| `--time_lims` | [1500, 4500] | Time limits for analysis [start, end] in ms | +| `--include_coupling` | False | Include coupling between neurons | +| `--separate_tastes` | False | Fit separate models for each taste | +| `--separate_regions` | False | Fit separate models for each region | +| `--retrain` | False | Force retraining even if model exists | +| `--blech_clust_env` | blech_clust | Name of blech_clust conda environment | +| `--nemos_env` | (auto) | Path to nemos venv Python interpreter | + +### Example + +```bash +# Basic usage +python utils/glm/infer_glm_rates.py /path/to/data + +# With options +python utils/glm/infer_glm_rates.py /path/to/data \ + --separate_tastes \ + --separate_regions \ + --include_coupling \ + --nemos_env ~/nemos_env/bin/python +``` + +## Output + +- **HDF5**: Results saved under `/glm_output/regions/` in the data's HDF5 file +- **CSV**: `glm_output/bits_per_spike_summary.csv` with model performance metrics +- **Plots**: `glm_output/plots/` with distribution and individual neuron plots +- **Models**: `glm_output/artifacts/` with pickled GLM models + +## Model Details + +The GLM includes: +- **Spike history**: Raised cosine log basis functions capture autoregressive effects +- **Stimulus features** (optional): Convolved indicator for stimulus onset +- **Neural coupling** (optional): Cross-neuron dependencies within region + +Performance is measured using **bits per spike**, which quantifies information gain over a baseline homogeneous Poisson model. From e1b61ba7214c74c54f59ceabf12775af8490035c Mon Sep 17 00:00:00 2001 From: Abuzar Mahmood Date: Mon, 19 Jan 2026 13:14:37 -0500 Subject: [PATCH 4/5] Refactor(glm): simplify model fitting and comment out non-essential operations - Omitted regularizer strength parameter to default to `Ridge`. - Temporarily commented out model saving and bits per spike distribution plotting to streamline current execution. These tweaks are intended to improve performance for now. --- utils/glm/_glm_fit_models.py | 44 ++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/utils/glm/_glm_fit_models.py b/utils/glm/_glm_fit_models.py index 2a214d00..23dccccf 100644 --- a/utils/glm/_glm_fit_models.py +++ b/utils/glm/_glm_fit_models.py @@ -159,7 +159,7 @@ def fit_glm_single_neuron( # Fit GLM model = nmo.glm.GLM( - regularizer=nmo.regularizer.Ridge(regularizer_strength=0.01) + regularizer=nmo.regularizer.Ridge() ) model.fit(X_valid, y_valid) @@ -313,15 +313,15 @@ def main(): include_coupling=include_coupling, ) - # Save model - with open(model_path, 'wb') as f: - pickle.dump({ - 'model': model, - 'X': X_valid, - 'y': y_valid, - 'feature_info': feature_info, - 'valid_mask': valid_mask, - }, f) + # # Save model + # with open(model_path, 'wb') as f: + # pickle.dump({ + # 'model': model, + # 'X': X_valid, + # 'y': y_valid, + # 'feature_info': feature_info, + # 'valid_mask': valid_mask, + # }, f) # Predict firing rates pred_rates = predict_firing_rates( @@ -356,18 +356,18 @@ def main(): # Save summary bps_df.to_csv(os.path.join(output_path, 'bits_per_spike_summary.csv'), index=False) - # Plot bits per spike distribution - fig, ax = plt.subplots(figsize=(10, 6)) - for region in bps_df['region'].unique(): - region_bps = bps_df[bps_df['region'] == region]['bits_per_spike'] - ax.hist(region_bps, alpha=0.5, label=region, bins=20) - ax.set_xlabel('Bits per Spike') - ax.set_ylabel('Count') - ax.set_title('GLM Model Performance: Bits per Spike') - ax.legend() - fig.savefig(os.path.join(plots_dir, 'bits_per_spike_distribution.png'), dpi=150) - plt.close(fig) - + # # Plot bits per spike distribution + # fig, ax = plt.subplots(figsize=(10, 6)) + # for region in bps_df['region'].unique(): + # region_bps = bps_df[bps_df['region'] == region]['bits_per_spike'] + # ax.hist(region_bps, alpha=0.5, label=region, bins=20) + # ax.set_xlabel('Bits per Spike') + # ax.set_ylabel('Count') + # ax.set_title('GLM Model Performance: Bits per Spike') + # ax.legend() + # fig.savefig(os.path.join(plots_dir, 'bits_per_spike_distribution.png'), dpi=150) + # plt.close(fig) + # # Plot mean bits per spike by taste and region fig, ax = plt.subplots(figsize=(10, 6)) bps_summary = bps_df.groupby(['taste', 'region'])['bits_per_spike'].mean().unstack() From fef34803b693042321e3be81324dd95ed5772f80 Mon Sep 17 00:00:00 2001 From: Abuzar Mahmood Date: Mon, 19 Jan 2026 13:15:14 -0500 Subject: [PATCH 5/5] feat(glm): add test mode to `parse_args` function for streamlined testing - Introduced a `test_mode` parameter to enable running tests without actual command-line arguments, facilitating easier code testing. - Enhanced `find_conda_env` to efficiently parse the conda environment list, providing the Python path for better environment management. - Improved the script's feedback with inline comments and print statements when operating in test mode. - Adjusted the default behavior of `sys.exit` to allow for graceful test termination and better debugging. --- utils/glm/infer_glm_rates.py | 91 ++++++++++++++++++++++++------------ 1 file changed, 62 insertions(+), 29 deletions(-) diff --git a/utils/glm/infer_glm_rates.py b/utils/glm/infer_glm_rates.py index aa19c240..fc00ba24 100644 --- a/utils/glm/infer_glm_rates.py +++ b/utils/glm/infer_glm_rates.py @@ -35,31 +35,53 @@ # Argument parsing ############################################################ -def parse_args(): - parser = argparse.ArgumentParser( - description='Infer firing rates using GLM (nemos)') - parser.add_argument('data_dir', help='Path to data directory') - parser.add_argument('--bin_size', type=int, default=25, - help='Bin size in ms for spike binning (default: %(default)s)') - parser.add_argument('--history_window', type=int, default=250, - help='History window in ms for autoregressive effects (default: %(default)s)') - parser.add_argument('--n_basis_funcs', type=int, default=8, - help='Number of basis functions for history filter (default: %(default)s)') - parser.add_argument('--time_lims', type=int, nargs=2, default=[1500, 4500], - help='Time limits for analysis [start, end] in ms (default: %(default)s)') - parser.add_argument('--include_coupling', action='store_true', - help='Include coupling between neurons (default: %(default)s)') - parser.add_argument('--separate_tastes', action='store_true', - help='Fit separate models for each taste (default: %(default)s)') - parser.add_argument('--separate_regions', action='store_true', - help='Fit separate models for each region (default: %(default)s)') - parser.add_argument('--retrain', action='store_true', - help='Force retraining of model (default: %(default)s)') - parser.add_argument('--blech_clust_env', type=str, default='blech_clust', - help='Name of blech_clust conda environment (default: %(default)s)') - parser.add_argument('--nemos_env', type=str, default=None, - help='Path to nemos venv Python interpreter') - return parser.parse_args() +def parse_args(test_mode=False): + if test_mode: + print('====================') + print('Running in test mode') + print('====================') + # data_dir = '/home/abuzarmahmood/projects/blech_clust/pipeline_testing/test_data_handling/test_data/KM45_5tastes_210620_113227_new' + # data_dir = '/media/storage/abu_resorted/bla_gc/AM11_4Tastes_191030_114043_copy' + data_dir = '/media/storage/abu_resorted/bla_gc/AM35_4Tastes_201228_124547' + args = argparse.Namespace( + data_dir=data_dir, + bin_size=25, + history_window=250, + n_basis_funcs=8, + time_lims=[1500, 4500], + include_coupling=True, + separate_tastes=True, + separate_regions=True, + retrain=False, + blech_clust_env='blech_clust', + nemos_env='/home/abuzarmahmood/Desktop/blech_clust/nemos_venv/bin/python', + ) + return args, test_mode + else: + parser = argparse.ArgumentParser( + description='Infer firing rates using GLM (nemos)') + parser.add_argument('data_dir', help='Path to data directory') + parser.add_argument('--bin_size', type=int, default=25, + help='Bin size in ms for spike binning (default: %(default)s)') + parser.add_argument('--history_window', type=int, default=250, + help='History window in ms for autoregressive effects (default: %(default)s)') + parser.add_argument('--n_basis_funcs', type=int, default=8, + help='Number of basis functions for history filter (default: %(default)s)') + parser.add_argument('--time_lims', type=int, nargs=2, default=[1500, 4500], + help='Time limits for analysis [start, end] in ms (default: %(default)s)') + parser.add_argument('--include_coupling', action='store_false', + help='Include coupling between neurons (default: %(default)s)') + parser.add_argument('--separate_tastes', action='store_true', + help='Fit separate models for each taste (default: %(default)s)') + parser.add_argument('--separate_regions', action='store_true', + help='Fit separate models for each region (default: %(default)s)') + parser.add_argument('--retrain', action='store_true', + help='Force retraining of model (default: %(default)s)') + parser.add_argument('--blech_clust_env', type=str, default='blech_clust', + help='Name of blech_clust conda environment (default: %(default)s)') + parser.add_argument('--nemos_env', type=str, default=None, + help='Path to nemos venv Python interpreter') + return parser.parse_args(), test_mode ############################################################ @@ -70,9 +92,16 @@ def find_conda_env(env_name): """Find conda environment by name and return its Python path.""" try: result = subprocess.run( - ['conda', 'run', '-n', env_name, 'which', 'python'], + # ['conda', 'run', '-n', env_name, 'which', 'python'], + ['conda', 'env', 'list'], capture_output=True, text=True, timeout=30 ) + # Parse output to find env path + for line in result.stdout.splitlines(): + if line.startswith(env_name + ' '): + env_path = line.split()[1] + python_path = os.path.join(env_path, 'bin', 'python') + return python_path if result.returncode == 0: return result.stdout.strip() except (subprocess.TimeoutExpired, FileNotFoundError): @@ -138,9 +167,13 @@ def prompt_for_env(env_type, default_name=None): ############################################################ def main(): - args = parse_args() + args, _ = parse_args() + # args, test_mode = parse_args(test_mode=True) - script_dir = os.path.dirname(os.path.abspath(__file__)) + if not test_mode: + script_dir = os.path.dirname(os.path.abspath(__file__)) + else: + script_dir = '/home/abuzarmahmood/Desktop/blech_clust/utils/glm' data_dir = os.path.abspath(args.data_dir) # Setup output directories @@ -215,7 +248,7 @@ def main(): if result.returncode != 0: print("ERROR: GLM fitting failed") print(result.stderr) - sys.exit(1) + # sys.exit(1) print(result.stdout) # Cleanup temp files