Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/ndi/fun/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

MATLAB equivalent: +ndi/+fun/

Provides document, epoch, file, data, stimulus, session, and dataset utilities.
Provides document, epoch, file, data, stimulus, session, dataset,
and probe utilities.
"""

from __future__ import annotations

from . import probe # noqa: F401 — make ndi.fun.probe accessible
from .utils import (
channelname2prefixnumber,
name2variable_name,
Expand All @@ -18,6 +20,7 @@
__all__ = [
"channelname2prefixnumber",
"name2variable_name",
"probe",
"pseudorandomint",
"timestamp",
]
19 changes: 19 additions & 0 deletions src/ndi/fun/probe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""
ndi.fun.probe - Probe utility functions.

MATLAB equivalent: +ndi/+fun/+probe/

Provides utility functions for exporting probe data and finding
probe location documents.
"""

from __future__ import annotations

from .export_binary import export_all_binary, export_binary
from .location import location

__all__ = [
"export_all_binary",
"export_binary",
"location",
]
171 changes: 171 additions & 0 deletions src/ndi/fun/probe/export_binary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
"""
ndi.fun.probe.export_binary - Export probe data to binary files.

MATLAB equivalents:
+ndi/+fun/+probe/export_binary.m
+ndi/+fun/+probe/export_all_binary.m
"""

from __future__ import annotations

from pathlib import Path
from typing import Any

import numpy as np


def export_binary(
probe: Any,
outputfile: str | Path,
*,
multiplier: float = 1.0,
verbose: bool = True,
precision: str = "int16",
) -> None:
"""Export data from a probe to a binary file.

MATLAB equivalent: ndi.fun.probe.export_binary

Exports data from *probe* (an :class:`ndi.element.Element` or
:class:`ndi.probe.Probe` of type ``n-trode``) to a binary file.
Before converting to the output precision the data are scaled by
*multiplier*. A text metadata file is created alongside *outputfile*
with the extension ``.metadata``.

Args:
probe: An NDI probe/element object with ``epochtable``,
``times2samples``, ``readtimeseries``, ``samplerate``, and
``elementstring`` methods.
outputfile: Path for the output binary file.
multiplier: Scaling factor applied to data before conversion.
verbose: If ``True``, print progress messages.
precision: NumPy-compatible dtype string for the output
(default ``'int16'``).
"""
outputfile = Path(outputfile)
metafile = outputfile.with_suffix(outputfile.suffix + ".metadata")

et = probe.epochtable()
if isinstance(et, tuple):
et = et[0]

dtype = np.dtype(precision)
chunk_duration = 100 # seconds

epoch_sample_counts: list[int] = []
epoch_sample_rates: list[float] = []
num_channels = 0

with open(outputfile, "wb") as fid:
for e_idx, entry in enumerate(et):
epoch_id = entry.get("epoch_id", e_idx + 1)
if verbose:
print(f"Processing epoch {e_idx + 1} of {len(et)}.")

t0_t1 = entry.get("t0_t1", [])
if isinstance(t0_t1, list) and len(t0_t1) > 0:
t0_t1_pair = t0_t1[0]
else:
t0_t1_pair = t0_t1

if isinstance(t0_t1_pair, (list, tuple, np.ndarray)) and len(t0_t1_pair) >= 2:
t_start = float(t0_t1_pair[0])
t_end = float(t0_t1_pair[1])
else:
continue

samples = probe.times2samples(epoch_id, np.array([t_start, t_end]))
sample_count = int(samples[1] - samples[0] + 1)
epoch_sample_counts.append(sample_count)

sr = probe.samplerate(epoch_id)
epoch_sample_rates.append(float(sr))
single_sample_time = 1.0 / sr if sr > 0 else 0.0

chunk_starts = np.arange(t_start, t_end, chunk_duration)
for c_idx, cs in enumerate(chunk_starts):
if verbose:
print(
f" Processing epoch {e_idx + 1}, "
f"chunk {c_idx + 1} of {len(chunk_starts)}."
)
start_time = float(cs)
end_time = min(cs + chunk_duration - single_sample_time, t_end)

data, _t, _tr = probe.readtimeseries(epoch=epoch_id, t0=start_time, t1=end_time)
if data is None or len(data) == 0:
continue

num_channels = data.shape[1] if data.ndim == 2 else 1

# Scale and convert — write channel-interleaved (transposed)
scaled = (multiplier * data).T
out = scaled.astype(dtype)
fid.write(out.tobytes())

# Write metadata file
probe_name = probe.elementstring()
with open(metafile, "w") as mf:
mf.write(f"epoch_sample_counts: {epoch_sample_counts}\n")
mf.write(f"epoch_sample_rates: {epoch_sample_rates}\n")
mf.write(f"multiplier: {multiplier}\n")
mf.write(f"num_channels: {num_channels}\n")
mf.write(f"probe_name: {probe_name}\n")


def export_all_binary(
session: Any,
*,
kilosort_dir: str = "kilosort",
verbose: bool = True,
multiplier: float = 1 / 0.195,
) -> None:
"""Export all n-trode probes in a session to binary files.

MATLAB equivalent: ndi.fun.probe.export_all_binary

Creates a *kilosort_dir* directory inside the session path. For each
probe of type ``n-trode``, a subdirectory named after the probe's
element string is created and a ``kilosort.bin`` file is written using
:func:`export_binary`.

Args:
session: An NDI session object (must have ``path`` and
``getprobes`` attributes).
kilosort_dir: Name of the output subdirectory (default
``'kilosort'``).
verbose: If ``True``, print progress messages.
multiplier: Scaling factor (default ``1/0.195``, assumes Intan
data).
"""
if verbose:
print(f"About to look for probes in {session.reference}")

probe_list = session.getprobes(type="n-trode")

if verbose:
print(f"Found {len(probe_list)} probe(s) of type 'n-trode'.")

kilosort_path = Path(session.path) / kilosort_dir
kilosort_path.mkdir(parents=True, exist_ok=True)

for probe in probe_list:
elestr = probe.elementstring()
if verbose:
print(f"Now working on probe {elestr}.")

# Replace spaces with underscores for directory name
safe_name = elestr.replace(" ", "_")
this_path = kilosort_path / safe_name
this_path.mkdir(parents=True, exist_ok=True)

outfile = this_path / "kilosort.bin"
export_binary(
probe,
outfile,
multiplier=multiplier,
verbose=verbose,
)

if verbose:
print(f"Done processing {session.reference}")
69 changes: 69 additions & 0 deletions src/ndi/fun/probe/location.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""
ndi.fun.probe.location - Find probe location documents for an element.

MATLAB equivalent: +ndi/+fun/+probe/location.m
"""

from __future__ import annotations

from typing import Any


def location(
session: Any,
element: Any | str,
) -> tuple[list[Any], Any | None]:
"""Find probe location documents and probe object for an NDI element.

MATLAB equivalent: ndi.fun.probe.location

Given an NDI element *element*, traverse down the ``underlying_element``
dependency tree until an :class:`ndi.probe.Probe` object is found, then
return all ``probe_location`` documents associated with that probe.

Args:
session: An NDI session or dataset object.
element: An :class:`ndi.element.Element` object **or** the string
identifier of an element.

Returns:
Tuple of ``(probe_locations, probe_obj)`` where *probe_locations*
is a list of probe-location documents and *probe_obj* is the
:class:`ndi.probe.Probe` found (or ``None`` if none was found).
"""
from ndi.database_fun import ndi_document2ndi_object
from ndi.probe import Probe
from ndi.query import Query

# Step 1: resolve string identifier to an element object
if isinstance(element, str):
docs = session.database_search(Query("base.id", "exact_string", element, ""))
if not docs:
raise ValueError(f"Could not find an element with id '{element}'.")
element = ndi_document2ndi_object(docs[0], session)

# Step 2: traverse down to the probe
current = element
while not isinstance(current, Probe):
underlying = getattr(current, "underlying_element", None)
if underlying is None:
break
if callable(underlying) and not isinstance(underlying, property):
underlying = underlying()
current = underlying

probe_obj: Any | None = current if isinstance(current, Probe) else None

if probe_obj is None:
return [], None

# Step 3: get probe identifier
probe_id = probe_obj.id
if callable(probe_id):
probe_id = probe_id()

# Step 4: query for probe_location documents
q = Query("", "depends_on", "probe_id", probe_id) & Query("", "isa", "probe_location")
probe_locations = session.database_search(q)

return probe_locations, probe_obj