Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 180 additions & 0 deletions tests/test_ephys_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,3 +412,183 @@ def test_baks_rate_with_sparse_spikes(self):
# The firing rate should be higher around spike times
# This is a basic sanity check for BAKS functionality
assert np.max(firing_rate) > 0


class TestShuffledPalatability:
"""Test class for shuffled palatability correlation functionality"""

def setup_method(self):
"""Set up test data for shuffled palatability tests"""
# Create a temporary directory for testing
self.temp_dir = tempfile.mkdtemp()

# Create a dummy HDF5 file in the temp directory
hdf5_path = os.path.join(self.temp_dir, 'test.h5')
with open(hdf5_path, 'w') as f:
f.write('')

# Create mock ephys_data instance with required attributes
self.data = ephys_data(data_dir=self.temp_dir)

# Mock essential data structures
self.data.pal_df = pd.DataFrame({
'dig_ins': ['taste1', 'taste2', 'taste3'],
'dig_in_nums': [1, 2, 3],
'taste_names': ['sucrose', 'quinine', 'water'],
'pal_ranks': [5, 1, 3] # palatability rankings
})

# Create mock firing data: 3 neurons, 10 time bins, with different trial counts
self.data.firing_list = [
# taste1: 5 trials
np.random.rand(5, 3, 10) * 10, # trials x neurons x time_bins
# taste2: 3 trials
np.random.rand(3, 3, 10) * 8,
# taste3: 4 trials
np.random.rand(4, 3, 10) * 6
]

# Create mock palatability correlation results
self.data.pal_rho_array = np.random.rand(
3, 10) * 0.8 # neurons x time_bins
self.data.pal_p_array = np.random.rand(3, 10) * 0.1

def teardown_method(self):
"""Clean up temporary directory"""
import shutil
if hasattr(self, 'temp_dir') and os.path.exists(self.temp_dir):
shutil.rmtree(self.temp_dir)

def test_calc_shuffled_palatability_basic(self):
"""Test basic shuffled palatability calculation"""
# Perform shuffled calculation with small number of shuffles for testing
self.data.calc_shuffled_palatability(
n_shuffles=10, confidence_level=0.95)

# Check that all required attributes were created
assert hasattr(self.data, 'pal_shuffled_mean')
assert hasattr(self.data, 'pal_shuffled_std')
assert hasattr(self.data, 'pal_shuffled_ci_lower')
assert hasattr(self.data, 'pal_shuffled_ci_upper')
assert hasattr(self.data, 'pal_shuffled_all')
assert hasattr(self.data, 'pal_significance')

# Check shapes
assert self.data.pal_shuffled_mean.shape == self.data.pal_rho_array.shape
assert self.data.pal_shuffled_std.shape == self.data.pal_rho_array.shape
assert self.data.pal_shuffled_ci_lower.shape == self.data.pal_rho_array.shape
assert self.data.pal_shuffled_ci_upper.shape == self.data.pal_rho_array.shape
assert self.data.pal_significance.shape == self.data.pal_rho_array.shape

# Check shuffled correlations array shape
expected_shape = (10, 3, 10) # n_shuffles x neurons x time_bins
assert self.data.pal_shuffled_all.shape == expected_shape

# Check that shuffled correlations are non-negative (using absolute values)
assert np.all(self.data.pal_shuffled_mean >= 0)
assert np.all(self.data.pal_shuffled_std >= 0)

# Check confidence interval logic
assert np.all(self.data.pal_shuffled_ci_lower <=
self.data.pal_shuffled_ci_upper)

def test_calc_shuffled_palatability_without_prerequisites(self):
"""Test that error is raised when prerequisites are not met"""
# Create mock object without pal_rho_array
from unittest.mock import Mock
data_no_pre = Mock()
data_no_pre.pal_df = self.data.pal_df
data_no_pre.firing_list = self.data.firing_list
# Intentionally don't set pal_rho_array

# Import the method and bind it to the mock
from utils.ephys_data.ephys_data import ephys_data
data_no_pre.calc_shuffled_palatability = ephys_data.calc_shuffled_palatability.__get__(
data_no_pre)

# Should raise exception
with pytest.raises(Exception, match="calc_palatability\\(\\) must be called before calc_shuffled_palatability\\(\\)"):
data_no_pre.calc_shuffled_palatability(n_shuffles=5)

def test_calc_shuffled_palatability_missing_data(self):
"""Test that error is raised when required data is missing"""
# Create mock object with only pal_rho_array but missing other required data
from unittest.mock import Mock
data_missing = Mock()
data_missing.pal_rho_array = self.data.pal_rho_array
# Intentionally don't set pal_df and firing_list

# Import the method and bind it to the mock
from utils.ephys_data.ephys_data import ephys_data
data_missing.calc_shuffled_palatability = ephys_data.calc_shuffled_palatability.__get__(
data_missing)

# Should raise exception
with pytest.raises(Exception, match="Required palatability data not found"):
data_missing.calc_shuffled_palatability(n_shuffles=5)

def test_calc_shuffled_palatability_different_parameters(self):
"""Test shuffled palatability with different parameters"""
# Test with different number of shuffles
self.data.calc_shuffled_palatability(
n_shuffles=20, confidence_level=0.90)

assert self.data.pal_shuffled_all.shape[0] == 20 # n_shuffles
assert self.data.pal_shuffled_mean.shape == self.data.pal_rho_array.shape

# Test with different confidence level
self.data.calc_shuffled_palatability(
n_shuffles=10, confidence_level=0.99)

# With higher confidence level, intervals should be wider
# (This is a rough check - exact values depend on random shuffling)
assert np.all(self.data.pal_shuffled_ci_lower <=
self.data.pal_shuffled_ci_upper)

def test_shuffled_palatability_significance_calculation(self):
"""Test that significance values are reasonable"""
self.data.calc_shuffled_palatability(n_shuffles=50)

# Significance values should be between 0 and 1
assert np.all(self.data.pal_significance >= 0)
assert np.all(self.data.pal_significance <= 1)

# Check that significance calculation makes sense
# For any given neuron/time point, significance is the proportion
# of shuffled correlations >= actual correlation
for neuron_idx in range(self.data.pal_rho_array.shape[0]):
for time_idx in range(self.data.pal_rho_array.shape[1]):
actual_corr = self.data.pal_rho_array[neuron_idx, time_idx]
shuffled_corrs = self.data.pal_shuffled_all[:,
neuron_idx, time_idx]
expected_sig = np.mean(shuffled_corrs >= actual_corr)
calculated_sig = self.data.pal_significance[neuron_idx, time_idx]

# Allow for small floating point differences
assert abs(expected_sig - calculated_sig) < 1e-10

def test_shuffled_palatability_statistics(self):
"""Test that shuffled statistics are calculated correctly"""
self.data.calc_shuffled_palatability(n_shuffles=100)

# Compare manually calculated statistics with stored ones
for neuron_idx in range(self.data.pal_rho_array.shape[0]):
for time_idx in range(self.data.pal_rho_array.shape[1]):
shuffled_corrs = self.data.pal_shuffled_all[:,
neuron_idx, time_idx]

manual_mean = np.mean(shuffled_corrs)
manual_std = np.std(shuffled_corrs)
manual_lower = np.percentile(shuffled_corrs, 2.5) # 95% CI
manual_upper = np.percentile(shuffled_corrs, 97.5)

stored_mean = self.data.pal_shuffled_mean[neuron_idx, time_idx]
stored_std = self.data.pal_shuffled_std[neuron_idx, time_idx]
stored_lower = self.data.pal_shuffled_ci_lower[neuron_idx, time_idx]
stored_upper = self.data.pal_shuffled_ci_upper[neuron_idx, time_idx]

# Allow for small floating point differences
assert abs(manual_mean - stored_mean) < 1e-10
assert abs(manual_std - stored_std) < 1e-10
assert abs(manual_lower - stored_lower) < 1e-10
assert abs(manual_upper - stored_upper) < 1e-10
152 changes: 152 additions & 0 deletions utils/ephys_data/ephys_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,54 @@
strong_pal_neurons = np.where(np.max(data.pal_array, axis=1) > 0.7)[0]
print(f"Neurons with strong palatability coding: {strong_pal_neurons}")

# Calculate shuffled palatability correlation for significance testing
data.calc_shuffled_palatability(n_shuffles=1000)

# Plot comparison of actual vs shuffled correlations
plt.figure(figsize=(12, 8))

# Plot actual correlation
plt.subplot(2, 2, 1)
plt.imshow(data.pal_rho_array, aspect='auto', cmap='viridis')
plt.colorbar(label='|Actual Palatability Correlation|')
plt.xlabel('Time (bins)')
plt.ylabel('Neuron')
plt.title('Actual Palatability Correlation')

# Plot mean shuffled correlation
plt.subplot(2, 2, 2)
plt.imshow(data.pal_shuffled_mean, aspect='auto', cmap='viridis')
plt.colorbar(label='|Mean Shuffled Correlation|')
plt.xlabel('Time (bins)')
plt.ylabel('Neuron')
plt.title('Mean Shuffled Palatability Correlation')

# Plot significance (p-values)
plt.subplot(2, 2, 3)
plt.imshow(data.pal_significance, aspect='auto', cmap='hot')
plt.colorbar(label='P-value')
plt.xlabel('Time (bins)')
plt.ylabel('Neuron')
plt.title('Palatability Significance (p < 0.05)')

# Plot significant neurons only
plt.subplot(2, 2, 4)
sig_mask = data.pal_significance < 0.05
significant_correlation = np.where(sig_mask, data.pal_rho_array, np.nan)
plt.imshow(significant_correlation, aspect='auto', cmap='viridis')
plt.colorbar(label='|Significant Palatability Correlation|')
plt.xlabel('Time (bins)')
plt.ylabel('Neuron')
plt.title('Significant Palatability Correlation Only')

plt.tight_layout()
plt.show()

# Find neurons with significant palatability coding
significant_neurons = np.where(np.any(data.pal_significance < 0.05, axis=1))[0]
print(f"Neurons with significant palatability coding: {significant_neurons}")
print(f"Total significant neurons: {len(significant_neurons)} out of {data.pal_rho_array.shape[0]}")

Workflow 5: Time-Frequency Analysis
-----------------------------------------------------
from blech_clust.utils.ephys_data.ephys_data import ephys_data
Expand Down Expand Up @@ -176,6 +224,7 @@
- `firing_rate_method_selector`: Selects the method for firing rate calculation.
- `get_firing_rates`: Converts spikes to firing rates.
- `calc_palatability`: Calculates single neuron palatability from firing rates.
- `calc_shuffled_palatability`: Calculates shuffled palatability correlation for significance testing.
- `separate_laser_firing`: Separates firing rates into laser on and off conditions.
- `get_info_dict`: Loads information from a JSON file.
- `get_region_electrodes`: Extracts electrodes for each region from a JSON file.
Expand Down Expand Up @@ -1119,6 +1168,109 @@ def calc_palatability(self):
self.pal_rho_array = np.abs(pal_rho_array).T
self.pal_p_array = pal_p_array.T

def calc_shuffled_palatability(self, n_shuffles=1000, confidence_level=0.95):
"""
Calculate shuffled palatability correlation to provide context for actual data correlation

This method performs multiple shuffles of the palatability vector to create a null
distribution of correlations, allowing assessment of whether the observed palatability
correlations are significantly different from chance.

Args:
n_shuffles (int): Number of random shuffles to perform (default: 1000)
confidence_level (float): Confidence level for intervals (default: 0.95)

Requires:
- calc_palatability() must be called first to set up required data
- pal_rho_array: Original palatability correlation coefficients

Generates:
- pal_shuffled_mean: Mean correlation across shuffles (neurons x time_bins)
- pal_shuffled_std: Standard deviation across shuffles (neurons x time_bins)
- pal_shuffled_ci_lower: Lower confidence bound (neurons x time_bins)
- pal_shuffled_ci_upper: Upper confidence bound (neurons x time_bins)
- pal_shuffled_all: All shuffled correlations (n_shuffles x neurons x time_bins)
"""

# Check if required data exists
if 'pal_rho_array' not in dir(self):
raise Exception(
"calc_palatability() must be called before calc_shuffled_palatability()")

if 'pal_df' not in dir(self) or 'firing_list' not in dir(self):
raise Exception(
"Required palatability data not found. Run calc_palatability() first.")

print(
f'Calculating shuffled palatability correlation with {n_shuffles} shuffles...')

# Get the original palatability vector and firing data
trial_counts = [x.shape[0] for x in self.firing_list]
pal_vec = np.concatenate(
[np.repeat(x, y) for x, y in zip(self.pal_df['pal_ranks'], trial_counts)])
cat_firing = np.concatenate(self.firing_list, axis=0).T

# Add the same small noise as in original calculation for consistency
cat_firing += np.random.normal(0, 1e-6, cat_firing.shape)

# Initialize arrays to store shuffled results
neurons, time_bins = cat_firing.shape[:2]
shuffled_correlations = np.zeros((n_shuffles, neurons, time_bins))

# Perform shuffles
for shuffle_idx in tqdm(range(n_shuffles), desc="Shuffling palatability"):
# Shuffle the palatability vector
shuffled_pal_vec = np.random.permutation(pal_vec)

# Calculate correlations for this shuffle
shuffle_rho_array = np.zeros((neurons, time_bins))
inds = list(np.ndindex((neurons, time_bins)))

for this_ind in inds:
rho, _ = spearmanr(
cat_firing[tuple(this_ind)], shuffled_pal_vec)
shuffle_rho_array[tuple(this_ind)] = abs(
rho) # Use absolute value like original

shuffled_correlations[shuffle_idx] = shuffle_rho_array

# Keep shuffled_correlations in original format (n_shuffles, neurons, time_bins)
# This matches self.pal_rho_array shape (neurons, time_bins) for easy comparison

# Calculate statistics across shuffles
# (time_bins, neurons) to match original
self.pal_shuffled_mean = np.mean(shuffled_correlations, axis=0).T
# (time_bins, neurons) to match original
self.pal_shuffled_std = np.std(shuffled_correlations, axis=0).T

# Calculate confidence intervals
alpha = 1 - confidence_level
lower_percentile = (alpha / 2) * 100
upper_percentile = (1 - alpha / 2) * 100

self.pal_shuffled_ci_lower = np.percentile(
shuffled_correlations, lower_percentile, axis=0).T
self.pal_shuffled_ci_upper = np.percentile(
shuffled_correlations, upper_percentile, axis=0).T

# Store all shuffled correlations (transpose to match original format)
self.pal_shuffled_all = shuffled_correlations.transpose(
0, 2, 1) # (n_shuffles, time_bins, neurons)

# Calculate significance: proportion of shuffles where correlation >= actual correlation
# Compare (n_shuffles, neurons, time_bins) with (neurons, time_bins)
# Add new axis to pal_rho_array.T for proper broadcasting: (1, neurons, time_bins)
significance_result = np.mean(
shuffled_correlations >= self.pal_rho_array.T[np.newaxis, :, :], axis=0)
# Transpose to match original format
self.pal_significance = significance_result.T

print(f'Shuffled palatability calculation complete.')
print(
f'Mean shuffled correlation: {np.mean(self.pal_shuffled_mean):.4f} ± {np.mean(self.pal_shuffled_std):.4f}')
print(
f'Proportion of neurons with significant palatability coding (p < 0.05): {np.mean(self.pal_significance < 0.05):.3f}')

def separate_laser_firing(self):
"""Separate firing rate arrays into laser on and off conditions

Expand Down
Loading