Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions .github/workflows/matlab-ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: MATLAB CI

on:
push:
branches:
- master
- main
pull_request:

jobs:
test-getwaveforms-mergepoints:
runs-on: ubuntu-latest
env:
MLM_LICENSE_TOKEN: ${{ secrets.MLM_LICENSE_TOKEN }}

steps:
- name: Checkout repository
uses: actions/checkout@v4

- name: Set up MATLAB
uses: matlab-actions/setup-matlab@v2
with:
products: >
Signal_Processing_Toolbox
Curve_Fitting_Toolbox
Statistics_and_Machine_Learning_Toolbox

- name: Run getWaveformsFromDat mergepoints test
uses: matlab-actions/run-command@v2
with:
command: addpath(genpath(pwd)); results = runtests('tests/test_getWaveformsFromDat_mergepoints.m'); assertSuccess(results);
165 changes: 149 additions & 16 deletions calc_CellMetrics/getWaveformsFromDat.m
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,7 @@
window_interval = wfWin-ceil(wfWinKeep*sr):wfWin-1+ceil(wfWinKeep*sr); % +- 0.8 ms of waveform
window_interval2 = wfWin-ceil(1.5*wfWinKeep*sr):wfWin-1+ceil(1.5*wfWinKeep*sr); % +- 1.20 ms of waveform
t1 = toc(timerVal);
if ~exist(datFile,'file')
error(['Binary file missing: ', datFile])
end
s = dir(datFile);

duration = s.bytes/(2*nChannels*sr);
rawData = memmapfile(datFile,'Format',precision,'writable',false);
% DATA = rawData.Data;
[waveformSource,duration,waveformSourceLabel] = initializeWaveformSource(datFile,basepath,basename,nChannels,sr,precision);

% Fit exponential
g = fittype('a*exp(-x/b)+c','dependent',{'y'},'independent',{'x'},'coefficients',{'a','b','c'});
Expand Down Expand Up @@ -167,10 +160,27 @@
% end

% Pulls the waveforms from all channels from the dat
startIndicies2 = (spkTmp - wfWin)*nChannels+1;
stopIndicies2 = (spkTmp + wfWin)*nChannels;
X2 = cumsum(accumarray(cumsum([1;stopIndicies2(:)-startIndicies2(:)+1]),[startIndicies2(:);0]-[0;stopIndicies2(:)]-1)+1);
wf = LSB * permute(reshape(double(rawData.Data(X2(1:end-1))),nChannels,(wfWin*2),[]),[2,3,1]);
[wf, spkTmp] = extractWaveformsFromSource(spkTmp,wfWin,nChannels,LSB,waveformSource);
if isempty(spkTmp)
warning('No spikes remained for waveform extraction for unit %d after applying file-boundary constraints.',ii)
spikes.rawWaveform{ii} = nan(1,length(window_interval));
spikes.rawWaveform_std{ii} = nan(1,length(window_interval));
spikes.filtWaveform{ii} = nan(1,length(window_interval));
spikes.filtWaveform_std{ii} = nan(1,length(window_interval));
spikes.rawWaveform_all{ii} = nan(nChannels,length(window_interval2));
spikes.filtWaveform_all{ii} = nan(nChannels,length(window_interval2));
spikes.timeWaveform{ii} = ([-ceil(wfWinKeep*sr)*(1/sr):1/sr:(ceil(wfWinKeep*sr)-1)*(1/sr)])*1000;
spikes.timeWaveform_all{ii} = ([-ceil(1.5*wfWinKeep*sr)*(1/sr):1/sr:(ceil(1.5*wfWinKeep*sr)-1)*(1/sr)])*1000;
spikes.peakVoltage(ii) = nan;
spikes.channels_all{ii} = 1:nChannels;
spikes.peakVoltage_sorted{ii} = nan(1,nChannels);
spikes.maxWaveform_all{ii} = nan(1,nChannels);
spikes.maxWaveformCh1(ii) = nan;
spikes.maxWaveformCh(ii) = nan;
spikes.shankID(ii) = nan;
spikes.peakVoltage_expFitLengthConstant(ii) = nan;
continue
end
wfF = zeros((wfWin * 2),length(spkTmp),nChannels);
for jjj = 1 : nChannels
wfF(:,:,jjj) = filtfilt(b1, a1, wf(:,:,jjj));
Expand Down Expand Up @@ -287,21 +297,19 @@
disp('Canceling waveform extraction...')
clear wf wfF wf2 wfF2
clear rawWaveform rawWaveform_std filtWaveform filtWaveform_std
clear rawData
error('Waveform extraction canceled by user by closing figure window.')
end
end
clear wf wfF wf2 wfF2
end

spikes.processinginfo.params.WaveformsSource = 'dat file';
spikes.processinginfo.params.WaveformsSource = waveformSourceLabel;
spikes.processinginfo.params.WaveformsFiltFreq = params.filtFreq;
spikes.processinginfo.params.Waveforms_nPull = params.nPull;
spikes.processinginfo.params.WaveformsWin_sec = wfWin_sec;
spikes.processinginfo.params.WaveformsWinKeep = wfWinKeep;
spikes.processinginfo.params.WaveformsFilterType = 'butter';
clear rawWaveform rawWaveform_std filtWaveform filtWaveform_std
clear rawData

% Plots
if params.showWaveforms && ishandle(fig1)
Expand All @@ -315,4 +323,129 @@
end
end
disp(['Waveform extraction complete. Total duration: ' num2str(round(toc(timerVal)/60)),' minutes'])
end
end

function [waveformSource,duration,waveformSourceLabel] = initializeWaveformSource(datFile,basepath,basename,nChannels,sr,precision)
sampleBytes = getPrecisionBytes(precision);

if exist(datFile,'file')
s = dir(datFile);
waveformSource.mode = 'single';
waveformSource.datFile = datFile;
waveformSource.precision = precision;
duration = s.bytes/(sampleBytes*nChannels*sr);
waveformSourceLabel = 'dat file';
return
end

mergePointsFile = fullfile(basepath,[basename,'.MergePoints.events.mat']);
if ~exist(mergePointsFile,'file')
error(['Binary file missing: ', datFile, newline, 'MergePoints file missing: ', mergePointsFile])
end

mergeData = load(mergePointsFile,'MergePoints');
if ~isfield(mergeData,'MergePoints')
error('MergePoints file is missing the MergePoints struct: %s',mergePointsFile)
end

MergePoints = mergeData.MergePoints;
if ~isfield(MergePoints,'timestamps_samples') || ~isfield(MergePoints,'foldernames')
error('MergePoints file is missing timestamps_samples or foldernames: %s',mergePointsFile)
end

starts = double(MergePoints.timestamps_samples(:,1));
stops = double(MergePoints.timestamps_samples(:,2));
foldernames = MergePoints.foldernames;
if isstring(foldernames)
foldernames = cellstr(foldernames);
end

if numel(foldernames) ~= numel(starts)
error('Mismatch between MergePoints foldernames and timestamps_samples in %s',mergePointsFile)
end

segments = repmat(struct('foldername','','datFile','','startSample',0,'endSample',0,'nSamples',0),1,numel(foldernames));
for i = 1:numel(foldernames)
foldername = foldernames{i};
datPath = fullfile(basepath,foldername,'amplifier.dat');
if ~exist(datPath,'file')
error('Expected amplifier.dat for MergePoints segment is missing: %s',datPath)
end
fileInfo = dir(datPath);
nSamples = fileInfo.bytes/(sampleBytes*nChannels);
expectedSamples = stops(i) - starts(i);
if nSamples ~= expectedSamples
error('Sample count mismatch for %s (MergePoints=%d, amplifier.dat=%d).',datPath,expectedSamples,nSamples)
end
segments(i).foldername = foldername;
segments(i).datFile = datPath;
segments(i).startSample = starts(i);
segments(i).endSample = stops(i);
segments(i).nSamples = nSamples;
end

waveformSource.mode = 'mergepoints';
waveformSource.segments = segments;
waveformSource.precision = precision;
duration = stops(end)/sr;
waveformSourceLabel = 'MergePoints amplifier.dat files';
end

function [wf, spkTmp] = extractWaveformsFromSource(spkTmp,wfWin,nChannels,LSB,waveformSource)
switch waveformSource.mode
case 'single'
wf = readWaveformsFromFile(waveformSource.datFile,spkTmp,wfWin,nChannels,LSB,waveformSource.precision);
case 'mergepoints'
[spkTmp,segmentIds] = filterSpikesForSegments(spkTmp,wfWin,waveformSource.segments);
if isempty(spkTmp)
wf = [];
return
end
wf = zeros(wfWin*2,length(spkTmp),nChannels);
uniqueSegments = unique(segmentIds);
for iSegment = uniqueSegments(:)'
idx = find(segmentIds == iSegment);
localSpikes = spkTmp(idx) - waveformSource.segments(iSegment).startSample;
wf(:,idx,:) = readWaveformsFromFile(waveformSource.segments(iSegment).datFile,localSpikes,wfWin,nChannels,LSB,waveformSource.precision);
end
otherwise
error('Unknown waveform source mode: %s',waveformSource.mode)
end
end

function [spkTmpValid,segmentIds] = filterSpikesForSegments(spkTmp,wfWin,segments)
spkTmpValid = [];
segmentIds = [];
for i = 1:numel(segments)
inSegment = spkTmp > (segments(i).startSample + wfWin) & spkTmp <= (segments(i).endSample - wfWin);
if any(inSegment)
spkTmpValid = [spkTmpValid; spkTmp(inSegment)]; %#ok<AGROW>
segmentIds = [segmentIds; repmat(i,sum(inSegment),1)]; %#ok<AGROW>
end
end
[spkTmpValid,order] = sort(spkTmpValid);
segmentIds = segmentIds(order);
end

function wf = readWaveformsFromFile(datFile,spkTmp,wfWin,nChannels,LSB,precision)
rawData = memmapfile(datFile,'Format',precision,'writable',false);
startIndicies = (spkTmp - wfWin)*nChannels+1;
stopIndicies = (spkTmp + wfWin)*nChannels;
X = cumsum(accumarray(cumsum([1;stopIndicies(:)-startIndicies(:)+1]),[startIndicies(:);0]-[0;stopIndicies(:)]-1)+1);
wf = LSB * permute(reshape(double(rawData.Data(X(1:end-1))),nChannels,(wfWin*2),[]),[2,3,1]);
end

function sampleBytes = getPrecisionBytes(precision)
switch lower(precision)
case {'int16','uint16'}
sampleBytes = 2;
case {'int32','uint32','single','float32'}
sampleBytes = 4;
case {'int64','uint64','double','float64'}
sampleBytes = 8;
case {'int8','uint8','char'}
sampleBytes = 1;
otherwise
error('Unsupported extracellular precision: %s',precision)
end
end
1 change: 1 addition & 0 deletions tests/fixtures/.gitkeep
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

Binary file not shown.
158 changes: 158 additions & 0 deletions tests/helpers/extract_getWaveformsFromDat_templates.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
function fixture = extract_getWaveformsFromDat_templates(varargin)
% Extract a few real Kilosort templates and save a compact fixture .mat file.
%
% This is intended as a one-time helper for building CI fixtures from a real
% recording. The generated .mat file is small and self-contained, so CI does
% not need access to the original R: drive data.
%
% Example:
% fixture = extract_getWaveformsFromDat_templates
%
% fixture = extract_getWaveformsFromDat_templates( ...
% 'basepath', 'R:\ys2375\Test data\test_rec_260509', ...
% 'phyFolder', 'Kilosort_2026-05-09_192146', ...
% 'outputFile', fullfile('tests','fixtures', ...
% 'getWaveformsFromDat_mergepoints_templates.mat'));

p = inputParser;
addParameter(p, 'basepath', 'R:\ys2375\Test data\test_rec_260509', @ischar);
addParameter(p, 'phyFolder', 'Kilosort_2026-05-09_192146', @ischar);
addParameter(p, 'outputFile', fullfile('tests', 'fixtures', 'getWaveformsFromDat_mergepoints_templates.mat'), @ischar);
addParameter(p, 'nTemplates', 3, @(x) isnumeric(x) && isscalar(x) && x >= 1);
addParameter(p, 'nChannelsToKeep', 4, @(x) isnumeric(x) && isscalar(x) && x >= 1);
addParameter(p, 'cropSamples', 21, @(x) isnumeric(x) && isscalar(x) && x >= 5);
addParameter(p, 'targetPeakUv', 120, @(x) isnumeric(x) && isscalar(x) && x > 0);
parse(p, varargin{:});

repoRoot = fileparts(fileparts(fileparts(mfilename('fullpath'))));
addpath(genpath(repoRoot));

basepath = p.Results.basepath;
basename = basenameFromBasepath(basepath);
phyPath = fullfile(basepath, p.Results.phyFolder);
outputFile = p.Results.outputFile;
nTemplates = p.Results.nTemplates;
nChannelsToKeep = p.Results.nChannelsToKeep;
cropSamples = p.Results.cropSamples;
targetPeakUv = p.Results.targetPeakUv;

assert(isfolder(basepath), 'Basepath not found: %s', basepath);
assert(isfolder(phyPath), 'Phy folder not found: %s', phyPath);

sessionData = load(fullfile(basepath, [basename, '.session.mat']), 'session');
session = sessionData.session;
clusterInfo = readtable(fullfile(phyPath, 'cluster_info.tsv'), 'FileType', 'text', 'Delimiter', '\t');
templates = readNPY(fullfile(phyPath, 'templates.npy'));
spikeTemplates = double(readNPY(fullfile(phyPath, 'spike_templates.npy')));
spikeClusters = double(readNPY(fullfile(phyPath, 'spike_clusters.npy')));

assert(ndims(templates) == 3, 'Expected templates.npy to be 3D.');
assert(any(strcmp(clusterInfo.Properties.VariableNames, 'cluster_id')), ...
'cluster_info.tsv must contain cluster_id.');

goodMask = true(height(clusterInfo), 1);
if any(strcmp(clusterInfo.Properties.VariableNames, 'group'))
goodMask = strcmp(string(clusterInfo.group), "good");
elseif any(strcmp(clusterInfo.Properties.VariableNames, 'KSLabel'))
goodMask = strcmp(string(clusterInfo.KSLabel), "good");
end

goodClusters = clusterInfo(goodMask, :);
assert(~isempty(goodClusters), 'No good clusters found in %s', phyPath);

if any(strcmp(goodClusters.Properties.VariableNames, 'amp'))
[~, order] = sort(goodClusters.amp, 'descend');
else
order = 1:height(goodClusters);
end
goodClusters = goodClusters(order, :);

selected = struct( ...
'clusterId', {}, ...
'templateId', {}, ...
'globalChannels', {}, ...
'localPeakCh1', {}, ...
'waveformUv', {}, ...
'waveformRawInt16', {});

usedChannelSets = {};
for i = 1:height(goodClusters)
clusterId = double(goodClusters.cluster_id(i));
templateId = choose_template_for_cluster(clusterId, spikeClusters, spikeTemplates);
templateWaveform = squeeze(double(templates(templateId + 1, :, :))); % Kilosort ids are 0-indexed

[croppedUv, globalChannels, localPeakCh1] = crop_template(templateWaveform, nChannelsToKeep, cropSamples, targetPeakUv);
channelSignature = sprintf('%d_', globalChannels);

if any(strcmp(usedChannelSets, channelSignature))
continue
end

selected(end + 1).clusterId = clusterId; %#ok<AGROW>
selected(end).templateId = templateId;
selected(end).globalChannels = globalChannels;
selected(end).localPeakCh1 = localPeakCh1;
selected(end).waveformUv = croppedUv;
selected(end).waveformRawInt16 = int16(round(croppedUv / session.extracellular.leastSignificantBit));
usedChannelSets{end + 1} = channelSignature; %#ok<AGROW>

if numel(selected) >= nTemplates
break
end
end

assert(~isempty(selected), 'No templates were extracted.');

fixture = struct();
fixture.sourceBasepath = basepath;
fixture.sourcePhyFolder = phyPath;
fixture.basename = basename;
fixture.sr = session.extracellular.sr;
fixture.LSB = session.extracellular.leastSignificantBit;
fixture.precision = 'int16';
fixture.cropSamples = cropSamples;
fixture.nChannelsToKeep = nChannelsToKeep;
fixture.templates = selected;

outputDir = fileparts(outputFile);
if ~isempty(outputDir) && ~isfolder(outputDir)
mkdir(outputDir);
end
save(outputFile, 'fixture');

fprintf('Saved %d templates to %s\n', numel(selected), outputFile);
for i = 1:numel(selected)
fprintf(' Template %d: cluster %d, template %d, channels %s, local peak ch %d\n', ...
i, selected(i).clusterId, selected(i).templateId, mat2str(selected(i).globalChannels), selected(i).localPeakCh1);
end

end

function templateId = choose_template_for_cluster(clusterId, spikeClusters, spikeTemplates)
clusterSpikeIdx = spikeClusters == clusterId;
assert(any(clusterSpikeIdx), 'Cluster %d has no spikes in spike_clusters.npy.', clusterId);
templateId = mode(spikeTemplates(clusterSpikeIdx));
end

function [croppedUv, channelsKept, localPeakCh1] = crop_template(templateWaveform, nChannelsToKeep, cropSamples, targetPeakUv)
% templateWaveform is [time x channels]
channelSpread = range(templateWaveform, 1);
[~, channelOrder] = sort(channelSpread, 'descend');
channelsKept = sort(channelOrder(1:nChannelsToKeep));

waveformSubset = templateWaveform(:, channelsKept);
[~, localPeakCh1] = max(range(waveformSubset, 1));
[~, peakSample] = min(waveformSubset(:, localPeakCh1));

halfWindow = floor(cropSamples / 2);
startSample = max(1, peakSample - halfWindow);
stopSample = min(size(waveformSubset, 1), startSample + cropSamples - 1);
startSample = max(1, stopSample - cropSamples + 1);
croppedUv = waveformSubset(startSample:stopSample, :);

% Normalize to a predictable amplitude while preserving the real shape.
peakAbs = max(abs(croppedUv(:, localPeakCh1)));
if peakAbs > 0
croppedUv = croppedUv * (targetPeakUv / peakAbs);
end
end
Loading
Loading