Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ blech.dir
*.Rhistory
.aider*
.env
utils/glm/_temp/
9 changes: 9 additions & 0 deletions params/_templates/glm_params.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"bin_size": 25,
"history_window": 250,
"n_basis_funcs": 8,
"time_lims": [1500, 4500],
"include_coupling": false,
"separate_tastes": true,
"separate_regions": true
}
116 changes: 116 additions & 0 deletions utils/glm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# GLM-based Firing Rate Estimation

This module fits Generalized Linear Models (GLMs) to electrophysiological spike data using the [nemos](https://nemos.readthedocs.io/) library.

## Why Separate Scripts?

The nemos library has dependency conflicts with blech_clust's conda environment (particularly around numpy/jax versions). To work around this, the GLM fitting runs in a **separate virtual environment** dedicated to nemos.

This module uses a two-environment architecture:

```
┌─────────────────────────────────────────────────────────────┐
│ infer_glm_rates.py │
│ (Main Orchestrator) │
│ Can run from any Python environment │
└─────────────────────┬───────────────────────────────────────┘
┌───────────┴───────────┐
▼ ▼
┌─────────────────────┐ ┌─────────────────────┐
│ _glm_extract_data.py│ │ _glm_fit_models.py │
│ │ │ │
│ Runs in blech_clust │ │ Runs in nemos │
│ conda environment │ │ virtual environment│
│ │ │ │
│ - Loads spike data │ │ - Fits GLM models │
│ - Uses ephys_data │ │ - Uses nemos │
│ - Saves to temp/ │ │ - Reads from temp/ │
└─────────────────────┘ └─────────────────────┘
│ │
└───────────┬───────────┘
┌───────────────┐
│ _temp/ │
│ (intermediate │
│ files) │
└───────────────┘
```

## Files

| File | Purpose | Environment |
|------|---------|-------------|
| `infer_glm_rates.py` | Main entry point. Detects environments, orchestrates data flow | Any |
| `_glm_extract_data.py` | Extracts spike data using ephys_data | blech_clust conda |
| `_glm_fit_models.py` | Fits GLM models, generates plots, writes HDF5 | nemos venv |
| `_temp/` | Temporary directory for intermediate numpy files (gitignored) | N/A |

## Setup

1. **blech_clust environment**: Should already exist if you're using blech_clust.

2. **nemos environment**: Create a separate virtual environment:
```bash
python -m venv ~/nemos_env
source ~/nemos_env/bin/activate
pip install nemos matplotlib pandas tables
```

## Usage

```bash
python utils/glm/infer_glm_rates.py <data_dir> [options]
```

The script will:
1. Auto-detect the blech_clust conda environment
2. Auto-detect or prompt for the nemos virtual environment
3. Extract data using blech_clust's ephys_data
4. Fit GLM models using nemos
5. Save results to HDF5 and generate plots

### Options

| Option | Default | Description |
|--------|---------|-------------|
| `--bin_size` | 25 | Bin size in ms |
| `--history_window` | 250 | History window in ms for autoregressive effects |
| `--n_basis_funcs` | 8 | Number of basis functions for history filter |
| `--time_lims` | [1500, 4500] | Time limits for analysis [start, end] in ms |
| `--include_coupling` | False | Include coupling between neurons |
| `--separate_tastes` | False | Fit separate models for each taste |
| `--separate_regions` | False | Fit separate models for each region |
| `--retrain` | False | Force retraining even if model exists |
| `--blech_clust_env` | blech_clust | Name of blech_clust conda environment |
| `--nemos_env` | (auto) | Path to nemos venv Python interpreter |

### Example

```bash
# Basic usage
python utils/glm/infer_glm_rates.py /path/to/data

# With options
python utils/glm/infer_glm_rates.py /path/to/data \
--separate_tastes \
--separate_regions \
--include_coupling \
--nemos_env ~/nemos_env/bin/python
```

## Output

- **HDF5**: Results saved under `/glm_output/regions/` in the data's HDF5 file
- **CSV**: `glm_output/bits_per_spike_summary.csv` with model performance metrics
- **Plots**: `glm_output/plots/` with distribution and individual neuron plots
- **Models**: `glm_output/artifacts/` with pickled GLM models

## Model Details

The GLM includes:
- **Spike history**: Raised cosine log basis functions capture autoregressive effects
- **Stimulus features** (optional): Convolved indicator for stimulus onset
- **Neural coupling** (optional): Cross-neuron dependencies within region

Performance is measured using **bits per spike**, which quantifies information gain over a baseline homogeneous Poisson model.
9 changes: 9 additions & 0 deletions utils/glm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""
GLM-based firing rate estimation module.

This module provides GLM fitting for electrophysiological spike data using nemos.
Due to dependency conflicts, nemos runs in a separate virtual environment.

Usage:
python -m blech_clust.utils.glm.infer_glm_rates <data_dir> [options]
"""
118 changes: 118 additions & 0 deletions utils/glm/_glm_extract_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""
Data extraction script for GLM fitting.

This script runs in the blech_clust conda environment and extracts spike data
using ephys_data. The extracted data is saved as numpy files for the GLM
fitting script to load.

This script is called by infer_glm_rates.py and should not be run directly.
"""

import sys
import os
import json
import numpy as np

def main():
if len(sys.argv) != 2:
print("Usage: _glm_extract_data.py <params_path>")
sys.exit(1)

params_path = sys.argv[1]

# Load parameters
with open(params_path, 'r') as f:
params = json.load(f)

data_dir = params['data_dir']
temp_dir = params['temp_dir']
time_lims = params['time_lims']
bin_size = params['bin_size']

# Add blech_clust to path
script_path = os.path.abspath(__file__)
blech_clust_path = os.path.dirname(os.path.dirname(script_path))
sys.path.insert(0, blech_clust_path)

# Import blech_clust modules
from blech_clust.utils.blech_utils import imp_metadata, pipeline_graph_check
from blech_clust.utils.ephys_data import ephys_data

print(f"Extracting data from: {data_dir}")

# Pipeline check
this_pipeline_check = pipeline_graph_check(data_dir)
this_pipeline_check.check_previous(script_path)
this_pipeline_check.write_to_log(script_path, 'attempted')

# Load data
data = ephys_data.ephys_data(data_dir)
data.get_spikes()
data.get_region_units()

# Build region mapping
region_dict = dict(zip(data.region_names, data.region_units))
region_vec = np.zeros(len(np.concatenate(data.region_units)), dtype=object)
for region_name, unit_list in region_dict.items():
region_vec[unit_list] = region_name

n_tastes = len(data.spikes)
n_neurons = data.spikes[0].shape[1]

print(f" Tastes: {n_tastes}")
print(f" Neurons: {n_neurons}")
print(f" Regions: {data.region_names}")

# Process and save spike data for each taste
spike_data = {}
for taste_idx, taste_spikes in enumerate(data.spikes):
# Cut to time limits
taste_spikes = taste_spikes[..., time_lims[0]:time_lims[1]]

# Bin spikes
n_trials, n_neurons_taste, n_time = taste_spikes.shape
n_bins = n_time // bin_size
trimmed = taste_spikes[..., :n_bins * bin_size]
binned = trimmed.reshape(n_trials, n_neurons_taste, n_bins, bin_size).sum(axis=-1)

spike_data[f'taste_{taste_idx}'] = binned
print(f" Taste {taste_idx}: {binned.shape}")

# Save extracted data
extracted_data = {
'spike_data': spike_data,
'region_names': list(data.region_names),
'region_units': {name: list(units) for name, units in region_dict.items()},
'region_vec': region_vec.tolist(),
'n_tastes': n_tastes,
'n_neurons': n_neurons,
'hdf5_path': data.hdf5_path,
}

# Save as numpy file
output_file = os.path.join(temp_dir, 'extracted_data.npz')
np.savez(output_file, **{
'spike_data_keys': list(spike_data.keys()),
'region_names': np.array(data.region_names, dtype=object),
'region_vec': np.array(region_vec, dtype=object),
'n_tastes': n_tastes,
'n_neurons': n_neurons,
'hdf5_path': data.hdf5_path,
})

# Save spike arrays separately (npz doesn't handle nested dicts well)
for key, arr in spike_data.items():
np.save(os.path.join(temp_dir, f'{key}.npy'), arr)

# Save region units as JSON
with open(os.path.join(temp_dir, 'region_units.json'), 'w') as f:
json.dump({name: list(map(int, units)) for name, units in region_dict.items()}, f)

print(f"\nData saved to: {temp_dir}")

# Write to pipeline log
this_pipeline_check.write_to_log(script_path, 'completed')


if __name__ == '__main__':
main()
Loading
Loading