diff --git a/.github/workflows/matlab-ci.yml b/.github/workflows/matlab-ci.yml new file mode 100644 index 00000000..d815ef0d --- /dev/null +++ b/.github/workflows/matlab-ci.yml @@ -0,0 +1,31 @@ +name: MATLAB CI + +on: + push: + branches: + - master + - main + pull_request: + +jobs: + test-getwaveforms-mergepoints: + runs-on: ubuntu-latest + env: + MLM_LICENSE_TOKEN: ${{ secrets.MLM_LICENSE_TOKEN }} + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up MATLAB + uses: matlab-actions/setup-matlab@v2 + with: + products: > + Signal_Processing_Toolbox + Curve_Fitting_Toolbox + Statistics_and_Machine_Learning_Toolbox + + - name: Run getWaveformsFromDat mergepoints test + uses: matlab-actions/run-command@v2 + with: + command: addpath(genpath(pwd)); results = runtests('tests/test_getWaveformsFromDat_mergepoints.m'); assertSuccess(results); diff --git a/calc_CellMetrics/getWaveformsFromDat.m b/calc_CellMetrics/getWaveformsFromDat.m index 6d654b82..0a79f51b 100644 --- a/calc_CellMetrics/getWaveformsFromDat.m +++ b/calc_CellMetrics/getWaveformsFromDat.m @@ -119,14 +119,7 @@ window_interval = wfWin-ceil(wfWinKeep*sr):wfWin-1+ceil(wfWinKeep*sr); % +- 0.8 ms of waveform window_interval2 = wfWin-ceil(1.5*wfWinKeep*sr):wfWin-1+ceil(1.5*wfWinKeep*sr); % +- 1.20 ms of waveform t1 = toc(timerVal); -if ~exist(datFile,'file') - error(['Binary file missing: ', datFile]) -end -s = dir(datFile); - -duration = s.bytes/(2*nChannels*sr); -rawData = memmapfile(datFile,'Format',precision,'writable',false); -% DATA = rawData.Data; +[waveformSource,duration,waveformSourceLabel] = initializeWaveformSource(datFile,basepath,basename,nChannels,sr,precision); % Fit exponential g = fittype('a*exp(-x/b)+c','dependent',{'y'},'independent',{'x'},'coefficients',{'a','b','c'}); @@ -167,10 +160,27 @@ % end % Pulls the waveforms from all channels from the dat - startIndicies2 = (spkTmp - wfWin)*nChannels+1; - stopIndicies2 = (spkTmp + wfWin)*nChannels; - X2 = cumsum(accumarray(cumsum([1;stopIndicies2(:)-startIndicies2(:)+1]),[startIndicies2(:);0]-[0;stopIndicies2(:)]-1)+1); - wf = LSB * permute(reshape(double(rawData.Data(X2(1:end-1))),nChannels,(wfWin*2),[]),[2,3,1]); + [wf, spkTmp] = extractWaveformsFromSource(spkTmp,wfWin,nChannels,LSB,waveformSource); + if isempty(spkTmp) + warning('No spikes remained for waveform extraction for unit %d after applying file-boundary constraints.',ii) + spikes.rawWaveform{ii} = nan(1,length(window_interval)); + spikes.rawWaveform_std{ii} = nan(1,length(window_interval)); + spikes.filtWaveform{ii} = nan(1,length(window_interval)); + spikes.filtWaveform_std{ii} = nan(1,length(window_interval)); + spikes.rawWaveform_all{ii} = nan(nChannels,length(window_interval2)); + spikes.filtWaveform_all{ii} = nan(nChannels,length(window_interval2)); + spikes.timeWaveform{ii} = ([-ceil(wfWinKeep*sr)*(1/sr):1/sr:(ceil(wfWinKeep*sr)-1)*(1/sr)])*1000; + spikes.timeWaveform_all{ii} = ([-ceil(1.5*wfWinKeep*sr)*(1/sr):1/sr:(ceil(1.5*wfWinKeep*sr)-1)*(1/sr)])*1000; + spikes.peakVoltage(ii) = nan; + spikes.channels_all{ii} = 1:nChannels; + spikes.peakVoltage_sorted{ii} = nan(1,nChannels); + spikes.maxWaveform_all{ii} = nan(1,nChannels); + spikes.maxWaveformCh1(ii) = nan; + spikes.maxWaveformCh(ii) = nan; + spikes.shankID(ii) = nan; + spikes.peakVoltage_expFitLengthConstant(ii) = nan; + continue + end wfF = zeros((wfWin * 2),length(spkTmp),nChannels); for jjj = 1 : nChannels wfF(:,:,jjj) = filtfilt(b1, a1, wf(:,:,jjj)); @@ -287,21 +297,19 @@ disp('Canceling waveform extraction...') clear wf wfF wf2 wfF2 clear rawWaveform rawWaveform_std filtWaveform filtWaveform_std - clear rawData error('Waveform extraction canceled by user by closing figure window.') end end clear wf wfF wf2 wfF2 end -spikes.processinginfo.params.WaveformsSource = 'dat file'; +spikes.processinginfo.params.WaveformsSource = waveformSourceLabel; spikes.processinginfo.params.WaveformsFiltFreq = params.filtFreq; spikes.processinginfo.params.Waveforms_nPull = params.nPull; spikes.processinginfo.params.WaveformsWin_sec = wfWin_sec; spikes.processinginfo.params.WaveformsWinKeep = wfWinKeep; spikes.processinginfo.params.WaveformsFilterType = 'butter'; clear rawWaveform rawWaveform_std filtWaveform filtWaveform_std -clear rawData % Plots if params.showWaveforms && ishandle(fig1) @@ -315,4 +323,129 @@ end end disp(['Waveform extraction complete. Total duration: ' num2str(round(toc(timerVal)/60)),' minutes']) -end \ No newline at end of file +end + +function [waveformSource,duration,waveformSourceLabel] = initializeWaveformSource(datFile,basepath,basename,nChannels,sr,precision) +sampleBytes = getPrecisionBytes(precision); + +if exist(datFile,'file') + s = dir(datFile); + waveformSource.mode = 'single'; + waveformSource.datFile = datFile; + waveformSource.precision = precision; + duration = s.bytes/(sampleBytes*nChannels*sr); + waveformSourceLabel = 'dat file'; + return +end + +mergePointsFile = fullfile(basepath,[basename,'.MergePoints.events.mat']); +if ~exist(mergePointsFile,'file') + error(['Binary file missing: ', datFile, newline, 'MergePoints file missing: ', mergePointsFile]) +end + +mergeData = load(mergePointsFile,'MergePoints'); +if ~isfield(mergeData,'MergePoints') + error('MergePoints file is missing the MergePoints struct: %s',mergePointsFile) +end + +MergePoints = mergeData.MergePoints; +if ~isfield(MergePoints,'timestamps_samples') || ~isfield(MergePoints,'foldernames') + error('MergePoints file is missing timestamps_samples or foldernames: %s',mergePointsFile) +end + +starts = double(MergePoints.timestamps_samples(:,1)); +stops = double(MergePoints.timestamps_samples(:,2)); +foldernames = MergePoints.foldernames; +if isstring(foldernames) + foldernames = cellstr(foldernames); +end + +if numel(foldernames) ~= numel(starts) + error('Mismatch between MergePoints foldernames and timestamps_samples in %s',mergePointsFile) +end + +segments = repmat(struct('foldername','','datFile','','startSample',0,'endSample',0,'nSamples',0),1,numel(foldernames)); +for i = 1:numel(foldernames) + foldername = foldernames{i}; + datPath = fullfile(basepath,foldername,'amplifier.dat'); + if ~exist(datPath,'file') + error('Expected amplifier.dat for MergePoints segment is missing: %s',datPath) + end + fileInfo = dir(datPath); + nSamples = fileInfo.bytes/(sampleBytes*nChannels); + expectedSamples = stops(i) - starts(i); + if nSamples ~= expectedSamples + error('Sample count mismatch for %s (MergePoints=%d, amplifier.dat=%d).',datPath,expectedSamples,nSamples) + end + segments(i).foldername = foldername; + segments(i).datFile = datPath; + segments(i).startSample = starts(i); + segments(i).endSample = stops(i); + segments(i).nSamples = nSamples; +end + +waveformSource.mode = 'mergepoints'; +waveformSource.segments = segments; +waveformSource.precision = precision; +duration = stops(end)/sr; +waveformSourceLabel = 'MergePoints amplifier.dat files'; +end + +function [wf, spkTmp] = extractWaveformsFromSource(spkTmp,wfWin,nChannels,LSB,waveformSource) +switch waveformSource.mode + case 'single' + wf = readWaveformsFromFile(waveformSource.datFile,spkTmp,wfWin,nChannels,LSB,waveformSource.precision); + case 'mergepoints' + [spkTmp,segmentIds] = filterSpikesForSegments(spkTmp,wfWin,waveformSource.segments); + if isempty(spkTmp) + wf = []; + return + end + wf = zeros(wfWin*2,length(spkTmp),nChannels); + uniqueSegments = unique(segmentIds); + for iSegment = uniqueSegments(:)' + idx = find(segmentIds == iSegment); + localSpikes = spkTmp(idx) - waveformSource.segments(iSegment).startSample; + wf(:,idx,:) = readWaveformsFromFile(waveformSource.segments(iSegment).datFile,localSpikes,wfWin,nChannels,LSB,waveformSource.precision); + end + otherwise + error('Unknown waveform source mode: %s',waveformSource.mode) +end +end + +function [spkTmpValid,segmentIds] = filterSpikesForSegments(spkTmp,wfWin,segments) +spkTmpValid = []; +segmentIds = []; +for i = 1:numel(segments) + inSegment = spkTmp > (segments(i).startSample + wfWin) & spkTmp <= (segments(i).endSample - wfWin); + if any(inSegment) + spkTmpValid = [spkTmpValid; spkTmp(inSegment)]; %#ok + segmentIds = [segmentIds; repmat(i,sum(inSegment),1)]; %#ok + end +end +[spkTmpValid,order] = sort(spkTmpValid); +segmentIds = segmentIds(order); +end + +function wf = readWaveformsFromFile(datFile,spkTmp,wfWin,nChannels,LSB,precision) +rawData = memmapfile(datFile,'Format',precision,'writable',false); +startIndicies = (spkTmp - wfWin)*nChannels+1; +stopIndicies = (spkTmp + wfWin)*nChannels; +X = cumsum(accumarray(cumsum([1;stopIndicies(:)-startIndicies(:)+1]),[startIndicies(:);0]-[0;stopIndicies(:)]-1)+1); +wf = LSB * permute(reshape(double(rawData.Data(X(1:end-1))),nChannels,(wfWin*2),[]),[2,3,1]); +end + +function sampleBytes = getPrecisionBytes(precision) +switch lower(precision) + case {'int16','uint16'} + sampleBytes = 2; + case {'int32','uint32','single','float32'} + sampleBytes = 4; + case {'int64','uint64','double','float64'} + sampleBytes = 8; + case {'int8','uint8','char'} + sampleBytes = 1; + otherwise + error('Unsupported extracellular precision: %s',precision) +end +end diff --git a/tests/fixtures/.gitkeep b/tests/fixtures/.gitkeep new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/tests/fixtures/.gitkeep @@ -0,0 +1 @@ + diff --git a/tests/fixtures/getWaveformsFromDat_mergepoints_templates.mat b/tests/fixtures/getWaveformsFromDat_mergepoints_templates.mat new file mode 100644 index 00000000..1075dd15 Binary files /dev/null and b/tests/fixtures/getWaveformsFromDat_mergepoints_templates.mat differ diff --git a/tests/helpers/extract_getWaveformsFromDat_templates.m b/tests/helpers/extract_getWaveformsFromDat_templates.m new file mode 100644 index 00000000..46e45358 --- /dev/null +++ b/tests/helpers/extract_getWaveformsFromDat_templates.m @@ -0,0 +1,158 @@ +function fixture = extract_getWaveformsFromDat_templates(varargin) +% Extract a few real Kilosort templates and save a compact fixture .mat file. +% +% This is intended as a one-time helper for building CI fixtures from a real +% recording. The generated .mat file is small and self-contained, so CI does +% not need access to the original R: drive data. +% +% Example: +% fixture = extract_getWaveformsFromDat_templates +% +% fixture = extract_getWaveformsFromDat_templates( ... +% 'basepath', 'R:\ys2375\Test data\test_rec_260509', ... +% 'phyFolder', 'Kilosort_2026-05-09_192146', ... +% 'outputFile', fullfile('tests','fixtures', ... +% 'getWaveformsFromDat_mergepoints_templates.mat')); + +p = inputParser; +addParameter(p, 'basepath', 'R:\ys2375\Test data\test_rec_260509', @ischar); +addParameter(p, 'phyFolder', 'Kilosort_2026-05-09_192146', @ischar); +addParameter(p, 'outputFile', fullfile('tests', 'fixtures', 'getWaveformsFromDat_mergepoints_templates.mat'), @ischar); +addParameter(p, 'nTemplates', 3, @(x) isnumeric(x) && isscalar(x) && x >= 1); +addParameter(p, 'nChannelsToKeep', 4, @(x) isnumeric(x) && isscalar(x) && x >= 1); +addParameter(p, 'cropSamples', 21, @(x) isnumeric(x) && isscalar(x) && x >= 5); +addParameter(p, 'targetPeakUv', 120, @(x) isnumeric(x) && isscalar(x) && x > 0); +parse(p, varargin{:}); + +repoRoot = fileparts(fileparts(fileparts(mfilename('fullpath')))); +addpath(genpath(repoRoot)); + +basepath = p.Results.basepath; +basename = basenameFromBasepath(basepath); +phyPath = fullfile(basepath, p.Results.phyFolder); +outputFile = p.Results.outputFile; +nTemplates = p.Results.nTemplates; +nChannelsToKeep = p.Results.nChannelsToKeep; +cropSamples = p.Results.cropSamples; +targetPeakUv = p.Results.targetPeakUv; + +assert(isfolder(basepath), 'Basepath not found: %s', basepath); +assert(isfolder(phyPath), 'Phy folder not found: %s', phyPath); + +sessionData = load(fullfile(basepath, [basename, '.session.mat']), 'session'); +session = sessionData.session; +clusterInfo = readtable(fullfile(phyPath, 'cluster_info.tsv'), 'FileType', 'text', 'Delimiter', '\t'); +templates = readNPY(fullfile(phyPath, 'templates.npy')); +spikeTemplates = double(readNPY(fullfile(phyPath, 'spike_templates.npy'))); +spikeClusters = double(readNPY(fullfile(phyPath, 'spike_clusters.npy'))); + +assert(ndims(templates) == 3, 'Expected templates.npy to be 3D.'); +assert(any(strcmp(clusterInfo.Properties.VariableNames, 'cluster_id')), ... + 'cluster_info.tsv must contain cluster_id.'); + +goodMask = true(height(clusterInfo), 1); +if any(strcmp(clusterInfo.Properties.VariableNames, 'group')) + goodMask = strcmp(string(clusterInfo.group), "good"); +elseif any(strcmp(clusterInfo.Properties.VariableNames, 'KSLabel')) + goodMask = strcmp(string(clusterInfo.KSLabel), "good"); +end + +goodClusters = clusterInfo(goodMask, :); +assert(~isempty(goodClusters), 'No good clusters found in %s', phyPath); + +if any(strcmp(goodClusters.Properties.VariableNames, 'amp')) + [~, order] = sort(goodClusters.amp, 'descend'); +else + order = 1:height(goodClusters); +end +goodClusters = goodClusters(order, :); + +selected = struct( ... + 'clusterId', {}, ... + 'templateId', {}, ... + 'globalChannels', {}, ... + 'localPeakCh1', {}, ... + 'waveformUv', {}, ... + 'waveformRawInt16', {}); + +usedChannelSets = {}; +for i = 1:height(goodClusters) + clusterId = double(goodClusters.cluster_id(i)); + templateId = choose_template_for_cluster(clusterId, spikeClusters, spikeTemplates); + templateWaveform = squeeze(double(templates(templateId + 1, :, :))); % Kilosort ids are 0-indexed + + [croppedUv, globalChannels, localPeakCh1] = crop_template(templateWaveform, nChannelsToKeep, cropSamples, targetPeakUv); + channelSignature = sprintf('%d_', globalChannels); + + if any(strcmp(usedChannelSets, channelSignature)) + continue + end + + selected(end + 1).clusterId = clusterId; %#ok + selected(end).templateId = templateId; + selected(end).globalChannels = globalChannels; + selected(end).localPeakCh1 = localPeakCh1; + selected(end).waveformUv = croppedUv; + selected(end).waveformRawInt16 = int16(round(croppedUv / session.extracellular.leastSignificantBit)); + usedChannelSets{end + 1} = channelSignature; %#ok + + if numel(selected) >= nTemplates + break + end +end + +assert(~isempty(selected), 'No templates were extracted.'); + +fixture = struct(); +fixture.sourceBasepath = basepath; +fixture.sourcePhyFolder = phyPath; +fixture.basename = basename; +fixture.sr = session.extracellular.sr; +fixture.LSB = session.extracellular.leastSignificantBit; +fixture.precision = 'int16'; +fixture.cropSamples = cropSamples; +fixture.nChannelsToKeep = nChannelsToKeep; +fixture.templates = selected; + +outputDir = fileparts(outputFile); +if ~isempty(outputDir) && ~isfolder(outputDir) + mkdir(outputDir); +end +save(outputFile, 'fixture'); + +fprintf('Saved %d templates to %s\n', numel(selected), outputFile); +for i = 1:numel(selected) + fprintf(' Template %d: cluster %d, template %d, channels %s, local peak ch %d\n', ... + i, selected(i).clusterId, selected(i).templateId, mat2str(selected(i).globalChannels), selected(i).localPeakCh1); +end + +end + +function templateId = choose_template_for_cluster(clusterId, spikeClusters, spikeTemplates) +clusterSpikeIdx = spikeClusters == clusterId; +assert(any(clusterSpikeIdx), 'Cluster %d has no spikes in spike_clusters.npy.', clusterId); +templateId = mode(spikeTemplates(clusterSpikeIdx)); +end + +function [croppedUv, channelsKept, localPeakCh1] = crop_template(templateWaveform, nChannelsToKeep, cropSamples, targetPeakUv) +% templateWaveform is [time x channels] +channelSpread = range(templateWaveform, 1); +[~, channelOrder] = sort(channelSpread, 'descend'); +channelsKept = sort(channelOrder(1:nChannelsToKeep)); + +waveformSubset = templateWaveform(:, channelsKept); +[~, localPeakCh1] = max(range(waveformSubset, 1)); +[~, peakSample] = min(waveformSubset(:, localPeakCh1)); + +halfWindow = floor(cropSamples / 2); +startSample = max(1, peakSample - halfWindow); +stopSample = min(size(waveformSubset, 1), startSample + cropSamples - 1); +startSample = max(1, stopSample - cropSamples + 1); +croppedUv = waveformSubset(startSample:stopSample, :); + +% Normalize to a predictable amplitude while preserving the real shape. +peakAbs = max(abs(croppedUv(:, localPeakCh1))); +if peakAbs > 0 + croppedUv = croppedUv * (targetPeakUv / peakAbs); +end +end diff --git a/tests/helpers/generate_getWaveformsFromDat_fixture.m b/tests/helpers/generate_getWaveformsFromDat_fixture.m new file mode 100644 index 00000000..1df20957 --- /dev/null +++ b/tests/helpers/generate_getWaveformsFromDat_fixture.m @@ -0,0 +1,130 @@ +function fixtureData = generate_getWaveformsFromDat_fixture(varargin) +% Generate a compact synthetic dataset for getWaveformsFromDat CI tests. +% +% The synthetic data uses real templates previously extracted from a real +% recording, but writes a tiny self-contained session with: +% basename.dat +% subfolder_01/amplifier.dat +% subfolder_02/amplifier.dat +% basename.MergePoints.events.mat +% +% Example: +% fixtureData = generate_getWaveformsFromDat_fixture + +p = inputParser; +addParameter(p, 'outputRoot', tempname, @ischar); +addParameter(p, 'basename', 'test_mergepoints_fixture', @ischar); +addParameter(p, 'fixtureFile', fullfile('tests', 'fixtures', 'getWaveformsFromDat_mergepoints_templates.mat'), @ischar); +addParameter(p, 'segmentDurationSec', 5, @(x) isnumeric(x) && isscalar(x) && x > 0); +addParameter(p, 'nChannels', 4, @(x) isnumeric(x) && isscalar(x) && x == 4); +addParameter(p, 'samplingRate', 20000, @(x) isnumeric(x) && isscalar(x) && x > 0); +parse(p, varargin{:}); + +repoRoot = fileparts(fileparts(fileparts(mfilename('fullpath')))); +addpath(genpath(repoRoot)); + +fixtureStruct = load(p.Results.fixtureFile, 'fixture'); +fixture = fixtureStruct.fixture; + +outputRoot = p.Results.outputRoot; +basename = p.Results.basename; +segmentDurationSec = p.Results.segmentDurationSec; +nChannels = p.Results.nChannels; +sr = p.Results.samplingRate; + +segmentSamples = round(segmentDurationSec * sr); +totalSamples = segmentSamples * 2; + +mkdir(outputRoot); +mkdir(fullfile(outputRoot, 'subfolder_01')); +mkdir(fullfile(outputRoot, 'subfolder_02')); + +spikePlan = build_spike_plan(); +rawData = zeros(totalSamples, nChannels, 'int32'); + +for iUnit = 1:numel(spikePlan) + template = fixture.templates(iUnit).waveformRawInt16; + centerSample = ceil(size(template, 1) / 2); + for iSpike = 1:numel(spikePlan(iUnit).timesSec) + spikeSample = round(spikePlan(iUnit).timesSec(iSpike) * sr); + rowIdx = spikeSample - centerSample + 1 : spikeSample - centerSample + size(template, 1); + rawData(rowIdx, :) = rawData(rowIdx, :) + int32(template); + end +end + +rawData = int16(rawData); +segment1 = rawData(1:segmentSamples, :); +segment2 = rawData(segmentSamples + 1:end, :); + +write_binary(fullfile(outputRoot, [basename, '.dat']), rawData); +write_binary(fullfile(outputRoot, 'subfolder_01', 'amplifier.dat'), segment1); +write_binary(fullfile(outputRoot, 'subfolder_02', 'amplifier.dat'), segment2); + +MergePoints = struct(); +MergePoints.timestamps_samples = [ + 0, segmentSamples; + segmentSamples, totalSamples +]; +MergePoints.foldernames = {'subfolder_01', 'subfolder_02'}; +save(fullfile(outputRoot, [basename, '.MergePoints.events.mat']), 'MergePoints'); + +session = struct(); +session.general.name = basename; +session.general.basePath = outputRoot; +session.extracellular.leastSignificantBit = fixture.LSB; +session.extracellular.nChannels = nChannels; +session.extracellular.sr = sr; +session.extracellular.precision = fixture.precision; +session.extracellular.fileName = [basename, '.dat']; +session.extracellular.nElectrodeGroups = 1; +session.extracellular.electrodeGroups.channels = {1:nChannels}; +session.extracellular.spikeGroups.channels = {1:nChannels}; +save(fullfile(outputRoot, [basename, '.session.mat']), 'session'); + +spikes = struct(); +spikes.basename = basename; +spikes.sr = sr; +spikes.UID = 1:numel(spikePlan); +spikes.cluID = 101:100 + numel(spikePlan); +spikes.times = cell(1, numel(spikePlan)); +spikes.ts = cell(1, numel(spikePlan)); +spikes.total = zeros(1, numel(spikePlan)); +for iUnit = 1:numel(spikePlan) + spikes.times{iUnit} = spikePlan(iUnit).timesSec; + spikes.ts{iUnit} = round(spikePlan(iUnit).timesSec * sr); + spikes.total(iUnit) = numel(spikePlan(iUnit).timesSec); +end + +fixtureData = struct(); +fixtureData.basepath = outputRoot; +fixtureData.basename = basename; +fixtureData.session = session; +fixtureData.spikes = spikes; +fixtureData.segmentSamples = segmentSamples; +fixtureData.safeIntervalsSec = [ + 40 / sr, segmentDurationSec - 40 / sr; + segmentDurationSec + 40 / sr, 2 * segmentDurationSec - 40 / sr +]; +fixtureData.boundarySpikeSec = spikePlan(3).timesSec(end-1:end); +fixtureData.expectedSafeTimes = { + spikePlan(1).timesSec, ... + spikePlan(2).timesSec, ... + spikePlan(3).timesSec(1:end-2) +}; +fixtureData.expectedPeakChannels = [fixture.templates(1).localPeakCh1, fixture.templates(2).localPeakCh1, fixture.templates(3).localPeakCh1]; +end + +function spikePlan = build_spike_plan() +spikePlan = struct('timesSec', {}); +spikePlan(1).timesSec = [1.0, 2.0, 3.5, 4.0]; +spikePlan(2).timesSec = [6.0, 7.0, 8.0, 9.0]; +spikePlan(3).timesSec = [2.5, 4.5, 5.5, 7.5, 4.9990, 5.0010]; +end + +function write_binary(filename, data) +fid = fopen(filename, 'w'); +assert(fid > 0, 'Failed to open %s for writing.', filename); +cleanupObj = onCleanup(@() fclose(fid)); %#ok +count = fwrite(fid, data', 'int16'); +assert(count == numel(data), 'Failed to write expected data count to %s.', filename); +end diff --git a/tests/test_getWaveformsFromDat_mergepoints.m b/tests/test_getWaveformsFromDat_mergepoints.m new file mode 100644 index 00000000..51dc18a5 --- /dev/null +++ b/tests/test_getWaveformsFromDat_mergepoints.m @@ -0,0 +1,79 @@ +function tests = test_getWaveformsFromDat_mergepoints +tests = functiontests(localfunctions); +end + +function testMergePointsFallbackMatchesSafeRegion(testCase) +repoRoot = fileparts(fileparts(mfilename('fullpath'))); +addpath(genpath(repoRoot)); + +fixtureData = generate_getWaveformsFromDat_fixture('outputRoot', tempname); +cleanupObj = onCleanup(@() cleanup_fixture(fixtureData.basepath)); %#ok + +directSpikes = run_waveform_extraction(fixtureData.spikes, fixtureData.session); + +datFile = fullfile(fixtureData.basepath, [fixtureData.basename, '.dat']); +backupDatFile = [datFile, '.bak']; +movefile(datFile, backupDatFile); +restoreObj = onCleanup(@() restore_dat(datFile, backupDatFile)); %#ok + +fallbackSpikes = run_waveform_extraction(fixtureData.spikes, fixtureData.session); + +verifyEqual(testCase, directSpikes.processinginfo.params.WaveformsSource, 'dat file'); +verifyEqual(testCase, fallbackSpikes.processinginfo.params.WaveformsSource, 'MergePoints amplifier.dat files'); +verifyEqual(testCase, directSpikes.cluID, fallbackSpikes.cluID); + +for iUnit = 1:numel(fixtureData.spikes.times) + verifyEqual(testCase, directSpikes.maxWaveformCh1(iUnit), fixtureData.expectedPeakChannels(iUnit)); + verifyEqual(testCase, fallbackSpikes.maxWaveformCh1(iUnit), fixtureData.expectedPeakChannels(iUnit)); + + directSafe = restrict_times_to_intervals(directSpikes.waveforms.times{iUnit}, fixtureData.safeIntervalsSec); + fallbackSafe = restrict_times_to_intervals(fallbackSpikes.waveforms.times{iUnit}, fixtureData.safeIntervalsSec); + + verifyEqual(testCase, directSafe(:), fixtureData.expectedSafeTimes{iUnit}(:), 'AbsTol', 1e-12); + verifyEqual(testCase, fallbackSafe(:), fixtureData.expectedSafeTimes{iUnit}(:), 'AbsTol', 1e-12); + verifyEqual(testCase, directSafe(:), fallbackSafe(:), 'AbsTol', 1e-12); + + % Units 1-2 contain only safe spikes, so their mean waveforms should + % match exactly between the direct and fallback paths. + if iUnit <= 2 + verifyEqual(testCase, directSpikes.rawWaveform{iUnit}, fallbackSpikes.rawWaveform{iUnit}, 'AbsTol', 1e-9); + verifyEqual(testCase, directSpikes.filtWaveform{iUnit}, fallbackSpikes.filtWaveform{iUnit}, 'AbsTol', 1e-9); + end +end + +verifyEqual(testCase, fallbackSpikes.waveforms.times{3}(:), fixtureData.expectedSafeTimes{3}(:), 'AbsTol', 1e-12); +verifyTrue(testCase, any(abs(directSpikes.waveforms.times{3} - fixtureData.boundarySpikeSec(1)) < 1e-12)); +verifyTrue(testCase, any(abs(directSpikes.waveforms.times{3} - fixtureData.boundarySpikeSec(2)) < 1e-12)); +verifyFalse(testCase, any(abs(fallbackSpikes.waveforms.times{3} - fixtureData.boundarySpikeSec(1)) < 1e-12)); +verifyFalse(testCase, any(abs(fallbackSpikes.waveforms.times{3} - fixtureData.boundarySpikeSec(2)) < 1e-12)); +end + +function spikesOut = run_waveform_extraction(spikesIn, session) +rng(1); +spikesOut = getWaveformsFromDat( ... + spikesIn, session, ... + 'showWaveforms', false, ... + 'saveMat', false, ... + 'keepWaveforms_raw', true, ... + 'nPull', 1000000); +end + +function restricted = restrict_times_to_intervals(timesSec, intervalsSec) +keep = false(size(timesSec)); +for i = 1:size(intervalsSec, 1) + keep = keep | (timesSec > intervalsSec(i,1) & timesSec <= intervalsSec(i,2)); +end +restricted = timesSec(keep); +end + +function restore_dat(datFile, backupDatFile) +if exist(backupDatFile, 'file') == 2 && exist(datFile, 'file') ~= 2 + movefile(backupDatFile, datFile); +end +end + +function cleanup_fixture(basepath) +if isfolder(basepath) + rmdir(basepath, 's'); +end +end