Skip to content

Commit ded2cb1

Browse files
Merge pull request #19 from Waltham-Data-Science/claude/apply-python-porting-guide-lVXML
Add pydantic validation decorators to epoch and ontology modules
2 parents 82aa2c8 + 7dc970c commit ded2cb1

8 files changed

Lines changed: 52 additions & 15 deletions

File tree

src/ndi/epoch/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
"""
1212

1313
from .epoch import Epoch
14-
from .epochprobemap import EpochProbeMap
14+
from .epochprobemap import EpochProbeMap, build_devicestring, parse_devicestring
1515
from .epochprobemap_daqsystem import EpochProbeMapDAQSystem
1616
from .epochset import EpochSet
1717
from .functions import epochrange, findepochnode
@@ -21,6 +21,8 @@
2121
"EpochSet",
2222
"EpochProbeMap",
2323
"EpochProbeMapDAQSystem",
24+
"build_devicestring",
2425
"epochrange",
2526
"findepochnode",
27+
"parse_devicestring",
2628
]

src/ndi/epoch/epoch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from dataclasses import dataclass, field
1111
from typing import TYPE_CHECKING, Any
1212

13+
import pydantic
14+
from pydantic import ConfigDict
15+
1316
if TYPE_CHECKING:
1417
from ..time import ClockType
1518
from .epochprobemap import EpochProbeMap
@@ -210,6 +213,7 @@ def matches_probe(
210213
return False
211214

212215

216+
@pydantic.validate_call(config=ConfigDict(arbitrary_types_allowed=True))
213217
def is_epoch_or_empty(value: Any) -> bool:
214218
"""
215219
Validate that a value is an Epoch or empty.

src/ndi/epoch/epochprobemap.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from dataclasses import dataclass
1212
from typing import Any
1313

14+
import pydantic
15+
1416

1517
@dataclass
1618
class EpochProbeMap:
@@ -141,6 +143,7 @@ def __hash__(self) -> int:
141143
return hash((self.name, self.reference, self.type))
142144

143145

146+
@pydantic.validate_call
144147
def parse_devicestring(devicestring: str) -> dict[str, str]:
145148
"""
146149
Parse a device string into components.
@@ -161,6 +164,7 @@ def parse_devicestring(devicestring: str) -> dict[str, str]:
161164
}
162165

163166

167+
@pydantic.validate_call
164168
def build_devicestring(
165169
name: str,
166170
deviceclass: str = "",

src/ndi/epoch/epochprobemap_daqsystem.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from pathlib import Path
1414
from typing import Any
1515

16+
import pydantic
17+
1618
from ..daq.daqsystemstring import DAQSystemString
1719
from .epochprobemap import EpochProbeMap
1820

@@ -84,6 +86,7 @@ def serialize(self) -> str:
8486
]
8587
)
8688

89+
@pydantic.validate_call
8790
def savetofile(self, filename: str) -> None:
8891
"""
8992
Write this epoch probe map to a file.

src/ndi/epoch/epochset.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99

1010
import hashlib
1111
from abc import ABC, abstractmethod
12-
from typing import TYPE_CHECKING, Any
12+
from typing import Annotated, Any
1313

1414
import numpy as np
15+
import pydantic
16+
from pydantic import Field
1517

16-
if TYPE_CHECKING:
17-
from ..time import ClockType
18+
from ..time import ClockType
1819

1920

2021
class EpochSet(ABC):
@@ -149,7 +150,8 @@ def numepochs(self) -> int:
149150
et, _ = self.epochtable()
150151
return len(et)
151152

152-
def epochclock(self, epoch_number: int) -> list[ClockType]:
153+
@pydantic.validate_call
154+
def epochclock(self, epoch_number: Annotated[int, Field(ge=1)]) -> list[ClockType]:
153155
"""
154156
Get clock types for an epoch.
155157
@@ -163,13 +165,14 @@ def epochclock(self, epoch_number: int) -> list[ClockType]:
163165
IndexError: If epoch_number is out of range
164166
"""
165167
et, _ = self.epochtable()
166-
if epoch_number < 1 or epoch_number > len(et):
168+
if epoch_number > len(et):
167169
raise IndexError(f"Epoch {epoch_number} out of range (1..{len(et)})")
168170

169171
entry = et[epoch_number - 1]
170172
return entry.get("epoch_clock", [])
171173

172-
def t0_t1(self, epoch_number: int) -> list[tuple[float, float]]:
174+
@pydantic.validate_call
175+
def t0_t1(self, epoch_number: Annotated[int, Field(ge=1)]) -> list[tuple[float, float]]:
173176
"""
174177
Get time range for an epoch.
175178
@@ -183,13 +186,14 @@ def t0_t1(self, epoch_number: int) -> list[tuple[float, float]]:
183186
IndexError: If epoch_number is out of range
184187
"""
185188
et, _ = self.epochtable()
186-
if epoch_number < 1 or epoch_number > len(et):
189+
if epoch_number > len(et):
187190
raise IndexError(f"Epoch {epoch_number} out of range (1..{len(et)})")
188191

189192
entry = et[epoch_number - 1]
190193
return entry.get("t0_t1", [(np.nan, np.nan)])
191194

192-
def epochid(self, epoch_number: int) -> str:
195+
@pydantic.validate_call
196+
def epochid(self, epoch_number: Annotated[int, Field(ge=1)]) -> str:
193197
"""
194198
Get epoch ID for an epoch number.
195199
@@ -203,11 +207,12 @@ def epochid(self, epoch_number: int) -> str:
203207
IndexError: If epoch_number is out of range
204208
"""
205209
et, _ = self.epochtable()
206-
if epoch_number < 1 or epoch_number > len(et):
210+
if epoch_number > len(et):
207211
raise IndexError(f"Epoch {epoch_number} out of range (1..{len(et)})")
208212

209213
return et[epoch_number - 1].get("epoch_id", "")
210214

215+
@pydantic.validate_call
211216
def epochnumber(self, epoch_id: str) -> int:
212217
"""
213218
Get epoch number for an epoch ID.
@@ -255,7 +260,8 @@ def matchedepochtable(
255260

256261
return matches
257262

258-
def epochtableentry(self, epoch_number: int) -> dict[str, Any]:
263+
@pydantic.validate_call
264+
def epochtableentry(self, epoch_number: Annotated[int, Field(ge=1)]) -> dict[str, Any]:
259265
"""
260266
Get a single epoch table entry.
261267
@@ -269,7 +275,7 @@ def epochtableentry(self, epoch_number: int) -> dict[str, Any]:
269275
IndexError: If epoch_number is out of range
270276
"""
271277
et, _ = self.epochtable()
272-
if epoch_number < 1 or epoch_number > len(et):
278+
if epoch_number > len(et):
273279
raise IndexError(f"Epoch {epoch_number} out of range (1..{len(et)})")
274280

275281
return et[epoch_number - 1]

src/ndi/epoch/functions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,13 @@
1010

1111
from typing import Any
1212

13+
import pydantic
14+
from pydantic import ConfigDict
15+
1316
from ..time import ClockType
1417

1518

19+
@pydantic.validate_call(config=ConfigDict(arbitrary_types_allowed=True))
1620
def epochrange(
1721
epochset_obj: Any,
1822
clocktype: ClockType,
@@ -118,6 +122,7 @@ def _resolve_epoch_index(
118122
raise ValueError(f"Epoch ID '{epoch}' not found")
119123

120124

125+
@pydantic.validate_call(config=ConfigDict(arbitrary_types_allowed=True))
121126
def findepochnode(
122127
epoch_node: dict[str, Any],
123128
epoch_node_array: list[dict[str, Any]],

src/ndi/ontology/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
from __future__ import annotations
1616

1717
import json
18-
from functools import lru_cache
19-
from pathlib import Path
20-
from typing import Any, Dict, List, Optional, Tuple
18+
from typing import Any
19+
20+
import pydantic
2121

2222
from .providers import PROVIDER_REGISTRY
2323

@@ -115,6 +115,7 @@ def _load_prefix_map() -> dict[str, str]:
115115
_CACHE_MAX = 100
116116

117117

118+
@pydantic.validate_call
118119
def lookup(lookup_string: str) -> OntologyResult:
119120
"""Look up a term in the appropriate ontology.
120121

src/ndi/ontology/providers.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from typing import Any
1414
from urllib.parse import quote
1515

16+
import pydantic
17+
1618
# Registry populated at module load
1719
PROVIDER_REGISTRY: dict[str, type[OntologyProvider]] = {}
1820

@@ -22,6 +24,7 @@ class OntologyProvider:
2224

2325
name: str = ""
2426

27+
@pydantic.validate_call
2528
def lookup_term(self, term: str, prefix: str = "") -> Any:
2629
"""Look up a term by ID or name. Override in subclasses."""
2730
from . import OntologyResult
@@ -49,6 +52,7 @@ class OLSProvider(OntologyProvider):
4952
ols_ontology: str = ""
5053
ols_prefix: str = ""
5154

55+
@pydantic.validate_call
5256
def lookup_term(self, term: str, prefix: str = "") -> Any:
5357

5458
prefix = prefix or self.ols_prefix
@@ -139,6 +143,7 @@ class OMProvider(OLSProvider):
139143
ols_ontology = "om"
140144
ols_prefix = "OM"
141145

146+
@pydantic.validate_call
142147
def lookup_term(self, term: str, prefix: str = "") -> Any:
143148

144149
# OM doesn't support numeric ID lookups
@@ -208,6 +213,7 @@ def _load_data(self) -> list[dict[str, str]]:
208213
NDICProvider._data = []
209214
return []
210215

216+
@pydantic.validate_call
211217
def lookup_term(self, term: str, prefix: str = "") -> Any:
212218
from . import OntologyResult
213219

@@ -230,6 +236,7 @@ class NCImProvider(OntologyProvider):
230236
name = "NCIm"
231237
_CUI_PATTERN = re.compile(r"^C\d{7}$")
232238

239+
@pydantic.validate_call
233240
def lookup_term(self, term: str, prefix: str = "") -> Any:
234241
from . import OntologyResult
235242

@@ -277,6 +284,7 @@ class NCBITaxonProvider(OntologyProvider):
277284

278285
name = "NCBITaxon"
279286

287+
@pydantic.validate_call
280288
def lookup_term(self, term: str, prefix: str = "") -> Any:
281289
from . import OntologyResult
282290

@@ -344,6 +352,7 @@ class WBStrainProvider(OntologyProvider):
344352

345353
name = "WBStrain"
346354

355+
@pydantic.validate_call
347356
def lookup_term(self, term: str, prefix: str = "") -> Any:
348357
from . import OntologyResult
349358

@@ -392,6 +401,7 @@ class RRIDProvider(OntologyProvider):
392401

393402
name = "RRID"
394403

404+
@pydantic.validate_call
395405
def lookup_term(self, term: str, prefix: str = "") -> Any:
396406
from . import OntologyResult
397407

@@ -423,6 +433,7 @@ class PubChemProvider(OntologyProvider):
423433

424434
name = "PubChem"
425435

436+
@pydantic.validate_call
426437
def lookup_term(self, term: str, prefix: str = "") -> Any:
427438
from . import OntologyResult
428439

@@ -541,6 +552,7 @@ def _load_ontology(self) -> dict[str, Any]:
541552
}
542553
return EMPTYProvider._cache
543554

555+
@pydantic.validate_call
544556
def lookup_term(self, term: str, prefix: str = "") -> Any:
545557
from ndi.fun.name_utils import name_to_variable_name
546558

0 commit comments

Comments
 (0)