diff --git a/tests/test_ephys_data.py b/tests/test_ephys_data.py index 667feaae..a6a4ea0c 100644 --- a/tests/test_ephys_data.py +++ b/tests/test_ephys_data.py @@ -412,3 +412,183 @@ def test_baks_rate_with_sparse_spikes(self): # The firing rate should be higher around spike times # This is a basic sanity check for BAKS functionality assert np.max(firing_rate) > 0 + + +class TestShuffledPalatability: + """Test class for shuffled palatability correlation functionality""" + + def setup_method(self): + """Set up test data for shuffled palatability tests""" + # Create a temporary directory for testing + self.temp_dir = tempfile.mkdtemp() + + # Create a dummy HDF5 file in the temp directory + hdf5_path = os.path.join(self.temp_dir, 'test.h5') + with open(hdf5_path, 'w') as f: + f.write('') + + # Create mock ephys_data instance with required attributes + self.data = ephys_data(data_dir=self.temp_dir) + + # Mock essential data structures + self.data.pal_df = pd.DataFrame({ + 'dig_ins': ['taste1', 'taste2', 'taste3'], + 'dig_in_nums': [1, 2, 3], + 'taste_names': ['sucrose', 'quinine', 'water'], + 'pal_ranks': [5, 1, 3] # palatability rankings + }) + + # Create mock firing data: 3 neurons, 10 time bins, with different trial counts + self.data.firing_list = [ + # taste1: 5 trials + np.random.rand(5, 3, 10) * 10, # trials x neurons x time_bins + # taste2: 3 trials + np.random.rand(3, 3, 10) * 8, + # taste3: 4 trials + np.random.rand(4, 3, 10) * 6 + ] + + # Create mock palatability correlation results + self.data.pal_rho_array = np.random.rand( + 3, 10) * 0.8 # neurons x time_bins + self.data.pal_p_array = np.random.rand(3, 10) * 0.1 + + def teardown_method(self): + """Clean up temporary directory""" + import shutil + if hasattr(self, 'temp_dir') and os.path.exists(self.temp_dir): + shutil.rmtree(self.temp_dir) + + def test_calc_shuffled_palatability_basic(self): + """Test basic shuffled palatability calculation""" + # Perform shuffled calculation with small number of shuffles for testing + self.data.calc_shuffled_palatability( + n_shuffles=10, confidence_level=0.95) + + # Check that all required attributes were created + assert hasattr(self.data, 'pal_shuffled_mean') + assert hasattr(self.data, 'pal_shuffled_std') + assert hasattr(self.data, 'pal_shuffled_ci_lower') + assert hasattr(self.data, 'pal_shuffled_ci_upper') + assert hasattr(self.data, 'pal_shuffled_all') + assert hasattr(self.data, 'pal_significance') + + # Check shapes + assert self.data.pal_shuffled_mean.shape == self.data.pal_rho_array.shape + assert self.data.pal_shuffled_std.shape == self.data.pal_rho_array.shape + assert self.data.pal_shuffled_ci_lower.shape == self.data.pal_rho_array.shape + assert self.data.pal_shuffled_ci_upper.shape == self.data.pal_rho_array.shape + assert self.data.pal_significance.shape == self.data.pal_rho_array.shape + + # Check shuffled correlations array shape + expected_shape = (10, 3, 10) # n_shuffles x neurons x time_bins + assert self.data.pal_shuffled_all.shape == expected_shape + + # Check that shuffled correlations are non-negative (using absolute values) + assert np.all(self.data.pal_shuffled_mean >= 0) + assert np.all(self.data.pal_shuffled_std >= 0) + + # Check confidence interval logic + assert np.all(self.data.pal_shuffled_ci_lower <= + self.data.pal_shuffled_ci_upper) + + def test_calc_shuffled_palatability_without_prerequisites(self): + """Test that error is raised when prerequisites are not met""" + # Create mock object without pal_rho_array + from unittest.mock import Mock + data_no_pre = Mock() + data_no_pre.pal_df = self.data.pal_df + data_no_pre.firing_list = self.data.firing_list + # Intentionally don't set pal_rho_array + + # Import the method and bind it to the mock + from utils.ephys_data.ephys_data import ephys_data + data_no_pre.calc_shuffled_palatability = ephys_data.calc_shuffled_palatability.__get__( + data_no_pre) + + # Should raise exception + with pytest.raises(Exception, match="calc_palatability\\(\\) must be called before calc_shuffled_palatability\\(\\)"): + data_no_pre.calc_shuffled_palatability(n_shuffles=5) + + def test_calc_shuffled_palatability_missing_data(self): + """Test that error is raised when required data is missing""" + # Create mock object with only pal_rho_array but missing other required data + from unittest.mock import Mock + data_missing = Mock() + data_missing.pal_rho_array = self.data.pal_rho_array + # Intentionally don't set pal_df and firing_list + + # Import the method and bind it to the mock + from utils.ephys_data.ephys_data import ephys_data + data_missing.calc_shuffled_palatability = ephys_data.calc_shuffled_palatability.__get__( + data_missing) + + # Should raise exception + with pytest.raises(Exception, match="Required palatability data not found"): + data_missing.calc_shuffled_palatability(n_shuffles=5) + + def test_calc_shuffled_palatability_different_parameters(self): + """Test shuffled palatability with different parameters""" + # Test with different number of shuffles + self.data.calc_shuffled_palatability( + n_shuffles=20, confidence_level=0.90) + + assert self.data.pal_shuffled_all.shape[0] == 20 # n_shuffles + assert self.data.pal_shuffled_mean.shape == self.data.pal_rho_array.shape + + # Test with different confidence level + self.data.calc_shuffled_palatability( + n_shuffles=10, confidence_level=0.99) + + # With higher confidence level, intervals should be wider + # (This is a rough check - exact values depend on random shuffling) + assert np.all(self.data.pal_shuffled_ci_lower <= + self.data.pal_shuffled_ci_upper) + + def test_shuffled_palatability_significance_calculation(self): + """Test that significance values are reasonable""" + self.data.calc_shuffled_palatability(n_shuffles=50) + + # Significance values should be between 0 and 1 + assert np.all(self.data.pal_significance >= 0) + assert np.all(self.data.pal_significance <= 1) + + # Check that significance calculation makes sense + # For any given neuron/time point, significance is the proportion + # of shuffled correlations >= actual correlation + for neuron_idx in range(self.data.pal_rho_array.shape[0]): + for time_idx in range(self.data.pal_rho_array.shape[1]): + actual_corr = self.data.pal_rho_array[neuron_idx, time_idx] + shuffled_corrs = self.data.pal_shuffled_all[:, + neuron_idx, time_idx] + expected_sig = np.mean(shuffled_corrs >= actual_corr) + calculated_sig = self.data.pal_significance[neuron_idx, time_idx] + + # Allow for small floating point differences + assert abs(expected_sig - calculated_sig) < 1e-10 + + def test_shuffled_palatability_statistics(self): + """Test that shuffled statistics are calculated correctly""" + self.data.calc_shuffled_palatability(n_shuffles=100) + + # Compare manually calculated statistics with stored ones + for neuron_idx in range(self.data.pal_rho_array.shape[0]): + for time_idx in range(self.data.pal_rho_array.shape[1]): + shuffled_corrs = self.data.pal_shuffled_all[:, + neuron_idx, time_idx] + + manual_mean = np.mean(shuffled_corrs) + manual_std = np.std(shuffled_corrs) + manual_lower = np.percentile(shuffled_corrs, 2.5) # 95% CI + manual_upper = np.percentile(shuffled_corrs, 97.5) + + stored_mean = self.data.pal_shuffled_mean[neuron_idx, time_idx] + stored_std = self.data.pal_shuffled_std[neuron_idx, time_idx] + stored_lower = self.data.pal_shuffled_ci_lower[neuron_idx, time_idx] + stored_upper = self.data.pal_shuffled_ci_upper[neuron_idx, time_idx] + + # Allow for small floating point differences + assert abs(manual_mean - stored_mean) < 1e-10 + assert abs(manual_std - stored_std) < 1e-10 + assert abs(manual_lower - stored_lower) < 1e-10 + assert abs(manual_upper - stored_upper) < 1e-10 diff --git a/utils/ephys_data/ephys_data.py b/utils/ephys_data/ephys_data.py index 154c90b2..9708dd3f 100755 --- a/utils/ephys_data/ephys_data.py +++ b/utils/ephys_data/ephys_data.py @@ -118,6 +118,54 @@ strong_pal_neurons = np.where(np.max(data.pal_array, axis=1) > 0.7)[0] print(f"Neurons with strong palatability coding: {strong_pal_neurons}") +# Calculate shuffled palatability correlation for significance testing +data.calc_shuffled_palatability(n_shuffles=1000) + +# Plot comparison of actual vs shuffled correlations +plt.figure(figsize=(12, 8)) + +# Plot actual correlation +plt.subplot(2, 2, 1) +plt.imshow(data.pal_rho_array, aspect='auto', cmap='viridis') +plt.colorbar(label='|Actual Palatability Correlation|') +plt.xlabel('Time (bins)') +plt.ylabel('Neuron') +plt.title('Actual Palatability Correlation') + +# Plot mean shuffled correlation +plt.subplot(2, 2, 2) +plt.imshow(data.pal_shuffled_mean, aspect='auto', cmap='viridis') +plt.colorbar(label='|Mean Shuffled Correlation|') +plt.xlabel('Time (bins)') +plt.ylabel('Neuron') +plt.title('Mean Shuffled Palatability Correlation') + +# Plot significance (p-values) +plt.subplot(2, 2, 3) +plt.imshow(data.pal_significance, aspect='auto', cmap='hot') +plt.colorbar(label='P-value') +plt.xlabel('Time (bins)') +plt.ylabel('Neuron') +plt.title('Palatability Significance (p < 0.05)') + +# Plot significant neurons only +plt.subplot(2, 2, 4) +sig_mask = data.pal_significance < 0.05 +significant_correlation = np.where(sig_mask, data.pal_rho_array, np.nan) +plt.imshow(significant_correlation, aspect='auto', cmap='viridis') +plt.colorbar(label='|Significant Palatability Correlation|') +plt.xlabel('Time (bins)') +plt.ylabel('Neuron') +plt.title('Significant Palatability Correlation Only') + +plt.tight_layout() +plt.show() + +# Find neurons with significant palatability coding +significant_neurons = np.where(np.any(data.pal_significance < 0.05, axis=1))[0] +print(f"Neurons with significant palatability coding: {significant_neurons}") +print(f"Total significant neurons: {len(significant_neurons)} out of {data.pal_rho_array.shape[0]}") + Workflow 5: Time-Frequency Analysis ----------------------------------------------------- from blech_clust.utils.ephys_data.ephys_data import ephys_data @@ -176,6 +224,7 @@ - `firing_rate_method_selector`: Selects the method for firing rate calculation. - `get_firing_rates`: Converts spikes to firing rates. - `calc_palatability`: Calculates single neuron palatability from firing rates. + - `calc_shuffled_palatability`: Calculates shuffled palatability correlation for significance testing. - `separate_laser_firing`: Separates firing rates into laser on and off conditions. - `get_info_dict`: Loads information from a JSON file. - `get_region_electrodes`: Extracts electrodes for each region from a JSON file. @@ -1119,6 +1168,109 @@ def calc_palatability(self): self.pal_rho_array = np.abs(pal_rho_array).T self.pal_p_array = pal_p_array.T + def calc_shuffled_palatability(self, n_shuffles=1000, confidence_level=0.95): + """ + Calculate shuffled palatability correlation to provide context for actual data correlation + + This method performs multiple shuffles of the palatability vector to create a null + distribution of correlations, allowing assessment of whether the observed palatability + correlations are significantly different from chance. + + Args: + n_shuffles (int): Number of random shuffles to perform (default: 1000) + confidence_level (float): Confidence level for intervals (default: 0.95) + + Requires: + - calc_palatability() must be called first to set up required data + - pal_rho_array: Original palatability correlation coefficients + + Generates: + - pal_shuffled_mean: Mean correlation across shuffles (neurons x time_bins) + - pal_shuffled_std: Standard deviation across shuffles (neurons x time_bins) + - pal_shuffled_ci_lower: Lower confidence bound (neurons x time_bins) + - pal_shuffled_ci_upper: Upper confidence bound (neurons x time_bins) + - pal_shuffled_all: All shuffled correlations (n_shuffles x neurons x time_bins) + """ + + # Check if required data exists + if 'pal_rho_array' not in dir(self): + raise Exception( + "calc_palatability() must be called before calc_shuffled_palatability()") + + if 'pal_df' not in dir(self) or 'firing_list' not in dir(self): + raise Exception( + "Required palatability data not found. Run calc_palatability() first.") + + print( + f'Calculating shuffled palatability correlation with {n_shuffles} shuffles...') + + # Get the original palatability vector and firing data + trial_counts = [x.shape[0] for x in self.firing_list] + pal_vec = np.concatenate( + [np.repeat(x, y) for x, y in zip(self.pal_df['pal_ranks'], trial_counts)]) + cat_firing = np.concatenate(self.firing_list, axis=0).T + + # Add the same small noise as in original calculation for consistency + cat_firing += np.random.normal(0, 1e-6, cat_firing.shape) + + # Initialize arrays to store shuffled results + neurons, time_bins = cat_firing.shape[:2] + shuffled_correlations = np.zeros((n_shuffles, neurons, time_bins)) + + # Perform shuffles + for shuffle_idx in tqdm(range(n_shuffles), desc="Shuffling palatability"): + # Shuffle the palatability vector + shuffled_pal_vec = np.random.permutation(pal_vec) + + # Calculate correlations for this shuffle + shuffle_rho_array = np.zeros((neurons, time_bins)) + inds = list(np.ndindex((neurons, time_bins))) + + for this_ind in inds: + rho, _ = spearmanr( + cat_firing[tuple(this_ind)], shuffled_pal_vec) + shuffle_rho_array[tuple(this_ind)] = abs( + rho) # Use absolute value like original + + shuffled_correlations[shuffle_idx] = shuffle_rho_array + + # Keep shuffled_correlations in original format (n_shuffles, neurons, time_bins) + # This matches self.pal_rho_array shape (neurons, time_bins) for easy comparison + + # Calculate statistics across shuffles + # (time_bins, neurons) to match original + self.pal_shuffled_mean = np.mean(shuffled_correlations, axis=0).T + # (time_bins, neurons) to match original + self.pal_shuffled_std = np.std(shuffled_correlations, axis=0).T + + # Calculate confidence intervals + alpha = 1 - confidence_level + lower_percentile = (alpha / 2) * 100 + upper_percentile = (1 - alpha / 2) * 100 + + self.pal_shuffled_ci_lower = np.percentile( + shuffled_correlations, lower_percentile, axis=0).T + self.pal_shuffled_ci_upper = np.percentile( + shuffled_correlations, upper_percentile, axis=0).T + + # Store all shuffled correlations (transpose to match original format) + self.pal_shuffled_all = shuffled_correlations.transpose( + 0, 2, 1) # (n_shuffles, time_bins, neurons) + + # Calculate significance: proportion of shuffles where correlation >= actual correlation + # Compare (n_shuffles, neurons, time_bins) with (neurons, time_bins) + # Add new axis to pal_rho_array.T for proper broadcasting: (1, neurons, time_bins) + significance_result = np.mean( + shuffled_correlations >= self.pal_rho_array.T[np.newaxis, :, :], axis=0) + # Transpose to match original format + self.pal_significance = significance_result.T + + print(f'Shuffled palatability calculation complete.') + print( + f'Mean shuffled correlation: {np.mean(self.pal_shuffled_mean):.4f} ± {np.mean(self.pal_shuffled_std):.4f}') + print( + f'Proportion of neurons with significant palatability coding (p < 0.05): {np.mean(self.pal_significance < 0.05):.3f}') + def separate_laser_firing(self): """Separate firing rate arrays into laser on and off conditions