diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 6936d3c..c201806 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -63,5 +63,5 @@ jobs: - name: Install and Test with pytest run: | export PATH="$pythonLocation:$PATH" - python -m pip install -e .[Dev,Orso] + python -m pip install -e .[dev,orso] pytest tests/ --cov=ratapi --cov-report=term diff --git a/.gitignore b/.gitignore index fde40ba..51a84ff 100644 --- a/.gitignore +++ b/.gitignore @@ -40,3 +40,6 @@ dist/* # Jupyter notebook checkpoints .ipynb_checkpoints/* + +# Lock file for uv env +uv.lock diff --git a/README.md b/README.md index e456670..46cc348 100644 --- a/README.md +++ b/README.md @@ -12,11 +12,11 @@ To install in local directory: matlabengine is an optional dependency only required for Matlab custom functions. The version of matlabengine should match the version of Matlab installed on the machine. This can be installed as shown below: - pip install -e .[Matlab-2023a] + pip install -e .[matlab-2023a] Development dependencies can be installed as shown below - pip install -e .[Dev] + pip install -e .[dev] To build wheel: diff --git a/cpp/RAT b/cpp/RAT index aae3dc1..7993771 160000 --- a/cpp/RAT +++ b/cpp/RAT @@ -1 +1 @@ -Subproject commit aae3dc141b6a10c6e10dfb47cd62e07a2a11857d +Subproject commit 7993771968fa7335528c4f14ef44393f0b607953 diff --git a/pyproject.toml b/pyproject.toml index d52f72a..54a6c8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,53 @@ requires = [ ] build-backend = 'setuptools.build_meta' +[project] +name = "ratapi" +version = "0.0.0.dev8" +description = "Python extension for the Reflectivity Analysis Toolbox (RAT)" +readme = "README.md" +requires-python = ">=3.10" +dependencies = [ + "matplotlib>=3.8.3", + "numpy>=1.20", + "prettytable>=3.9.0", + "pydantic>=2.7.2", + "scipy>=1.13.1", + "strenum>=0.4.15 ; python_full_version < '3.11'", + "tqdm>=4.66.5", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.4.0", + "pytest-cov>=4.1.0", + "ruff>=0.4.10" +] +orso = [ + "orsopy>=1.2.1", + "pint>=0.24.4" +] +matlab_latest = ["matlabengine"] +matlab_2025b = ["matlabengine == 25.2.*"] +matlab_2025a = ["matlabengine == 25.1.2"] +matlab_2024b = ["matlabengine == 24.2.2"] +matlab_2024a = ["matlabengine == 24.1.4"] +matlab_2023b = ["matlabengine == 23.2.3"] +matlab_2023a = ["matlabengine == 9.14.3"] + +[tool.uv] +conflicts = [ + [ + { extra = "matlab_latest" }, + { extra = "matlab_2025b" }, + { extra = "matlab_2025a" }, + { extra = "matlab_2024b" }, + { extra = "matlab_2024a" }, + { extra = "matlab_2023b" }, + { extra = "matlab_2023a" }, + ], +] + [tool.ruff] line-length = 120 extend-exclude = ["*.ipynb"] @@ -24,7 +71,8 @@ ignore = ["SIM103", # needless bool "D105", # undocumented __init__ "D107", # undocumented magic method "D203", # blank line before class docstring - "D213"] # multi line summary should start at second line + "D213", # multi line summary should start at second line + "UP038"] # non pep604 isinstance - to be removed # ignore docstring lints in the tests and install script [tool.ruff.lint.per-file-ignores] diff --git a/ratapi/classlist.py b/ratapi/classlist.py index f0a61d3..29637a5 100644 --- a/ratapi/classlist.py +++ b/ratapi/classlist.py @@ -5,7 +5,7 @@ import importlib import warnings from collections.abc import Sequence -from typing import Any, Generic, TypeVar, Union +from typing import Any, Generic, TypeVar import numpy as np import prettytable @@ -38,7 +38,7 @@ class ClassList(collections.UserList, Generic[T]): """ - def __init__(self, init_list: Union[Sequence[T], T] = None, name_field: str = "name") -> None: + def __init__(self, init_list: Sequence[T] | T = None, name_field: str = "name") -> None: self.name_field = name_field # Set input as list if necessary @@ -114,7 +114,7 @@ def __str__(self): output = str(self.data) return output - def __getitem__(self, index: Union[int, slice, str, T]) -> T: + def __getitem__(self, index: int | slice | str | T) -> T: """Get an item by its index, name, a slice, or the object itself.""" if isinstance(index, (int, slice)): return self.data[index] @@ -262,12 +262,12 @@ def insert(self, index: int, obj: T = None, **kwargs) -> None: self._validate_name_field(kwargs) self.data.insert(index, self._class_handle(**kwargs)) - def remove(self, item: Union[T, str]) -> None: + def remove(self, item: T | str) -> None: """Remove an object from the ClassList using either the object itself or its ``name_field`` value.""" item = self._get_item_from_name_field(item) self.data.remove(item) - def count(self, item: Union[T, str]) -> int: + def count(self, item: T | str) -> int: """Return the number of times an object appears in the ClassList. This method can use either the object itself or its ``name_field`` value. @@ -276,7 +276,7 @@ def count(self, item: Union[T, str]) -> int: item = self._get_item_from_name_field(item) return self.data.count(item) - def index(self, item: Union[T, str], offset: bool = False, *args) -> int: + def index(self, item: T | str, offset: bool = False, *args) -> int: """Return the index of a particular object in the ClassList. This method can use either the object itself or its ``name_field`` value. @@ -309,7 +309,7 @@ def union(self, other: Sequence[T]) -> None: ] ) - def set_fields(self, index: Union[int, slice, str, T], **kwargs) -> None: + def set_fields(self, index: int | slice | str | T, **kwargs) -> None: """Assign the values of an existing object's attributes using keyword arguments.""" self._validate_name_field(kwargs) pydantic_object = False @@ -519,7 +519,7 @@ def _check_classes(self, input_list: Sequence[T]) -> None: f"In the input list:\n{newline.join(error for error in error_list)}\n" ) - def _get_item_from_name_field(self, value: Union[T, str]) -> Union[T, str]: + def _get_item_from_name_field(self, value: T | str) -> T | str: """Return the object with the given value of the ``name_field`` attribute in the ClassList. Parameters @@ -577,11 +577,12 @@ def _determine_class_handle(input_list: Sequence[T]): @classmethod def __get_pydantic_core_schema__(cls, source: Any, handler): # import here so that the ClassList can be instantiated and used without Pydantic installed + from typing import get_args, get_origin + from pydantic import ValidatorFunctionWrapHandler from pydantic.types import ( core_schema, # import core_schema through here rather than making pydantic_core a dependency ) - from typing_extensions import get_args, get_origin # if annotated with a class, get the item type of that class origin = get_origin(source) diff --git a/ratapi/controls.py b/ratapi/controls.py index ea62f5d..06c457b 100644 --- a/ratapi/controls.py +++ b/ratapi/controls.py @@ -5,7 +5,6 @@ import tempfile import warnings from pathlib import Path -from typing import Union import prettytable from pydantic import ( @@ -233,7 +232,7 @@ def delete_IPC(self): os.remove(self._IPCFilePath) return None - def save(self, filepath: Union[str, Path] = "./controls.json"): + def save(self, filepath: str | Path = "./controls.json"): """Save a controls object to a JSON file. Parameters @@ -245,7 +244,7 @@ def save(self, filepath: Union[str, Path] = "./controls.json"): filepath.write_text(self.model_dump_json()) @classmethod - def load(cls, path: Union[str, Path]) -> "Controls": + def load(cls, path: str | Path) -> "Controls": """Load a controls object from file. Parameters diff --git a/ratapi/events.py b/ratapi/events.py index 71993dd..2383159 100644 --- a/ratapi/events.py +++ b/ratapi/events.py @@ -1,12 +1,12 @@ """Hooks for connecting to run callback events.""" import os -from typing import Callable, Union +from collections.abc import Callable from ratapi.rat_core import EventBridge, EventTypes, PlotEventData, ProgressEventData -def notify(event_type: EventTypes, data: Union[str, PlotEventData, ProgressEventData]) -> None: +def notify(event_type: EventTypes, data: str | PlotEventData | ProgressEventData) -> None: """Call registered callbacks with data when event type has been triggered. Parameters @@ -22,7 +22,7 @@ def notify(event_type: EventTypes, data: Union[str, PlotEventData, ProgressEvent callback(data) -def get_event_callback(event_type: EventTypes) -> list[Callable[[Union[str, PlotEventData, ProgressEventData]], None]]: +def get_event_callback(event_type: EventTypes) -> list[Callable[[str | PlotEventData | ProgressEventData], None]]: """Return all callbacks registered for the given event type. Parameters @@ -39,7 +39,7 @@ def get_event_callback(event_type: EventTypes) -> list[Callable[[Union[str, Plot return list(__event_callbacks[event_type]) -def register(event_type: EventTypes, callback: Callable[[Union[str, PlotEventData, ProgressEventData]], None]) -> None: +def register(event_type: EventTypes, callback: Callable[[str | PlotEventData | ProgressEventData], None]) -> None: """Register a new callback for the event type. Parameters diff --git a/ratapi/inputs.py b/ratapi/inputs.py index 7d0a872..a537b21 100644 --- a/ratapi/inputs.py +++ b/ratapi/inputs.py @@ -3,7 +3,7 @@ import importlib import os import pathlib -from typing import Callable, Union +from collections.abc import Callable import numpy as np @@ -23,7 +23,7 @@ } -def get_python_handle(file_name: str, function_name: str, path: Union[str, pathlib.Path] = "") -> Callable: +def get_python_handle(file_name: str, function_name: str, path: str | pathlib.Path = "") -> Callable: """Get the function handle from a function defined in a python module located anywhere within the filesystem. Parameters diff --git a/ratapi/outputs.py b/ratapi/outputs.py index b60547d..2add165 100644 --- a/ratapi/outputs.py +++ b/ratapi/outputs.py @@ -3,7 +3,7 @@ import json from dataclasses import dataclass from pathlib import Path -from typing import Any, Optional, Union +from typing import Any, Union import numpy as np @@ -244,7 +244,7 @@ def __str__(self): output += get_field_string(key, value, 100) return output - def save(self, filepath: Union[str, Path] = "./results.json"): + def save(self, filepath: str | Path = "./results.json"): """Save the Results object to a JSON file. Parameters @@ -258,7 +258,7 @@ def save(self, filepath: Union[str, Path] = "./results.json"): filepath.write_text(json.dumps(json_dict)) @classmethod - def load(cls, path: Union[str, Path]) -> Union["Results", "BayesResults"]: + def load(cls, path: str | Path) -> Union["Results", "BayesResults"]: """Load a Results object from file. Parameters @@ -538,7 +538,7 @@ class BayesResults(Results): nestedSamplerOutput: NestedSamplerOutput chain: np.ndarray - def save(self, filepath: Union[str, Path] = "./results.json"): + def save(self, filepath: str | Path = "./results.json"): """Save the BayesResults object to a JSON file. Parameters @@ -574,7 +574,7 @@ def save(self, filepath: Union[str, Path] = "./results.json"): filepath.write_text(json.dumps(json_dict)) -def write_core_results_fields(results: Union[Results, BayesResults], json_dict: Optional[dict] = None) -> dict: +def write_core_results_fields(results: Results | BayesResults, json_dict: dict | None = None) -> dict: """Modify the values of the fields that appear in both Results and BayesResults when saving to a json file. Parameters @@ -684,8 +684,8 @@ def read_bayes_results_fields(results_dict: dict) -> dict: def make_results( procedure: Procedures, output_results: ratapi.rat_core.OutputResult, - bayes_results: Optional[ratapi.rat_core.OutputBayesResult] = None, -) -> Union[Results, BayesResults]: + bayes_results: ratapi.rat_core.OutputBayesResult | None = None, +) -> Results | BayesResults: """Initialise a python Results or BayesResults object using the outputs from a RAT calculation. Parameters diff --git a/ratapi/project.py b/ratapi/project.py index 8855553..fcbbaf8 100644 --- a/ratapi/project.py +++ b/ratapi/project.py @@ -5,10 +5,11 @@ import functools import json import warnings +from collections.abc import Callable from enum import Enum from pathlib import Path from textwrap import indent -from typing import Annotated, Any, Callable, Union +from typing import Annotated, Any, get_args, get_origin import numpy as np from pydantic import ( @@ -21,7 +22,6 @@ field_validator, model_validator, ) -from typing_extensions import get_args, get_origin import ratapi.models from ratapi.classlist import ClassList @@ -248,10 +248,10 @@ class Project(BaseModel, validate_assignment=True, extra="forbid", use_attribute data: ClassList[ratapi.models.Data] = ClassList() """Experimental data for a model.""" - layers: Union[ - Annotated[ClassList[ratapi.models.Layer], Tag("no_abs")], - Annotated[ClassList[ratapi.models.AbsorptionLayer], Tag("abs")], - ] = Field( + layers: ( + Annotated[ClassList[ratapi.models.Layer], Tag("no_abs")] + | Annotated[ClassList[ratapi.models.AbsorptionLayer], Tag("abs")] + ) = Field( default=ClassList(), discriminator=Discriminator( discriminate_layers, @@ -265,10 +265,10 @@ class Project(BaseModel, validate_assignment=True, extra="forbid", use_attribute domain_contrasts: ClassList[ratapi.models.DomainContrast] = ClassList() """The groups of layers required by each domain in a domains model.""" - contrasts: Union[ - Annotated[ClassList[ratapi.models.Contrast], Tag("no_ratio")], - Annotated[ClassList[ratapi.models.ContrastWithRatio], Tag("ratio")], - ] = Field( + contrasts: ( + Annotated[ClassList[ratapi.models.Contrast], Tag("no_ratio")] + | Annotated[ClassList[ratapi.models.ContrastWithRatio], Tag("ratio")] + ) = Field( default=ClassList(), discriminator=Discriminator( discriminate_contrasts, @@ -577,7 +577,7 @@ def update_renamed_models(self) -> "Project": old_names = self._all_names[class_list] new_names = getattr(self, class_list).get_names() if len(old_names) == len(new_names): - name_diff = [(old, new) for (old, new) in zip(old_names, new_names) if old != new] + name_diff = [(old, new) for (old, new) in zip(old_names, new_names, strict=False) if old != new] for old_name, new_name in name_diff: for field in fields_to_update: project_field = getattr(self, field.attribute) @@ -927,7 +927,7 @@ def classlist_script(name, classlist): + "\n)" ) - def save(self, filepath: Union[str, Path] = "./project.json"): + def save(self, filepath: str | Path = "./project.json"): """Save a project to a JSON file. Parameters @@ -973,7 +973,7 @@ def make_custom_file_dict(item): filepath.write_text(json.dumps(json_dict)) @classmethod - def load(cls, path: Union[str, Path]) -> "Project": + def load(cls, path: str | Path) -> "Project": """Load a project from file. Parameters diff --git a/ratapi/utils/convert.py b/ratapi/utils/convert.py index b689317..4e2649a 100644 --- a/ratapi/utils/convert.py +++ b/ratapi/utils/convert.py @@ -4,7 +4,6 @@ from collections.abc import Iterable from os import PathLike from pathlib import Path -from typing import Union from numpy import array, empty from scipy.io.matlab import MatlabOpaque, loadmat @@ -15,7 +14,7 @@ from ratapi.utils.enums import Geometries, Languages, LayerModels -def r1_to_project(filename: Union[str, PathLike]) -> Project: +def r1_to_project(filename: str | PathLike) -> Project: """Read a RasCAL1 project struct as a Python `Project`. Parameters @@ -43,7 +42,7 @@ def r1_to_project(filename: Union[str, PathLike]) -> Project: layer_model = LayerModels.CustomXY layer_model = LayerModels(layer_model) - def zip_if_several(*params) -> Union[tuple, list[tuple]]: + def zip_if_several(*params) -> tuple | list[tuple]: """Zips parameters if necessary, but can handle single-item parameters. Examples @@ -64,7 +63,7 @@ def zip_if_several(*params) -> Union[tuple, list[tuple]]: """ if all(isinstance(param, Iterable) and not isinstance(param, str) for param in params): - return zip(*params) + return zip(*params, strict=False) return [params] def read_param(names, constrs, values, fits): @@ -319,8 +318,8 @@ def fix_invalid_constraints(name: str, constrs: tuple[float, float], value: floa def project_to_r1( - project: Project, filename: Union[str, PathLike] = "RAT_project", return_struct: bool = False -) -> Union[dict, None]: + project: Project, filename: str | PathLike = "RAT_project", return_struct: bool = False +) -> dict | None: """Convert a RAT Project to a RasCAL1 project struct. Parameters diff --git a/ratapi/utils/custom_errors.py b/ratapi/utils/custom_errors.py index 425cf9e..83bf084 100644 --- a/ratapi/utils/custom_errors.py +++ b/ratapi/utils/custom_errors.py @@ -1,13 +1,11 @@ """Defines routines for custom error handling in RAT.""" -from typing import Optional - import pydantic_core def custom_pydantic_validation_error( error_list: list[pydantic_core.ErrorDetails], - custom_error_msgs: Optional[dict[str, str]] = None, + custom_error_msgs: dict[str, str] | None = None, ) -> list[pydantic_core.ErrorDetails]: """Give Pydantic errors a better custom message with extraneous information removed. diff --git a/ratapi/utils/enums.py b/ratapi/utils/enums.py index 24c50cb..313f04c 100644 --- a/ratapi/utils/enums.py +++ b/ratapi/utils/enums.py @@ -1,7 +1,5 @@ """The Enum values used in the parameters of various ratapi classes and functions.""" -from typing import Union - try: from enum import StrEnum except ImportError: @@ -92,7 +90,7 @@ class Strategies(RATEnum): or a pure recombination of parent parameter values.""" @classmethod - def _missing_(cls, value: Union[int, str]): + def _missing_(cls, value: int | str): # legacy compatibility with strategies being 1-indexed ints under the hood if isinstance(value, int): if value < 1 or value > 6: diff --git a/ratapi/utils/orso.py b/ratapi/utils/orso.py index 403597b..2d0345f 100644 --- a/ratapi/utils/orso.py +++ b/ratapi/utils/orso.py @@ -4,7 +4,6 @@ from itertools import count from pathlib import Path from textwrap import shorten -from typing import Union import orsopy import prettytable @@ -26,7 +25,7 @@ class ORSOProject: """ - def __init__(self, filepath: Union[str, Path], absorption: bool = False): + def __init__(self, filepath: str | Path, absorption: bool = False): ort_data = load_orso(filepath) datasets = [Data(name=dataset.info.data_source.sample.name, data=dataset.data) for dataset in ort_data] # orso datasets in the same file can have repeated names! @@ -75,7 +74,7 @@ class ORSOSample: bulk_in: Parameter bulk_out: Parameter parameters: ClassList[Parameter] - layers: Union[ClassList[Layer], ClassList[AbsorptionLayer]] + layers: ClassList[Layer] | ClassList[AbsorptionLayer] model: list[str] def __str__(self): @@ -94,8 +93,8 @@ def __str__(self): def orso_model_to_rat( - model: Union[orsopy.fileio.model_language.SampleModel, str], absorption: bool = False -) -> Union[ORSOSample, None]: + model: orsopy.fileio.model_language.SampleModel | str, absorption: bool = False +) -> ORSOSample | None: """Get information from an ORSO SampleModel object. Parameters diff --git a/ratapi/utils/plotting.py b/ratapi/utils/plotting.py index c2823b8..50e0d44 100644 --- a/ratapi/utils/plotting.py +++ b/ratapi/utils/plotting.py @@ -2,12 +2,14 @@ import copy import types +from collections.abc import Callable from functools import partial, wraps from math import ceil, floor, sqrt from statistics import stdev -from typing import Callable, Literal, Optional, Union +from typing import Literal import matplotlib +import matplotlib.figure import matplotlib.pyplot as plt import matplotlib.transforms as mtransforms import numpy as np @@ -47,7 +49,9 @@ def _extract_plot_data(event_data: PlotEventData, q4: bool, show_error_bar: bool if shift_value < 1 or shift_value > 100: raise ValueError("Parameter `shift_value` must be between 1 and 100") - for i, (r, data, sld) in enumerate(zip(event_data.reflectivity, event_data.shiftedData, event_data.sldProfiles)): + for i, (r, data, sld) in enumerate( + zip(event_data.reflectivity, event_data.shiftedData, event_data.sldProfiles, strict=False) + ): # Calculate the divisor div = 1 if i == 0 and not q4 else 10 ** ((i / 100) * shift_value) q4_data = 1 if not q4 or not event_data.dataPresent[i] else data[:, 0] ** 4 @@ -94,9 +98,9 @@ def _extract_plot_data(event_data: PlotEventData, q4: bool, show_error_bar: bool def plot_ref_sld_helper( data: PlotEventData, - fig: matplotlib.pyplot.figure, + fig: matplotlib.figure.Figure, delay: bool = True, - confidence_intervals: Union[dict, None] = None, + confidence_intervals: dict | None = None, linear_x: bool = False, q4: bool = False, show_error_bar: bool = True, @@ -112,7 +116,7 @@ def plot_ref_sld_helper( data : PlotEventData The plot event data that contains all the information to generate the ref and sld plots - fig : matplotlib.pyplot.figure + fig : matplotlib.figure.Figure The figure object that has two subplots delay : bool, default: True Controls whether to delay 0.005s after plot is created @@ -230,9 +234,9 @@ def plot_ref_sld_helper( def plot_ref_sld( project: ratapi.Project, - results: Union[ratapi.outputs.Results, ratapi.outputs.BayesResults], + results: ratapi.outputs.Results | ratapi.outputs.BayesResults, block: bool = False, - fig: Optional[matplotlib.pyplot.figure] = None, + fig: matplotlib.figure.Figure | None = None, return_fig: bool = False, bayes: Literal[65, 95, None] = None, linear_x: bool = False, @@ -241,7 +245,7 @@ def plot_ref_sld( show_grid: bool = False, show_legend: bool = True, shift_value: float = 100, -) -> Union[plt.Figure, None]: +) -> plt.Figure | None: """Plot the reflectivity and SLD profiles. Parameters @@ -252,7 +256,7 @@ def plot_ref_sld( The result from the calculation block : bool, default: False Indicates the plot should block until it is closed - fig : matplotlib.pyplot.figure, optional + fig : matplotlib.figure.Figure, optional The figure object that has two subplots return_fig : bool, default False If True, return the figure instead of displaying it. @@ -319,10 +323,12 @@ def plot_ref_sld( ], } # For a shaded plot, use the mean values from predictionIntervals - for reflectivity, mean_reflectivity in zip(data.reflectivity, results.predictionIntervals.reflectivity): + for reflectivity, mean_reflectivity in zip( + data.reflectivity, results.predictionIntervals.reflectivity, strict=False + ): reflectivity[:, 1] = mean_reflectivity[2] - for sldProfile, mean_sld_profile in zip(data.sldProfiles, results.predictionIntervals.sld): - for sld, mean_sld in zip(sldProfile, mean_sld_profile): + for sldProfile, mean_sld_profile in zip(data.sldProfiles, results.predictionIntervals.sld, strict=False): + for sld, mean_sld in zip(sldProfile, mean_sld_profile, strict=False): sld[:, 1] = mean_sld[2] else: raise ValueError( @@ -366,7 +372,7 @@ class BlittingSupport: data : PlotEventData The plot event data that contains all the information to generate the ref and sld plots - fig : matplotlib.pyplot.figure, optional + fig : matplotlib.figure.Figure, optional The figure class that has two subplots linear_x : bool, default: False Controls whether the x-axis on reflectivity plot uses the linear scale @@ -471,7 +477,9 @@ def adjust_error_bar(self, error_bar_container, x, y, y_error): y_error_top = y_base + y_error y_error_bottom = y_base - y_error - new_segments_y = [np.array([[x, yt], [x, yb]]) for x, yt, yb in zip(x_base, y_error_top, y_error_bottom)] + new_segments_y = [ + np.array([[x, yt], [x, yb]]) for x, yt, yb in zip(x_base, y_error_top, y_error_bottom, strict=False) + ] bars_y.set_segments(new_segments_y) def update_plot(self, data): @@ -628,7 +636,7 @@ def inner(results, *args, **kwargs): return decorator -def name_to_index(param: Union[str, int], names: list[str]): +def name_to_index(param: str | int, names: list[str]): """Convert parameter names to indices.""" if isinstance(param, str): if param not in names: @@ -645,14 +653,14 @@ def name_to_index(param: Union[str, int], names: list[str]): @assert_bayesian("Corner") def plot_corner( results: ratapi.outputs.BayesResults, - params: Union[list[Union[int, str]], None] = None, + params: list[int | str] | None = None, smooth: bool = True, block: bool = False, - fig: Optional[matplotlib.pyplot.figure] = None, + fig: matplotlib.figure.Figure | None = None, return_fig: bool = False, - hist_kwargs: Union[dict, None] = None, - hist2d_kwargs: Union[dict, None] = None, - progress_callback: Union[Callable[[int, int], None], None] = None, + hist_kwargs: dict | None = None, + hist2d_kwargs: dict | None = None, + progress_callback: Callable[[int, int], None] | None = None, ): """Create a corner plot from a Bayesian analysis. @@ -667,7 +675,7 @@ def plot_corner( Whether to apply Gaussian smoothing to the corner plot. block : bool, default False Whether Python should block until the plot is closed. - fig : matplotlib.pyplot.figure, optional + fig : matplotlib.figure.Figure, optional The figure object to use for plot. return_fig: bool, default False If True, return the figure as an object instead of showing it. @@ -750,11 +758,11 @@ def plot_corner( @assert_bayesian("Histogram") def plot_one_hist( results: ratapi.outputs.BayesResults, - param: Union[int, str], + param: int | str, smooth: bool = True, - sigma: Union[float, None] = None, + sigma: float | None = None, estimated_density: Literal["normal", "lognor", "kernel", None] = None, - axes: Union[Axes, None] = None, + axes: Axes | None = None, block: bool = False, return_fig: bool = False, **hist_settings, @@ -901,11 +909,11 @@ def _y_update_offset_text_position(axis, _bboxes, bboxes2): @assert_bayesian("Contour") def plot_contour( results: ratapi.outputs.BayesResults, - x_param: Union[int, str], - y_param: Union[int, str], + x_param: int | str, + y_param: int | str, smooth: bool = True, - sigma: Union[tuple[float], None] = None, - axes: Union[Axes, None] = None, + sigma: tuple[float] | None = None, + axes: Axes | None = None, block: bool = False, return_fig: bool = False, **hist2d_settings, @@ -974,7 +982,7 @@ def plot_contour( def panel_plot_helper( - plot_func: Callable, indices: list[int], fig: Optional[matplotlib.pyplot.figure] = None + plot_func: Callable, indices: list[int], fig: matplotlib.figure.Figure | None = None ) -> matplotlib.figure.Figure: """Generate a panel-based plot from a single plot function. @@ -984,7 +992,7 @@ def panel_plot_helper( A function which plots one parameter on an Axes object, given its index. indices : list[int] The list of indices to pass into ``plot_func``. - fig : matplotlib.pyplot.figure, optional + fig : matplotlib.figure.Figure, optional The figure object to use for plot. Returns @@ -1020,14 +1028,13 @@ def panel_plot_helper( @assert_bayesian("Histogram") def plot_hists( results: ratapi.outputs.BayesResults, - params: Union[list[Union[int, str]], None] = None, + params: list[int | str] | None = None, smooth: bool = True, - sigma: Union[float, None] = None, - estimated_density: Union[ - dict[Literal["normal", "lognor", "kernel", None]], Literal["normal", "lognor", "kernel", None] - ] = None, + sigma: float | None = None, + estimated_density: dict[Literal["normal", "lognor", "kernel", None]] + | Literal["normal", "lognor", "kernel", None] = None, block: bool = False, - fig: Optional[matplotlib.pyplot.figure] = None, + fig: matplotlib.figure.Figure | None = None, return_fig: bool = False, **hist_settings, ): @@ -1061,7 +1068,7 @@ def plot_hists( e.g. to apply 'normal' to all unset parameters, set `estimated_density = {'default': 'normal'}`. block : bool, default False Whether Python should block until the plot is closed. - fig : matplotlib.pyplot.figure, optional + fig : matplotlib.figure.Figure, optional The figure object to use for plot. return_fig: bool, default False If True, return the figure as an object instead of showing it. @@ -1085,7 +1092,7 @@ def plot_hists( if estimated_density is not None: - def validate_dens_type(dens_type: Union[str, None], param: str): + def validate_dens_type(dens_type: str | None, param: str): """Check estimated density is a supported type.""" if dens_type not in [None, "normal", "lognor", "kernel"]: raise ValueError( @@ -1132,10 +1139,10 @@ def validate_dens_type(dens_type: Union[str, None], param: str): @assert_bayesian("Chain") def plot_chain( results: ratapi.outputs.BayesResults, - params: Union[list[Union[int, str]], None] = None, + params: list[int | str] | None = None, maxpoints: int = 15000, block: bool = False, - fig: Optional[matplotlib.pyplot.figure] = None, + fig: matplotlib.figure.Figure | None = None, return_fig: bool = False, ): """Plot the MCMC chain for each parameter of a Bayesian analysis. @@ -1151,7 +1158,7 @@ def plot_chain( The maximum number of points to plot for each parameter. block : bool, default False Whether Python should block until the plot is closed. - fig : matplotlib.pyplot.figure, optional + fig : matplotlib.figure.Figure, optional The figure object to use for plot. return_fig: bool, default False If True, return the figure as an object instead of showing it. diff --git a/ratapi/wrappers.py b/ratapi/wrappers.py index 39021e2..74eda41 100644 --- a/ratapi/wrappers.py +++ b/ratapi/wrappers.py @@ -2,8 +2,8 @@ import os import pathlib +from collections.abc import Callable from contextlib import suppress -from typing import Callable import numpy as np from numpy.typing import ArrayLike diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index cee9a79..0000000 --- a/requirements.txt +++ /dev/null @@ -1,14 +0,0 @@ -numpy >= 1.20 -scipy >= 1.13.1 -prettytable >= 3.9.0 -pybind11 >= 2.4 -pydantic >= 2.7.2 -pytest >= 7.4.0 -pytest-cov >= 4.1.0 -matplotlib >= 3.8.3 -StrEnum >= 0.4.15; python_version < '3.11' -ruff >= 0.4.10 -scipy >= 1.13.1 -tqdm >= 4.66.5 -orsopy >= 1.2.1 -pint >= 0.24.4 diff --git a/setup.py b/setup.py index 4c99636..3c21a01 100644 --- a/setup.py +++ b/setup.py @@ -165,25 +165,5 @@ def build_libraries(self, libraries): cmdclass={"build_clib": BuildClib, "build_ext": BuildExt}, libraries=[libevent], ext_modules=ext_modules, - python_requires=">=3.10", - install_requires=[ - "numpy >= 1.20", - "prettytable >= 3.9.0", - "pydantic >= 2.7.2", - "matplotlib >= 3.8.3", - "scipy >= 1.13.1", - "tqdm >= 4.66.5", - ], - extras_require={ - ':python_version < "3.11"': ["StrEnum >= 0.4.15"], - "Dev": ["pytest>=7.4.0", "pytest-cov>=4.1.0", "ruff>=0.4.10"], - "Orso": ["orsopy>=1.2.1", "pint>=0.24.4"], - "Matlab_latest": ["matlabengine"], - "Matlab_2025a": ["matlabengine == 25.1.*"], - "Matlab_2024b": ["matlabengine == 24.2.2"], - "Matlab_2024a": ["matlabengine == 24.1.4"], - "Matlab_2023b": ["matlabengine == 23.2.3"], - "Matlab_2023a": ["matlabengine == 9.14.3"], - }, zip_safe=False, ) diff --git a/tests/test_classlist.py b/tests/test_classlist.py index 04130ff..98d9440 100644 --- a/tests/test_classlist.py +++ b/tests/test_classlist.py @@ -5,7 +5,7 @@ import warnings from collections import deque from collections.abc import Iterable, Sequence -from typing import Any, Union +from typing import Any import prettytable import pytest @@ -611,7 +611,7 @@ def test_insert_kwargs_same_name(two_name_class_list: ClassList, new_values: dic (InputAttributes(name="Bob")), ], ) -def test_remove(two_name_class_list: ClassList, remove_value: Union[object, str]) -> None: +def test_remove(two_name_class_list: ClassList, remove_value: object | str) -> None: """We should be able to remove an object either by the value of the name_field or by specifying the object itself. """ @@ -626,7 +626,7 @@ def test_remove(two_name_class_list: ClassList, remove_value: Union[object, str] (InputAttributes(name="Eve")), ], ) -def test_remove_not_present(two_name_class_list: ClassList, remove_value: Union[object, str]) -> None: +def test_remove_not_present(two_name_class_list: ClassList, remove_value: object | str) -> None: """If we remove an object not included in the ClassList we should raise a ValueError.""" with pytest.raises(ValueError, match=re.escape("list.remove(x): x not in list")): two_name_class_list.remove(remove_value) @@ -641,7 +641,7 @@ def test_remove_not_present(two_name_class_list: ClassList, remove_value: Union[ (InputAttributes(name="Eve"), 0), ], ) -def test_count(two_name_class_list: ClassList, count_value: Union[object, str], expected_count: int) -> None: +def test_count(two_name_class_list: ClassList, count_value: object | str, expected_count: int) -> None: """We should be able to determine the number of times an object is in the ClassList using either the object itself or its name_field value. """ @@ -655,7 +655,7 @@ def test_count(two_name_class_list: ClassList, count_value: Union[object, str], (InputAttributes(name="Bob"), 1), ], ) -def test_index(two_name_class_list: ClassList, index_value: Union[object, str], expected_index: int) -> None: +def test_index(two_name_class_list: ClassList, index_value: object | str, expected_index: int) -> None: """We should be able to find the index of an object in the ClassList either by its name_field value or by specifying the object itself. """ @@ -671,7 +671,7 @@ def test_index(two_name_class_list: ClassList, index_value: Union[object, str], ) def test_index_offset( two_name_class_list: ClassList, - index_value: Union[object, str], + index_value: object | str, offset: int, expected_index: int, ) -> None: @@ -688,7 +688,7 @@ def test_index_offset( (InputAttributes(name="Eve")), ], ) -def test_index_not_present(two_name_class_list: ClassList, index_value: Union[object, str]) -> None: +def test_index_not_present(two_name_class_list: ClassList, index_value: object | str) -> None: """If we try to find the index of an object not included in the ClassList we should raise a ValueError.""" # with pytest.raises(ValueError, match=f"'{index_value}' is not in list") as e: with pytest.raises(ValueError): @@ -741,7 +741,7 @@ def test_extend_empty_classlist(extended_list: Sequence, one_name_class_list: Cl ], ) def test_set_fields( - two_name_class_list: ClassList, index: Union[int, str], new_values: dict[str, Any], expected_classlist: ClassList + two_name_class_list: ClassList, index: int | str, new_values: dict[str, Any], expected_classlist: ClassList ) -> None: """We should be able to set field values in an element of a ClassList using keyword arguments.""" class_list = two_name_class_list @@ -963,7 +963,7 @@ def test__check_classes_different_classes(input_list: Sequence) -> None: def test__get_item_from_name_field( two_name_class_list: ClassList, value: str, - expected_output: Union[object, str], + expected_output: object | str, ) -> None: """When we input the name_field value of an object defined in the ClassList, we should return the object. If the value is not the name_field of an object defined in the ClassList, we should return the value. @@ -1044,7 +1044,7 @@ class NestedModel(pydantic.BaseModel): submodels_list = [{"i": 3, "s": "hello", "f": 3.0}, {"i": 4, "s": "hi", "f": 3.14}] model = NestedModel(submodels=submodels_list) - for submodel, exp_dict in zip(model.submodels, submodels_list): + for submodel, exp_dict in zip(model.submodels, submodels_list, strict=False): for key, value in exp_dict.items(): assert getattr(submodel, key) == value diff --git a/tests/test_controls.py b/tests/test_controls.py index 61d6833..ec8cb12 100644 --- a/tests/test_controls.py +++ b/tests/test_controls.py @@ -4,7 +4,7 @@ import os import tempfile from pathlib import Path -from typing import Any, Union +from typing import Any import pydantic import pytest @@ -366,7 +366,7 @@ def test_set_non_simplex_properties(self, wrong_property: str, value: Any) -> No ("maxIterations", -50), ], ) - def test_simplex_property_errors(self, control_property: str, value: Union[float, int]) -> None: + def test_simplex_property_errors(self, control_property: str, value: float | int) -> None: """Tests the property errors of Simplex class.""" with pytest.raises(pydantic.ValidationError, match="Input should be greater than 0"): setattr(self.simplex, control_property, value) @@ -538,7 +538,7 @@ def test_de_crossoverProbability_error(self, value: int, msg: str) -> None: def test_de_targetValue_numGenerations_populationSize_error( self, control_property: str, - value: Union[int, float], + value: int | float, ) -> None: """Tests the targetValue, numGenerations, populationSize setter error in DE class.""" with pytest.raises(pydantic.ValidationError, match="Input should be greater than or equal to 1"): @@ -693,7 +693,7 @@ def test_set_non_ns_properties(self, wrong_property: str, value: Any) -> None: ("nLive", -500, 1), ], ) - def test_ns_setter_error(self, control_property: str, value: Union[int, float], bound: int) -> None: + def test_ns_setter_error(self, control_property: str, value: int | float, bound: int) -> None: """Tests the nMCMC, nsTolerance, nLive setter error in NS class.""" with pytest.raises(pydantic.ValidationError, match=f"Input should be greater than or equal to {bound}"): setattr(self.ns, control_property, value) diff --git a/tests/test_enums.py b/tests/test_enums.py index 9984185..3d07c30 100644 --- a/tests/test_enums.py +++ b/tests/test_enums.py @@ -1,6 +1,6 @@ """Tests the enums module.""" -from typing import Callable +from collections.abc import Callable import pytest diff --git a/tests/test_inputs.py b/tests/test_inputs.py index e7a1560..bf9a2e1 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -736,11 +736,12 @@ def check_problem_equal(actual_problem, expected_problem) -> None: # Data field is a numpy array assert [ - actual_data == expected_data for (actual_data, expected_data) in zip(actual_problem.data, expected_problem.data) + actual_data == expected_data + for (actual_data, expected_data) in zip(actual_problem.data, expected_problem.data, strict=False) ] # Need to account for "NaN" entries in layersDetails and contrastCustomFiles field - for actual_layer, expected_layer in zip(actual_problem.layersDetails, expected_problem.layersDetails): + for actual_layer, expected_layer in zip(actual_problem.layersDetails, expected_problem.layersDetails, strict=False): assert (actual_layer == expected_layer) or ["NaN" if np.isnan(el) else el for el in actual_layer] == [ "NaN" if np.isnan(el) else el for el in expected_layer ] diff --git a/tests/test_models.py b/tests/test_models.py index 0b00dea..906a5d6 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -2,7 +2,7 @@ import pathlib import re -from typing import Callable +from collections.abc import Callable import numpy as np import pydantic diff --git a/tests/test_orso_utils.py b/tests/test_orso_utils.py index a8c0791..60f0449 100644 --- a/tests/test_orso_utils.py +++ b/tests/test_orso_utils.py @@ -95,7 +95,7 @@ def test_load_ort_data(test_data): actual_data = ORSOProject(Path(TEST_DIR_PATH, test_data)).data assert len(actual_data) == len(expected_data) - for actual_dataset, expected_dataset in zip(actual_data, expected_data): + for actual_dataset, expected_dataset in zip(actual_data, expected_data, strict=False): np.testing.assert_array_equal(actual_dataset.data, expected_dataset) @@ -118,5 +118,5 @@ def test_load_ort_project(test_data, expected_data): assert sample.parameters == exp_project.parameters[1:] assert sample.layers == exp_project.layers - for data, exp_data in zip(ort_data.data, exp_project.data[1:]): + for data, exp_data in zip(ort_data.data, exp_project.data[1:], strict=False): np.testing.assert_array_equal(data.data, exp_data.data) diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 222d514..710729a 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -194,10 +194,14 @@ def test_plot_ref_sld(mock: MagicMock, input_project, reflectivity_calculation_r assert figure.axes[0].get_subplotspec().get_gridspec().get_geometry() == (1, 2) assert len(figure.axes) == 2 - for reflectivity, reflectivity_results in zip(data.reflectivity, reflectivity_calculation_results.reflectivity): + for reflectivity, reflectivity_results in zip( + data.reflectivity, reflectivity_calculation_results.reflectivity, strict=False + ): assert (reflectivity == reflectivity_results).all() - for sldProfile, result_sld_profile in zip(data.sldProfiles, reflectivity_calculation_results.sldProfiles): - for sld, sld_results in zip(sldProfile, result_sld_profile): + for sldProfile, result_sld_profile in zip( + data.sldProfiles, reflectivity_calculation_results.sldProfiles, strict=False + ): + for sld, sld_results in zip(sldProfile, result_sld_profile, strict=False): assert (sld == sld_results).all() assert data.modelType == input_project.model diff --git a/tests/test_project.py b/tests/test_project.py index 7683dec..31e1376 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -4,13 +4,13 @@ import re import tempfile import warnings +from collections.abc import Callable from pathlib import Path -from typing import Callable +from typing import get_args, get_origin import numpy as np import pydantic import pytest -from typing_extensions import get_args, get_origin import ratapi from ratapi.utils.enums import Calculations, LayerModels, TypeOptions @@ -667,7 +667,7 @@ def test_rename_models(test_project, model: str, fields: list[str]) -> None: getattr(test_project, model).set_fields(-1, name="New Name") model_name_lists = ratapi.project.model_names_used_in[model] - for model_name_list, field in zip(model_name_lists, fields): + for model_name_list, field in zip(model_name_lists, fields, strict=False): attribute = model_name_list.attribute assert getattr(getattr(test_project, attribute)[-1], field) == "New Name" diff --git a/tests/utils.py b/tests/utils.py index 91b5b9a..5387b30 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -48,7 +48,7 @@ def check_results_equal(actual_results, expected_results) -> None: # The first set of fields are either 1D or 2D python lists containing numpy arrays. # Hence, we need to compare them element-wise. for list_field in ratapi.outputs.results_fields["list_fields"]: - for a, b in zip(getattr(actual_results, list_field), getattr(expected_results, list_field)): + for a, b in zip(getattr(actual_results, list_field), getattr(expected_results, list_field), strict=False): assert (a == b).all() for list_field in ratapi.outputs.results_fields["double_list_fields"]: @@ -56,7 +56,7 @@ def check_results_equal(actual_results, expected_results) -> None: expected_list = getattr(expected_results, list_field) assert len(actual_list) == len(expected_list) for i in range(len(actual_list)): - for a, b in zip(actual_list[i], expected_list[i]): + for a, b in zip(actual_list[i], expected_list[i], strict=False): assert (a == b).all() # Compare the final fields @@ -90,7 +90,7 @@ def check_bayes_fields_equal(actual_results, expected_results) -> None: assert getattr(actual_subclass, field) == getattr(expected_subclass, field) for field in ratapi.outputs.bayes_results_fields["list_fields"][subclass]: - for a, b in zip(getattr(actual_subclass, field), getattr(expected_subclass, field)): + for a, b in zip(getattr(actual_subclass, field), getattr(expected_subclass, field), strict=False): assert (a == b).all() for field in ratapi.outputs.bayes_results_fields["double_list_fields"][subclass]: @@ -98,7 +98,7 @@ def check_bayes_fields_equal(actual_results, expected_results) -> None: expected_list = getattr(expected_subclass, field) assert len(actual_list) == len(expected_list) for i in range(len(actual_list)): - for a, b in zip(actual_list[i], expected_list[i]): + for a, b in zip(actual_list[i], expected_list[i], strict=False): assert (a == b).all() # Need to account for the arrays that are initialised as "NaN" in the compiled code