From 51840d59e46c1197f43bc4dc6a8096f9883605c0 Mon Sep 17 00:00:00 2001 From: Oliver Date: Wed, 3 Dec 2025 16:29:11 -0800 Subject: [PATCH 01/15] broken commit --- physicsnemo/core/__init__.py | 3 +- physicsnemo/core/benchmark.py | 394 ++++++++++++++++++ physicsnemo/core/function.py | 32 ++ physicsnemo/core/module.py | 3 +- physicsnemo/core/version_check.py | 15 + physicsnemo/nn/__init__.py | 70 +--- physicsnemo/nn/functional/__init__.py | 5 + .../nn/functional/finite_difference.py | 367 ++++++++++++++++ physicsnemo/nn/module/__init__.py | 66 +++ physicsnemo/nn/{ => module}/activations.py | 0 .../nn/{ => module}/attention_layers.py | 0 physicsnemo/nn/{ => module}/ball_query.py | 2 +- physicsnemo/nn/{ => module}/conv_layers.py | 0 physicsnemo/nn/{ => module}/dgm_layers.py | 0 physicsnemo/nn/{ => module}/drop.py | 0 physicsnemo/nn/{ => module}/fft.py | 0 physicsnemo/nn/module/finite_difference.py | 108 +++++ physicsnemo/nn/{ => module}/fourier_layers.py | 0 .../nn/{ => module}/fully_connected_layers.py | 0 physicsnemo/nn/{ => module}/fused_silu.py | 0 .../nn/{ => module}/gnn_layers/__init__.py | 0 .../nn/{ => module}/gnn_layers/bsms.py | 0 .../gnn_layers/distributed_graph.py | 0 .../nn/{ => module}/gnn_layers/embedder.py | 0 .../nn/{ => module}/gnn_layers/graph.py | 0 .../gnn_layers/mesh_edge_block.py | 0 .../gnn_layers/mesh_graph_decoder.py | 0 .../gnn_layers/mesh_graph_encoder.py | 0 .../{ => module}/gnn_layers/mesh_graph_mlp.py | 0 .../gnn_layers/mesh_node_block.py | 0 .../nn/{ => module}/gnn_layers/utils.py | 0 physicsnemo/nn/{ => module}/interpolation.py | 0 physicsnemo/nn/{ => module}/kan_layers.py | 0 physicsnemo/nn/{ => module}/layer_norm.py | 0 physicsnemo/nn/{ => module}/mlp_layers.py | 0 .../nn/{ => module}/neighbors/__init__.py | 0 .../{ => module}/neighbors/_knn/__init__.py | 0 .../{ => module}/neighbors/_knn/_cuml_impl.py | 0 .../neighbors/_knn/_scipy_impl.py | 0 .../neighbors/_knn/_torch_impl.py | 0 .../nn/{ => module}/neighbors/_knn/knn.py | 0 .../neighbors/_radius_search/__init__.py | 0 .../neighbors/_radius_search/_torch_impl.py | 0 .../neighbors/_radius_search/_warp_impl.py | 0 .../neighbors/_radius_search/kernels.py | 0 .../neighbors/_radius_search/radius_search.py | 0 .../nn/{ => module}/resample_layers.py | 0 physicsnemo/nn/{ => module}/sdf.py | 0 physicsnemo/nn/{ => module}/siren_layers.py | 0 .../nn/{ => module}/spectral_layers.py | 0 .../nn/{ => module}/transformer_decoder.py | 0 .../nn/{ => module}/transformer_layers.py | 0 physicsnemo/nn/module/utils/__init__.py | 35 ++ .../nn/{ => module}/utils/patch_embed.py | 0 .../{ => module}/utils/shift_window_mask.py | 0 physicsnemo/nn/{ => module}/utils/utils.py | 0 .../nn/{ => module}/utils/weight_init.py | 0 physicsnemo/nn/{ => module}/weight_fact.py | 0 physicsnemo/nn/{ => module}/weight_norm.py | 0 physicsnemo/nn/utils/__init__.py | 24 +- pyproject.toml | 3 + 61 files changed, 1041 insertions(+), 86 deletions(-) create mode 100644 physicsnemo/core/benchmark.py create mode 100644 physicsnemo/core/function.py create mode 100644 physicsnemo/nn/functional/__init__.py create mode 100644 physicsnemo/nn/functional/finite_difference.py create mode 100644 physicsnemo/nn/module/__init__.py rename physicsnemo/nn/{ => module}/activations.py (100%) rename physicsnemo/nn/{ => module}/attention_layers.py (100%) rename physicsnemo/nn/{ => module}/ball_query.py (98%) rename physicsnemo/nn/{ => module}/conv_layers.py (100%) rename physicsnemo/nn/{ => module}/dgm_layers.py (100%) rename physicsnemo/nn/{ => module}/drop.py (100%) rename physicsnemo/nn/{ => module}/fft.py (100%) create mode 100644 physicsnemo/nn/module/finite_difference.py rename physicsnemo/nn/{ => module}/fourier_layers.py (100%) rename physicsnemo/nn/{ => module}/fully_connected_layers.py (100%) rename physicsnemo/nn/{ => module}/fused_silu.py (100%) rename physicsnemo/nn/{ => module}/gnn_layers/__init__.py (100%) rename physicsnemo/nn/{ => module}/gnn_layers/bsms.py (100%) rename physicsnemo/nn/{ => module}/gnn_layers/distributed_graph.py (100%) rename physicsnemo/nn/{ => module}/gnn_layers/embedder.py (100%) rename physicsnemo/nn/{ => module}/gnn_layers/graph.py (100%) rename physicsnemo/nn/{ => module}/gnn_layers/mesh_edge_block.py (100%) rename physicsnemo/nn/{ => module}/gnn_layers/mesh_graph_decoder.py (100%) rename physicsnemo/nn/{ => module}/gnn_layers/mesh_graph_encoder.py (100%) rename physicsnemo/nn/{ => module}/gnn_layers/mesh_graph_mlp.py (100%) rename physicsnemo/nn/{ => module}/gnn_layers/mesh_node_block.py (100%) rename physicsnemo/nn/{ => module}/gnn_layers/utils.py (100%) rename physicsnemo/nn/{ => module}/interpolation.py (100%) rename physicsnemo/nn/{ => module}/kan_layers.py (100%) rename physicsnemo/nn/{ => module}/layer_norm.py (100%) rename physicsnemo/nn/{ => module}/mlp_layers.py (100%) rename physicsnemo/nn/{ => module}/neighbors/__init__.py (100%) rename physicsnemo/nn/{ => module}/neighbors/_knn/__init__.py (100%) rename physicsnemo/nn/{ => module}/neighbors/_knn/_cuml_impl.py (100%) rename physicsnemo/nn/{ => module}/neighbors/_knn/_scipy_impl.py (100%) rename physicsnemo/nn/{ => module}/neighbors/_knn/_torch_impl.py (100%) rename physicsnemo/nn/{ => module}/neighbors/_knn/knn.py (100%) rename physicsnemo/nn/{ => module}/neighbors/_radius_search/__init__.py (100%) rename physicsnemo/nn/{ => module}/neighbors/_radius_search/_torch_impl.py (100%) rename physicsnemo/nn/{ => module}/neighbors/_radius_search/_warp_impl.py (100%) rename physicsnemo/nn/{ => module}/neighbors/_radius_search/kernels.py (100%) rename physicsnemo/nn/{ => module}/neighbors/_radius_search/radius_search.py (100%) rename physicsnemo/nn/{ => module}/resample_layers.py (100%) rename physicsnemo/nn/{ => module}/sdf.py (100%) rename physicsnemo/nn/{ => module}/siren_layers.py (100%) rename physicsnemo/nn/{ => module}/spectral_layers.py (100%) rename physicsnemo/nn/{ => module}/transformer_decoder.py (100%) rename physicsnemo/nn/{ => module}/transformer_layers.py (100%) create mode 100644 physicsnemo/nn/module/utils/__init__.py rename physicsnemo/nn/{ => module}/utils/patch_embed.py (100%) rename physicsnemo/nn/{ => module}/utils/shift_window_mask.py (100%) rename physicsnemo/nn/{ => module}/utils/utils.py (100%) rename physicsnemo/nn/{ => module}/utils/weight_init.py (100%) rename physicsnemo/nn/{ => module}/weight_fact.py (100%) rename physicsnemo/nn/{ => module}/weight_norm.py (100%) diff --git a/physicsnemo/core/__init__.py b/physicsnemo/core/__init__.py index 307c186bbe..3ead9e26d1 100644 --- a/physicsnemo/core/__init__.py +++ b/physicsnemo/core/__init__.py @@ -14,8 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .function import Function from .meta import ModelMetaData from .module import Module from .registry import ModelRegistry -__all__ = ["ModelMetaData", "Module", "ModelRegistry"] +__all__ = ["ModelMetaData", "Module", "ModelRegistry", "Function"] diff --git a/physicsnemo/core/benchmark.py b/physicsnemo/core/benchmark.py new file mode 100644 index 0000000000..54d7227514 --- /dev/null +++ b/physicsnemo/core/benchmark.py @@ -0,0 +1,394 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common benchmarking hooks shared by autograd functions and modules.""" + +from __future__ import annotations + +import json +import time +from pathlib import Path +from typing import Any, Callable, Iterable, Tuple + +import torch + +_DEFAULT_RESULTS_DIR = Path("benchmarks/results") + + +class BenchmarkMixin: + """Defines the hooks needed by the benchmark and validation helpers.""" + + @classmethod + def make_inputs(cls) -> Iterable[Tuple[str, dict[str, Any], Tuple[Any, ...]]]: + """Yield (label, init_kwargs, run_args) tuples for benchmarking.""" + + raise NotImplementedError( + f"{cls.__name__}.make_inputs must be implemented by subclasses" + ) + + @classmethod + def reference_impl(cls, *args, **kwargs): + """Compute a reference output for correctness checks.""" + + raise NotImplementedError( + f"{cls.__name__}.reference_impl must be implemented by subclasses" + ) + + @classmethod + def check(cls, *args, **kwargs) -> None: + """Validate outputs against the reference implementation.""" + + raise NotImplementedError( + f"{cls.__name__}.check must be implemented by subclasses" + ) + + @classmethod + def benchmark( + cls, + *, + repeats: int = 10, + warmup: int = 1, + ) -> dict: + """Collect benchmark data for downstream consumers. + + Parameters + ---------- + repeats + Number of timed iterations recorded per benchmark case. + warmup + Runs discarded before timings are collected. Warmups help amortize + one-time setup costs. + + Returns + ------- + dict + Structured benchmark payload that can be consumed by CI, ASV, or + plotting utilities. + """ + + if repeats <= 0: + raise ValueError("repeats must be positive") + if warmup < 0: + raise ValueError("warmup must be non-negative") + + cases = [] + devices = set() + dtypes = set() + for index, case in enumerate(cls.make_inputs()): + label, init_kwargs, run_args = cls._normalize_case(case) + inputs = cls._tuple_inputs(run_args) + tensor = cls._first_tensor(inputs) + if tensor is None: + case_device = torch.device("cpu") + case_dtype = "unknown" + else: + case_device = tensor.device + case_dtype = str(tensor.dtype).replace("torch.", "") + + instance = cls._instantiate_case(init_kwargs) + cls._maybe_move_instance(instance, case_device) + + current_runner = cls._resolve_current_runner(instance) + reference_runner = cls._resolve_reference_runner(instance) + + current_timings = cls._time_callable( + current_runner, inputs, repeats, warmup, case_device + ) + reference_timings = cls._time_callable( + reference_runner, inputs, repeats, warmup, case_device + ) + case_record = { + "label": label, + "index": index, + "metadata": cls._benchmark_case_metadata(inputs), + "device": str(case_device), + "dtype": case_dtype, + "init_kwargs": dict(init_kwargs), + "current": { + "timings_ms": current_timings, + "statistics": cls._summarize_timings(current_timings), + }, + "reference": { + "timings_ms": reference_timings, + "statistics": cls._summarize_timings(reference_timings), + }, + } + cases.append(case_record) + devices.add(str(case_device)) + dtypes.add(case_dtype) + + payload = { + "name": cls.__name__, + "qualified_name": f"{cls.__module__}.{cls.__name__}", + "generated_at": time.time(), + "options": { + "repeats": repeats, + "warmup": warmup, + }, + "devices": sorted(devices), + "dtypes": sorted(dtypes), + "cases": cases, + } + + return payload + + @staticmethod + def save_benchmark(payload: dict, directory: str | Path = _DEFAULT_RESULTS_DIR) -> Path: + """Persist a benchmark payload to disk.""" + + output_dir = Path(directory) + output_dir.mkdir(parents=True, exist_ok=True) + destination = output_dir / f"{payload['qualified_name']}.json" + destination.write_text(json.dumps(payload, indent=2)) + return destination + + @classmethod + def plot_benchmarks( + cls, + payload: dict | None = None, + *, + save: str | Path | None = None, + ): + """Render a bar chart comparing current vs reference timings.""" + + if payload is None: + payload = cls.benchmark() + + labels = [case["label"] for case in payload.get("cases", [])] + if not labels: + raise ValueError("No benchmark cases available to plot") + + current = [case["current"]["statistics"]["mean_ms"] for case in payload["cases"]] + reference = [ + case["reference"]["statistics"]["mean_ms"] for case in payload["cases"] + ] + + import matplotlib.pyplot as plt # Imported lazily to avoid hard dependency + + x = range(len(labels)) + width = 0.35 + fig, ax = plt.subplots(figsize=(max(6, len(labels) * 1.5), 4)) + ax.bar([xi - width / 2 for xi in x], current, width, color="#2ca02c", label="Current") + ax.bar([xi + width / 2 for xi in x], reference, width, color="#888888", label="Reference") + ax.set_xticks(list(x)) + ax.set_xticklabels(labels, rotation=30, ha="right") + ax.set_ylabel("Time (ms)") + ax.set_title(f"{cls.__name__} Benchmark") + ax.legend() + fig.tight_layout() + + if save is not None: + fig.savefig(save, bbox_inches="tight") + else: + plt.show() + + return fig + + @classmethod + def _benchmark_forward(cls, *args, **kwargs): + raise NotImplementedError( + f"{cls.__name__} must implement _benchmark_forward for benchmarking" + ) + + @classmethod + def _benchmark_case_label(cls, index: int, inputs: Tuple[Any, ...]) -> str: + """Return a human readable label for a benchmark case.""" + + return f"case_{index}" + + @classmethod + def _benchmark_case_metadata(cls, inputs: Tuple[Any, ...]) -> dict: + """Summarize the benchmark inputs for downstream reporting.""" + + return { + "arguments": [cls._summarize_value(value) for value in inputs], + } + + @classmethod + def _summarize_value(cls, value: Any) -> Any: + if isinstance(value, torch.Tensor): + return { + "type": "tensor", + "shape": list(value.shape), + "dtype": str(value.dtype).replace("torch.", ""), + "device": str(value.device), + "requires_grad": bool(value.requires_grad), + } + if isinstance(value, (int, float, bool, str)): + return value + if isinstance(value, (list, tuple)): + if all(isinstance(elem, (int, float, bool, str)) for elem in value): + return list(value) + return { + "type": value.__class__.__name__, + "length": len(value), + } + if isinstance(value, dict): + summary = {} + for key, val in value.items(): + summary[str(key)] = cls._summarize_value(val) + return summary + return { + "type": value.__class__.__name__, + "repr": repr(value), + } + + @staticmethod + def _tuple_inputs(raw_inputs: Any) -> Tuple[Any, ...]: + if isinstance(raw_inputs, tuple): + return raw_inputs + if isinstance(raw_inputs, list): + return tuple(raw_inputs) + return (raw_inputs,) + + @classmethod + def _first_tensor(cls, inputs: Tuple[Any, ...]) -> torch.Tensor | None: + for value in inputs: + tensor = cls._extract_tensor(value) + if tensor is not None: + return tensor + return None + + @classmethod + def _extract_tensor(cls, value: Any) -> torch.Tensor | None: + if isinstance(value, torch.Tensor): + return value + if isinstance(value, (list, tuple)): + for elem in value: + tensor = cls._extract_tensor(elem) + if tensor is not None: + return tensor + return None + if isinstance(value, dict): + for elem in value.values(): + tensor = cls._extract_tensor(elem) + if tensor is not None: + return tensor + return None + + @classmethod + def _time_callable( + cls, + runner: Callable[..., Any], + args: Tuple[Any, ...], + repeats: int, + warmup: int, + device: torch.device, + ) -> list[float]: + use_cuda = device.type == "cuda" + if use_cuda: + torch.cuda.synchronize(device) + for _ in range(warmup): + runner(*args) + if use_cuda: + torch.cuda.synchronize(device) + + timings: list[float] = [] + if use_cuda: + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + for _ in range(repeats): + torch.cuda.synchronize(device) + start_event.record() + runner(*args) + end_event.record() + torch.cuda.synchronize(device) + timings.append(start_event.elapsed_time(end_event)) + else: + for _ in range(repeats): + t0 = time.perf_counter() + runner(*args) + timings.append((time.perf_counter() - t0) * 1e3) + return timings + + @classmethod + def _normalize_case( + cls, case: Tuple[Any, ...] + ) -> Tuple[str, dict[str, Any], Tuple[Any, ...]]: + if len(case) != 3: + raise ValueError( + "Benchmark cases must yield (label, init_kwargs, run_args) tuples" + ) + label, init_kwargs, run_args = case + if not isinstance(label, str): + raise TypeError("Benchmark case label must be a string") + if init_kwargs is None: + init_kwargs = {} + if not isinstance(init_kwargs, dict): + raise TypeError("Benchmark case init_kwargs must be a dict") + return label, init_kwargs, run_args + + @classmethod + def _instantiate_case(cls, init_kwargs: dict[str, Any]): + if not init_kwargs and not issubclass(cls, torch.nn.Module): + return None + if issubclass(cls, torch.nn.Module): + instance = cls(**init_kwargs) + instance.eval() + return instance + return None + + @classmethod + def _maybe_move_instance(cls, instance, device: torch.device) -> None: + if instance is None: + return + if isinstance(instance, torch.nn.Module): + try: + instance.to(device) + except Exception: # pragma: no cover - best effort move + pass + + @classmethod + def _resolve_current_runner(cls, instance) -> Callable[..., Any]: + if instance is not None: + return instance + return cls._benchmark_forward + + @classmethod + def _resolve_reference_runner(cls, instance) -> Callable[..., Any]: + if hasattr(cls, "reference_impl"): + return cls.reference_impl + raise NotImplementedError( + f"{cls.__name__} must define reference_impl for benchmarking" + ) + + @staticmethod + def _summarize_timings(timings: list[float]) -> dict: + if not timings: + return { + "mean_ms": 0.0, + "median_ms": 0.0, + "min_ms": 0.0, + "max_ms": 0.0, + } + sorted_times = sorted(timings) + count = len(sorted_times) + total = sum(sorted_times) + if count % 2: + median = sorted_times[count // 2] + else: + median = 0.5 * ( + sorted_times[count // 2 - 1] + sorted_times[count // 2] + ) + return { + "mean_ms": total / count, + "median_ms": median, + "min_ms": sorted_times[0], + "max_ms": sorted_times[-1], + } + + +__all__ = ["BenchmarkMixin"] diff --git a/physicsnemo/core/function.py b/physicsnemo/core/function.py new file mode 100644 index 0000000000..e777bfbda0 --- /dev/null +++ b/physicsnemo/core/function.py @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +from torch.autograd import Function as TorchAutogradFunction + +from physicsnemo.core.benchmark import BenchmarkMixin + +class Function(TorchAutogradFunction, BenchmarkMixin): + """Base class for PhysicsNeMo custom autograd functions.""" + + @classmethod + def _benchmark_forward(cls, values, *rest): + cls.apply(values, *rest) + + +__all__ = ["Function"] diff --git a/physicsnemo/core/module.py b/physicsnemo/core/module.py index ef677204be..9ade12e72e 100644 --- a/physicsnemo/core/module.py +++ b/physicsnemo/core/module.py @@ -33,6 +33,7 @@ import torch +from physicsnemo.core.benchmark import BenchmarkMixin from physicsnemo.core.base import RegisterableModule from physicsnemo.core.filesystem import _download_cached, _get_fs from physicsnemo.core.meta import ModelMetaData @@ -68,7 +69,7 @@ def _load_state_dict_with_logging( return missing_keys, unexpected_keys -class Module(RegisterableModule): +class Module(RegisterableModule, BenchmarkMixin): """The base class for all network models in PhysicsNeMo. This should be used as a direct replacement for torch.nn.module and provides diff --git a/physicsnemo/core/version_check.py b/physicsnemo/core/version_check.py index 250ef29f70..d26ae212c7 100644 --- a/physicsnemo/core/version_check.py +++ b/physicsnemo/core/version_check.py @@ -45,6 +45,21 @@ } +def check_min_version( + distribution_name: str, + min_version: str, + *, + hard_fail: bool = True, +) -> bool: + """Backwards-compatible helper that checks ``package >= min_version``.""" + + return ensure_available( + distribution_name, + spec=f">={min_version}", + hard_fail=hard_fail, + ) + + @functools.lru_cache(maxsize=None) def get_installed_version(distribution_name: str) -> Optional[str]: """ diff --git a/physicsnemo/nn/__init__.py b/physicsnemo/nn/__init__.py index ba5a46f4e5..4c113ac5c8 100644 --- a/physicsnemo/nn/__init__.py +++ b/physicsnemo/nn/__init__.py @@ -1,65 +1,7 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +"""Neural network building blocks for PhysicsNeMo.""" -from .activations import ( - CappedGELU, - CappedLeakyReLU, - Identity, - SquarePlus, - Stan, - get_activation, -) -from .ball_query import BQWarp -from .conv_layers import ConvBlock, CubeEmbedding -from .dgm_layers import DGMLayer -from .fourier_layers import ( - FourierFilter, - FourierLayer, - FourierMLP, - GaborFilter, - fourier_encode, -) -from .fully_connected_layers import ( - Conv1dFCLayer, - Conv2dFCLayer, - Conv3dFCLayer, - ConvNdFCLayer, - ConvNdKernel1Layer, - FCLayer, -) -from .kan_layers import KolmogorovArnoldNetwork -from .mlp_layers import Mlp -from .resample_layers import ( - DownSample2D, - DownSample3D, - UpSample2D, - UpSample3D, -) -from .siren_layers import SirenLayer, SirenLayerType -from .spectral_layers import ( - SpectralConv1d, - SpectralConv2d, - SpectralConv3d, - SpectralConv4d, -) -from .transformer_layers import ( - DecoderLayer, - EncoderLayer, - FuserLayer, - SwinTransformer, -) -from .weight_fact import WeightFactLinear -from .weight_norm import WeightNormLinear +from physicsnemo.core import Module +# from physicsnemo.nn.module import FiniteDifferenceNd +from physicsnemo.nn.module.finite_difference import FiniteDifferenceNd + +__all__ = ["Module", "FiniteDifferenceNd"] diff --git a/physicsnemo/nn/functional/__init__.py b/physicsnemo/nn/functional/__init__.py new file mode 100644 index 0000000000..7d629d82e2 --- /dev/null +++ b/physicsnemo/nn/functional/__init__.py @@ -0,0 +1,5 @@ +"""Functional operators for PhysicsNeMo.""" + +from .finite_difference import FiniteDifference + +__all__ = ["FiniteDifference"] diff --git a/physicsnemo/nn/functional/finite_difference.py b/physicsnemo/nn/functional/finite_difference.py new file mode 100644 index 0000000000..d274de2e2f --- /dev/null +++ b/physicsnemo/nn/functional/finite_difference.py @@ -0,0 +1,367 @@ +"""Finite difference functional backed by Warp custom ops.""" + +from __future__ import annotations + +from typing import Any, Iterable, Sequence, Tuple + +import torch +from torch import Tensor + +try: + import warp as wp + wp.init() +except ImportError as err: # pragma: no cover - ImportError only raised in misconfigured envs. + raise ImportError( + "Warp is required for the finite difference functional. Install it from https://github.com/NVIDIA/warp" + ) from err + +from physicsnemo.core.function import Function + +wp.config.quiet = True + + +def _normalize_spacing(spacing: Sequence[float] | float, dims: int) -> Tuple[float, ...]: + if isinstance(spacing, Iterable) and not isinstance(spacing, (str, bytes)): + spacing_values = tuple(float(s) for s in spacing) + else: + spacing_values = (float(spacing),) + + if len(spacing_values) == 1 and dims > 1: + spacing_values = spacing_values * dims + + if len(spacing_values) != dims: + raise ValueError( + f"Spacing must provide {dims} value(s) for the spatial dimensions, got {spacing_values}" + ) + + if any(s <= 0.0 for s in spacing_values): + raise ValueError(f"Spacing values must be positive, got {spacing_values}") + + return spacing_values + + +def _prepare_input(values: Tensor, has_batch: bool) -> tuple[Tensor, bool]: + if has_batch: + return values, False + return values.unsqueeze(0), True + + +def _check_spatial_shape(spatial_shape: Tuple[int, ...]) -> None: + for size in spatial_shape: + if size < 3: + raise ValueError( + "Finite differences require at least three points per spatial dimension; " + f"got shape {spatial_shape}." + ) + + +def _get_wp_context(tensor: Tensor) -> tuple[wp.stream, str | None]: + if tensor.device.type == "cuda": + stream = wp.stream_from_torch(torch.cuda.current_stream(tensor.device)) + return stream, None + return None, "cpu" + + +def _launch_kernel(kernel, dim: int, inputs: list, stream: wp.stream, device: str | None) -> None: + wp.launch( + kernel, + dim=dim, + inputs=inputs, + stream=stream, + device=device, + ) + + +@wp.func +def _wrap_index(idx: int, size: int): + result = idx % size + if result < 0: + result += size + return result + + +@wp.kernel +def _finite_difference_1d_kernel( + values: wp.array(dtype=wp.float32, ndim=2), + gradients: wp.array(dtype=wp.float32, ndim=2), + batch: int, + dim0: int, + spacing0: float, +): + tid = wp.tid() + total = batch * dim0 + if tid >= total: + return + + b = tid // dim0 + i = tid - b * dim0 + + ip = _wrap_index(i + 1, dim0) + im = _wrap_index(i - 1, dim0) + + grad = (values[b, ip] - values[b, im]) / (2.0 * spacing0) + + gradients[b, i] = grad + + +@wp.kernel +def _finite_difference_2d_kernel( + values: wp.array(dtype=wp.float32, ndim=3), + gradients: wp.array(dtype=wp.vec2f, ndim=3), + batch: int, + dim0: int, + dim1: int, + spacing0: float, + spacing1: float, +): + tid = wp.tid() + total = batch * dim0 * dim1 + if tid >= total: + return + + plane = dim0 * dim1 + b = tid // plane + rem = tid - b * plane + i = rem // dim1 + j = rem - i * dim1 + + # Axis 0 derivative + ip = _wrap_index(i + 1, dim0) + im = _wrap_index(i - 1, dim0) + grad0 = (values[b, ip, j] - values[b, im, j]) / (2.0 * spacing0) + + # Axis 1 derivative + jp = _wrap_index(j + 1, dim1) + jm = _wrap_index(j - 1, dim1) + grad1 = (values[b, i, jp] - values[b, i, jm]) / (2.0 * spacing1) + + gradients[b, i, j] = wp.vec2f(grad0, grad1) + + +@wp.kernel +def _finite_difference_3d_kernel( + values: wp.array(dtype=wp.float32, ndim=4), + gradients: wp.array(dtype=wp.vec3f, ndim=4), + batch: int, + dim0: int, + dim1: int, + dim2: int, + spacing0: float, + spacing1: float, + spacing2: float, +): + tid = wp.tid() + total = batch * dim0 * dim1 * dim2 + if tid >= total: + return + + plane = dim1 * dim2 + volume = dim0 * plane + b = tid // volume + rem = tid - b * volume + i = rem // plane + rem = rem - i * plane + j = rem // dim2 + k = rem - j * dim2 + + # Axis 0 derivative + ip = _wrap_index(i + 1, dim0) + im = _wrap_index(i - 1, dim0) + grad0 = (values[b, ip, j, k] - values[b, im, j, k]) / (2.0 * spacing0) + + # Axis 1 derivative + jp = _wrap_index(j + 1, dim1) + jm = _wrap_index(j - 1, dim1) + grad1 = (values[b, i, jp, k] - values[b, i, jm, k]) / (2.0 * spacing1) + + # Axis 2 derivative + kp = _wrap_index(k + 1, dim2) + km = _wrap_index(k - 1, dim2) + grad2 = (values[b, i, j, kp] - values[b, i, j, km]) / (2.0 * spacing2) + + gradients[b, i, j, k] = wp.vec3f(grad0, grad1, grad2) + + +def _run_finite_difference( + values: Tensor, + spacing: Tuple[float, ...], + has_batch: bool, +) -> Tensor: + if not values.is_floating_point(): + raise TypeError("Finite differences require floating point tensors") + + dims = values.dim() - (1 if has_batch else 0) + if dims not in (1, 2, 3): + raise ValueError( + f"Finite differences support 1D, 2D, or 3D inputs, got tensor with {values.dim()} dims" + ) + + spacing_tuple = _normalize_spacing(spacing, dims) + + values32 = values.to(torch.float32).contiguous() + prepared, added_batch = _prepare_input(values32, has_batch) + batch = prepared.shape[0] + spatial_shape = tuple(int(s) for s in prepared.shape[1:]) + _check_spatial_shape(spatial_shape) + + if dims == 1: + grad_shape = (batch, *spatial_shape) + grad_dtype = wp.float32 + elif dims == 2: + grad_shape = (batch, *spatial_shape, dims) + grad_dtype = wp.vec2f + else: + grad_shape = (batch, *spatial_shape, dims) + grad_dtype = wp.vec3f + + gradients = torch.empty( + grad_shape, + device=values.device, + dtype=torch.float32, + ) + + stream, wp_device = _get_wp_context(prepared) + + with wp.ScopedStream(stream): + # wp_values = wp.from_torch(prepared, dtype=wp.float32, return_ctype=True) + wp_values = wp.from_torch(prepared, dtype=wp.float32) + # wp_grads = wp.from_torch(gradients, dtype=wp.float32, return_ctype=True) + wp_grads = wp.from_torch(gradients, dtype=grad_dtype) + + if dims == 1: + inputs = [wp_values, wp_grads, batch, spatial_shape[0], spacing_tuple[0]] + _launch_kernel( + _finite_difference_1d_kernel, + batch * spatial_shape[0], + inputs, + stream, + wp_device, + ) + elif dims == 2: + inputs = [ + wp_values, + wp_grads, + batch, + spatial_shape[0], + spatial_shape[1], + spacing_tuple[0], + spacing_tuple[1], + ] + _launch_kernel( + _finite_difference_2d_kernel, + batch * spatial_shape[0] * spatial_shape[1], + inputs, + stream, + wp_device, + ) + else: + inputs = [ + wp_values, + wp_grads, + batch, + spatial_shape[0], + spatial_shape[1], + spatial_shape[2], + spacing_tuple[0], + spacing_tuple[1], + spacing_tuple[2], + ] + _launch_kernel( + _finite_difference_3d_kernel, + batch * spatial_shape[0] * spatial_shape[1] * spatial_shape[2], + inputs, + stream, + wp_device, + ) + + if dims == 1: + gradients = gradients.unsqueeze(1) + else: + gradients = gradients.movedim(-1, 1) + + if added_batch: + gradients = gradients.squeeze(0) + + if values.dtype != torch.float32: + gradients = gradients.to(values.dtype) + + return gradients + + +@torch.library.custom_op("physicsnemo::finite_difference_nd", mutates_args=()) +def finite_difference_op( + values: Tensor, + spacing: Sequence[float], + has_batch: bool = True, +) -> Tensor: + return _run_finite_difference(values, spacing, has_batch) + + +@finite_difference_op.register_fake +def _finite_difference_fake( + values: Tensor, + spacing: Sequence[float], + has_batch: bool = True, +) -> Tensor: + dims = values.dim() - (1 if has_batch else 0) + if dims not in (1, 2, 3): + raise RuntimeError( + "Finite differences support only 1D, 2D, or 3D inputs in fake tensor mode" + ) + + spatial_shape = values.shape[1:] if has_batch else values.shape + if has_batch: + shape = (values.shape[0], dims, *spatial_shape) + else: + shape = (dims, *spatial_shape) + + return values.new_empty(shape) + + +class FiniteDifference(Function): + """Autograd wrapper around the custom finite difference op.""" + + @staticmethod + def forward(ctx, values: Tensor, spacing: Sequence[float], has_batch: bool = True): + output = torch.ops.physicsnemo.finite_difference_nd(values, spacing, has_batch) + ctx.mark_non_differentiable(output) + return output + + @staticmethod + def backward(ctx, *grad_outputs): # pragma: no cover - no autograd yet + return None, None, None + + @classmethod + def make_inputs( + cls, + ) -> Iterable[tuple[str, dict[str, Any], tuple[Tensor, Sequence[float], bool]]]: + configs = { + "1D": ((1, 1024*8), (1.0,)), + "2D": ((1, 256, 256), (1.0, 1.2)), + "3D": ((1, 64, 64, 64), (0.8, 1.0, 1.2)), + } + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + for label, (shape, spacing) in configs.items(): + values = torch.randn(shape, device=device, dtype=torch.float32) + yield label, {}, (values, spacing, True) + + @classmethod + def reference_impl( + cls, values: Tensor, spacing: Sequence[float], has_batch: bool = True + ) -> Tensor: + dims = values.dim() - (1 if has_batch else 0) + tensor = values if has_batch else values.unsqueeze(0) + grads = [] + for axis in range(1, dims + 1): + forward = torch.roll(tensor, shifts=-1, dims=axis) + backward = torch.roll(tensor, shifts=1, dims=axis) + grads.append((forward - backward) / (2.0 * spacing[axis - 1])) + stacked = torch.stack(grads, dim=1) + return stacked if has_batch else stacked.squeeze(0) + + @classmethod + def check(cls, actual: Tensor, expected: Tensor) -> None: + torch.testing.assert_close(actual, expected, rtol=1e-4, atol=1e-4) + + +__all__ = ["FiniteDifference"] diff --git a/physicsnemo/nn/module/__init__.py b/physicsnemo/nn/module/__init__.py new file mode 100644 index 0000000000..2bd0044494 --- /dev/null +++ b/physicsnemo/nn/module/__init__.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .activations import ( + CappedGELU, + CappedLeakyReLU, + Identity, + SquarePlus, + Stan, + get_activation, +) +from .ball_query import BQWarp +from .conv_layers import ConvBlock, CubeEmbedding +from .dgm_layers import DGMLayer +from .finite_difference import FiniteDifferenceNd +from .fourier_layers import ( + FourierFilter, + FourierLayer, + FourierMLP, + GaborFilter, + fourier_encode, +) +from .fully_connected_layers import ( + Conv1dFCLayer, + Conv2dFCLayer, + Conv3dFCLayer, + ConvNdFCLayer, + ConvNdKernel1Layer, + FCLayer, +) +from .kan_layers import KolmogorovArnoldNetwork +from .mlp_layers import Mlp +from .resample_layers import ( + DownSample2D, + DownSample3D, + UpSample2D, + UpSample3D, +) +from .siren_layers import SirenLayer, SirenLayerType +from .spectral_layers import ( + SpectralConv1d, + SpectralConv2d, + SpectralConv3d, + SpectralConv4d, +) +from .transformer_layers import ( + DecoderLayer, + EncoderLayer, + FuserLayer, + SwinTransformer, +) +from .weight_fact import WeightFactLinear +from .weight_norm import WeightNormLinear diff --git a/physicsnemo/nn/activations.py b/physicsnemo/nn/module/activations.py similarity index 100% rename from physicsnemo/nn/activations.py rename to physicsnemo/nn/module/activations.py diff --git a/physicsnemo/nn/attention_layers.py b/physicsnemo/nn/module/attention_layers.py similarity index 100% rename from physicsnemo/nn/attention_layers.py rename to physicsnemo/nn/module/attention_layers.py diff --git a/physicsnemo/nn/ball_query.py b/physicsnemo/nn/module/ball_query.py similarity index 98% rename from physicsnemo/nn/ball_query.py rename to physicsnemo/nn/module/ball_query.py index bb759998d5..f9701d595b 100644 --- a/physicsnemo/nn/ball_query.py +++ b/physicsnemo/nn/module/ball_query.py @@ -26,7 +26,7 @@ import torch.nn as nn from einops import rearrange -from physicsnemo.nn.neighbors import radius_search +from physicsnemo.nn.module.neighbors import radius_search class BQWarp(nn.Module): diff --git a/physicsnemo/nn/conv_layers.py b/physicsnemo/nn/module/conv_layers.py similarity index 100% rename from physicsnemo/nn/conv_layers.py rename to physicsnemo/nn/module/conv_layers.py diff --git a/physicsnemo/nn/dgm_layers.py b/physicsnemo/nn/module/dgm_layers.py similarity index 100% rename from physicsnemo/nn/dgm_layers.py rename to physicsnemo/nn/module/dgm_layers.py diff --git a/physicsnemo/nn/drop.py b/physicsnemo/nn/module/drop.py similarity index 100% rename from physicsnemo/nn/drop.py rename to physicsnemo/nn/module/drop.py diff --git a/physicsnemo/nn/fft.py b/physicsnemo/nn/module/fft.py similarity index 100% rename from physicsnemo/nn/fft.py rename to physicsnemo/nn/module/fft.py diff --git a/physicsnemo/nn/module/finite_difference.py b/physicsnemo/nn/module/finite_difference.py new file mode 100644 index 0000000000..5dffdce303 --- /dev/null +++ b/physicsnemo/nn/module/finite_difference.py @@ -0,0 +1,108 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any, Iterable, Sequence, Tuple + +import torch +from torch import Tensor + +from physicsnemo.core import Module +from physicsnemo.nn.functional.finite_difference import ( + FiniteDifference, + _normalize_spacing, +) + + +class FiniteDifferenceNd(Module): + """Finite-difference stencil implemented with Warp-backed functionals. + + Parameters + ---------- + spacing: + Grid spacing for each spatial dimension. Provide a single value to + reuse across all axes. + has_batch: + Whether the input tensors include a batch dimension as the first axis. + """ + + def __init__( + self, + spacing: Sequence[float] | float, + has_batch: bool = True, + ) -> None: + super().__init__() + if isinstance(spacing, Sequence) and not isinstance(spacing, (str, bytes)): + self.spacing: Sequence[float] | float = tuple(float(s) for s in spacing) + else: + self.spacing = float(spacing) + self.has_batch = has_batch + + def forward(self, values: Tensor) -> Tensor: + """Apply the finite difference stencil.""" + + if not torch.is_tensor(values): + raise TypeError("values must be a torch.Tensor") + dims = values.dim() - (1 if self.has_batch else 0) + spacing_tuple = _normalize_spacing(self.spacing, dims) + return FiniteDifference.apply(values, spacing_tuple, self.has_batch) + + @classmethod + def make_inputs( + cls, + ) -> Iterable[tuple[str, dict[str, Any], Tuple[Tensor]]]: + configs = { + "1D": ((1, 64), (1.0,)), + "2D": ((1, 32, 32), (1.0, 1.2)), + "3D": ((1, 16, 16, 16), (0.8, 1.0, 1.2)), + } + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + for label, (shape, spacing) in configs.items(): + values = torch.randn(shape, device=device, dtype=torch.float32) + init_kwargs: dict[str, Any] = { + "spacing": spacing, + "has_batch": True, + } + yield label, init_kwargs, (values,) + + @classmethod + def reference_impl( + cls, values: Tensor, spacing: Sequence[float] | float, has_batch: bool = True + ) -> Tensor: + dims = values.dim() - (1 if has_batch else 0) + spacing_tuple = _normalize_spacing(spacing, dims) + return FiniteDifference.reference_impl(values, spacing_tuple, has_batch) + + @classmethod + def _resolve_reference_runner(cls, instance): + if instance is None: + return super()._resolve_reference_runner(instance) + + spacing = instance.spacing + has_batch = instance.has_batch + + def runner(values: Tensor) -> Tensor: + return cls.reference_impl(values, spacing, has_batch) + + return runner + + @classmethod + def check(cls, actual: Tensor, expected: Tensor) -> None: + FiniteDifference.check(actual, expected) + + +__all__ = ["FiniteDifferenceNd"] diff --git a/physicsnemo/nn/fourier_layers.py b/physicsnemo/nn/module/fourier_layers.py similarity index 100% rename from physicsnemo/nn/fourier_layers.py rename to physicsnemo/nn/module/fourier_layers.py diff --git a/physicsnemo/nn/fully_connected_layers.py b/physicsnemo/nn/module/fully_connected_layers.py similarity index 100% rename from physicsnemo/nn/fully_connected_layers.py rename to physicsnemo/nn/module/fully_connected_layers.py diff --git a/physicsnemo/nn/fused_silu.py b/physicsnemo/nn/module/fused_silu.py similarity index 100% rename from physicsnemo/nn/fused_silu.py rename to physicsnemo/nn/module/fused_silu.py diff --git a/physicsnemo/nn/gnn_layers/__init__.py b/physicsnemo/nn/module/gnn_layers/__init__.py similarity index 100% rename from physicsnemo/nn/gnn_layers/__init__.py rename to physicsnemo/nn/module/gnn_layers/__init__.py diff --git a/physicsnemo/nn/gnn_layers/bsms.py b/physicsnemo/nn/module/gnn_layers/bsms.py similarity index 100% rename from physicsnemo/nn/gnn_layers/bsms.py rename to physicsnemo/nn/module/gnn_layers/bsms.py diff --git a/physicsnemo/nn/gnn_layers/distributed_graph.py b/physicsnemo/nn/module/gnn_layers/distributed_graph.py similarity index 100% rename from physicsnemo/nn/gnn_layers/distributed_graph.py rename to physicsnemo/nn/module/gnn_layers/distributed_graph.py diff --git a/physicsnemo/nn/gnn_layers/embedder.py b/physicsnemo/nn/module/gnn_layers/embedder.py similarity index 100% rename from physicsnemo/nn/gnn_layers/embedder.py rename to physicsnemo/nn/module/gnn_layers/embedder.py diff --git a/physicsnemo/nn/gnn_layers/graph.py b/physicsnemo/nn/module/gnn_layers/graph.py similarity index 100% rename from physicsnemo/nn/gnn_layers/graph.py rename to physicsnemo/nn/module/gnn_layers/graph.py diff --git a/physicsnemo/nn/gnn_layers/mesh_edge_block.py b/physicsnemo/nn/module/gnn_layers/mesh_edge_block.py similarity index 100% rename from physicsnemo/nn/gnn_layers/mesh_edge_block.py rename to physicsnemo/nn/module/gnn_layers/mesh_edge_block.py diff --git a/physicsnemo/nn/gnn_layers/mesh_graph_decoder.py b/physicsnemo/nn/module/gnn_layers/mesh_graph_decoder.py similarity index 100% rename from physicsnemo/nn/gnn_layers/mesh_graph_decoder.py rename to physicsnemo/nn/module/gnn_layers/mesh_graph_decoder.py diff --git a/physicsnemo/nn/gnn_layers/mesh_graph_encoder.py b/physicsnemo/nn/module/gnn_layers/mesh_graph_encoder.py similarity index 100% rename from physicsnemo/nn/gnn_layers/mesh_graph_encoder.py rename to physicsnemo/nn/module/gnn_layers/mesh_graph_encoder.py diff --git a/physicsnemo/nn/gnn_layers/mesh_graph_mlp.py b/physicsnemo/nn/module/gnn_layers/mesh_graph_mlp.py similarity index 100% rename from physicsnemo/nn/gnn_layers/mesh_graph_mlp.py rename to physicsnemo/nn/module/gnn_layers/mesh_graph_mlp.py diff --git a/physicsnemo/nn/gnn_layers/mesh_node_block.py b/physicsnemo/nn/module/gnn_layers/mesh_node_block.py similarity index 100% rename from physicsnemo/nn/gnn_layers/mesh_node_block.py rename to physicsnemo/nn/module/gnn_layers/mesh_node_block.py diff --git a/physicsnemo/nn/gnn_layers/utils.py b/physicsnemo/nn/module/gnn_layers/utils.py similarity index 100% rename from physicsnemo/nn/gnn_layers/utils.py rename to physicsnemo/nn/module/gnn_layers/utils.py diff --git a/physicsnemo/nn/interpolation.py b/physicsnemo/nn/module/interpolation.py similarity index 100% rename from physicsnemo/nn/interpolation.py rename to physicsnemo/nn/module/interpolation.py diff --git a/physicsnemo/nn/kan_layers.py b/physicsnemo/nn/module/kan_layers.py similarity index 100% rename from physicsnemo/nn/kan_layers.py rename to physicsnemo/nn/module/kan_layers.py diff --git a/physicsnemo/nn/layer_norm.py b/physicsnemo/nn/module/layer_norm.py similarity index 100% rename from physicsnemo/nn/layer_norm.py rename to physicsnemo/nn/module/layer_norm.py diff --git a/physicsnemo/nn/mlp_layers.py b/physicsnemo/nn/module/mlp_layers.py similarity index 100% rename from physicsnemo/nn/mlp_layers.py rename to physicsnemo/nn/module/mlp_layers.py diff --git a/physicsnemo/nn/neighbors/__init__.py b/physicsnemo/nn/module/neighbors/__init__.py similarity index 100% rename from physicsnemo/nn/neighbors/__init__.py rename to physicsnemo/nn/module/neighbors/__init__.py diff --git a/physicsnemo/nn/neighbors/_knn/__init__.py b/physicsnemo/nn/module/neighbors/_knn/__init__.py similarity index 100% rename from physicsnemo/nn/neighbors/_knn/__init__.py rename to physicsnemo/nn/module/neighbors/_knn/__init__.py diff --git a/physicsnemo/nn/neighbors/_knn/_cuml_impl.py b/physicsnemo/nn/module/neighbors/_knn/_cuml_impl.py similarity index 100% rename from physicsnemo/nn/neighbors/_knn/_cuml_impl.py rename to physicsnemo/nn/module/neighbors/_knn/_cuml_impl.py diff --git a/physicsnemo/nn/neighbors/_knn/_scipy_impl.py b/physicsnemo/nn/module/neighbors/_knn/_scipy_impl.py similarity index 100% rename from physicsnemo/nn/neighbors/_knn/_scipy_impl.py rename to physicsnemo/nn/module/neighbors/_knn/_scipy_impl.py diff --git a/physicsnemo/nn/neighbors/_knn/_torch_impl.py b/physicsnemo/nn/module/neighbors/_knn/_torch_impl.py similarity index 100% rename from physicsnemo/nn/neighbors/_knn/_torch_impl.py rename to physicsnemo/nn/module/neighbors/_knn/_torch_impl.py diff --git a/physicsnemo/nn/neighbors/_knn/knn.py b/physicsnemo/nn/module/neighbors/_knn/knn.py similarity index 100% rename from physicsnemo/nn/neighbors/_knn/knn.py rename to physicsnemo/nn/module/neighbors/_knn/knn.py diff --git a/physicsnemo/nn/neighbors/_radius_search/__init__.py b/physicsnemo/nn/module/neighbors/_radius_search/__init__.py similarity index 100% rename from physicsnemo/nn/neighbors/_radius_search/__init__.py rename to physicsnemo/nn/module/neighbors/_radius_search/__init__.py diff --git a/physicsnemo/nn/neighbors/_radius_search/_torch_impl.py b/physicsnemo/nn/module/neighbors/_radius_search/_torch_impl.py similarity index 100% rename from physicsnemo/nn/neighbors/_radius_search/_torch_impl.py rename to physicsnemo/nn/module/neighbors/_radius_search/_torch_impl.py diff --git a/physicsnemo/nn/neighbors/_radius_search/_warp_impl.py b/physicsnemo/nn/module/neighbors/_radius_search/_warp_impl.py similarity index 100% rename from physicsnemo/nn/neighbors/_radius_search/_warp_impl.py rename to physicsnemo/nn/module/neighbors/_radius_search/_warp_impl.py diff --git a/physicsnemo/nn/neighbors/_radius_search/kernels.py b/physicsnemo/nn/module/neighbors/_radius_search/kernels.py similarity index 100% rename from physicsnemo/nn/neighbors/_radius_search/kernels.py rename to physicsnemo/nn/module/neighbors/_radius_search/kernels.py diff --git a/physicsnemo/nn/neighbors/_radius_search/radius_search.py b/physicsnemo/nn/module/neighbors/_radius_search/radius_search.py similarity index 100% rename from physicsnemo/nn/neighbors/_radius_search/radius_search.py rename to physicsnemo/nn/module/neighbors/_radius_search/radius_search.py diff --git a/physicsnemo/nn/resample_layers.py b/physicsnemo/nn/module/resample_layers.py similarity index 100% rename from physicsnemo/nn/resample_layers.py rename to physicsnemo/nn/module/resample_layers.py diff --git a/physicsnemo/nn/sdf.py b/physicsnemo/nn/module/sdf.py similarity index 100% rename from physicsnemo/nn/sdf.py rename to physicsnemo/nn/module/sdf.py diff --git a/physicsnemo/nn/siren_layers.py b/physicsnemo/nn/module/siren_layers.py similarity index 100% rename from physicsnemo/nn/siren_layers.py rename to physicsnemo/nn/module/siren_layers.py diff --git a/physicsnemo/nn/spectral_layers.py b/physicsnemo/nn/module/spectral_layers.py similarity index 100% rename from physicsnemo/nn/spectral_layers.py rename to physicsnemo/nn/module/spectral_layers.py diff --git a/physicsnemo/nn/transformer_decoder.py b/physicsnemo/nn/module/transformer_decoder.py similarity index 100% rename from physicsnemo/nn/transformer_decoder.py rename to physicsnemo/nn/module/transformer_decoder.py diff --git a/physicsnemo/nn/transformer_layers.py b/physicsnemo/nn/module/transformer_layers.py similarity index 100% rename from physicsnemo/nn/transformer_layers.py rename to physicsnemo/nn/module/transformer_layers.py diff --git a/physicsnemo/nn/module/utils/__init__.py b/physicsnemo/nn/module/utils/__init__.py new file mode 100644 index 0000000000..538d95cf4b --- /dev/null +++ b/physicsnemo/nn/module/utils/__init__.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .patch_embed import ( + PatchEmbed2D, + PatchEmbed3D, + PatchRecovery2D, + PatchRecovery3D, +) +from .shift_window_mask import ( + get_shift_window_mask, + window_partition, + window_reverse, +) +from .utils import ( + crop2d, + crop3d, + get_earth_position_index, + get_pad2d, + get_pad3d, +) +from .weight_init import trunc_normal_ diff --git a/physicsnemo/nn/utils/patch_embed.py b/physicsnemo/nn/module/utils/patch_embed.py similarity index 100% rename from physicsnemo/nn/utils/patch_embed.py rename to physicsnemo/nn/module/utils/patch_embed.py diff --git a/physicsnemo/nn/utils/shift_window_mask.py b/physicsnemo/nn/module/utils/shift_window_mask.py similarity index 100% rename from physicsnemo/nn/utils/shift_window_mask.py rename to physicsnemo/nn/module/utils/shift_window_mask.py diff --git a/physicsnemo/nn/utils/utils.py b/physicsnemo/nn/module/utils/utils.py similarity index 100% rename from physicsnemo/nn/utils/utils.py rename to physicsnemo/nn/module/utils/utils.py diff --git a/physicsnemo/nn/utils/weight_init.py b/physicsnemo/nn/module/utils/weight_init.py similarity index 100% rename from physicsnemo/nn/utils/weight_init.py rename to physicsnemo/nn/module/utils/weight_init.py diff --git a/physicsnemo/nn/weight_fact.py b/physicsnemo/nn/module/weight_fact.py similarity index 100% rename from physicsnemo/nn/weight_fact.py rename to physicsnemo/nn/module/weight_fact.py diff --git a/physicsnemo/nn/weight_norm.py b/physicsnemo/nn/module/weight_norm.py similarity index 100% rename from physicsnemo/nn/weight_norm.py rename to physicsnemo/nn/module/weight_norm.py diff --git a/physicsnemo/nn/utils/__init__.py b/physicsnemo/nn/utils/__init__.py index 538d95cf4b..343d286e6e 100644 --- a/physicsnemo/nn/utils/__init__.py +++ b/physicsnemo/nn/utils/__init__.py @@ -14,22 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .patch_embed import ( - PatchEmbed2D, - PatchEmbed3D, - PatchRecovery2D, - PatchRecovery3D, -) -from .shift_window_mask import ( - get_shift_window_mask, - window_partition, - window_reverse, -) -from .utils import ( - crop2d, - crop3d, - get_earth_position_index, - get_pad2d, - get_pad3d, -) -from .weight_init import trunc_normal_ +"""Backward-compatible proxy to the reorganized nn.module.utils package.""" + +from physicsnemo.nn.module.utils import * # noqa: F401,F403 + +__all__ = [name for name in globals().keys() if not name.startswith("_")] diff --git a/pyproject.toml b/pyproject.toml index ce75a83248..14c4bc1e43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,9 @@ Documentation = "https://docs.nvidia.com/physicsnemo/index.html#core" Issues = "https://github.com/NVIDIA/physicsnemo/issues" Changelog = "https://github.com/NVIDIA/physicsnemo/blob/main/CHANGELOG.md" +[tool.setuptools] +packages = ["physicsnemo"] + [tool.ruff] # Enable flake8/pycodestyle (`E`), Pyflakes (`F`), flake8-bandit (`S`), From 34f26aaaf9f3b1d3717c4423daae4d93388feea2 Mon Sep 17 00:00:00 2001 From: Oliver Date: Wed, 3 Dec 2025 16:45:27 -0800 Subject: [PATCH 02/15] removed benchmarking stuff --- physicsnemo/core/__init__.py | 2 +- physicsnemo/core/benchmark.py | 327 +--------------------------------- 2 files changed, 4 insertions(+), 325 deletions(-) diff --git a/physicsnemo/core/__init__.py b/physicsnemo/core/__init__.py index 9a421bd657..baedbeebc0 100644 --- a/physicsnemo/core/__init__.py +++ b/physicsnemo/core/__init__.py @@ -14,10 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .function import Function from .meta import ModelMetaData from .module import Module from .registry import ModelRegistry +from .function import Function from .version_check import check_version_spec __all__ = ["ModelMetaData", "Module", "ModelRegistry", "Function"] diff --git a/physicsnemo/core/benchmark.py b/physicsnemo/core/benchmark.py index 54d7227514..1189426adf 100644 --- a/physicsnemo/core/benchmark.py +++ b/physicsnemo/core/benchmark.py @@ -14,8 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Common benchmarking hooks shared by autograd functions and modules.""" - from __future__ import annotations import json @@ -25,9 +23,6 @@ import torch -_DEFAULT_RESULTS_DIR = Path("benchmarks/results") - - class BenchmarkMixin: """Defines the hooks needed by the benchmark and validation helpers.""" @@ -61,7 +56,7 @@ def benchmark( *, repeats: int = 10, warmup: int = 1, - ) -> dict: + ) -> None: """Collect benchmark data for downstream consumers. Parameters @@ -71,324 +66,8 @@ def benchmark( warmup Runs discarded before timings are collected. Warmups help amortize one-time setup costs. - - Returns - ------- - dict - Structured benchmark payload that can be consumed by CI, ASV, or - plotting utilities. """ - if repeats <= 0: - raise ValueError("repeats must be positive") - if warmup < 0: - raise ValueError("warmup must be non-negative") - - cases = [] - devices = set() - dtypes = set() - for index, case in enumerate(cls.make_inputs()): - label, init_kwargs, run_args = cls._normalize_case(case) - inputs = cls._tuple_inputs(run_args) - tensor = cls._first_tensor(inputs) - if tensor is None: - case_device = torch.device("cpu") - case_dtype = "unknown" - else: - case_device = tensor.device - case_dtype = str(tensor.dtype).replace("torch.", "") - - instance = cls._instantiate_case(init_kwargs) - cls._maybe_move_instance(instance, case_device) - - current_runner = cls._resolve_current_runner(instance) - reference_runner = cls._resolve_reference_runner(instance) - - current_timings = cls._time_callable( - current_runner, inputs, repeats, warmup, case_device - ) - reference_timings = cls._time_callable( - reference_runner, inputs, repeats, warmup, case_device - ) - case_record = { - "label": label, - "index": index, - "metadata": cls._benchmark_case_metadata(inputs), - "device": str(case_device), - "dtype": case_dtype, - "init_kwargs": dict(init_kwargs), - "current": { - "timings_ms": current_timings, - "statistics": cls._summarize_timings(current_timings), - }, - "reference": { - "timings_ms": reference_timings, - "statistics": cls._summarize_timings(reference_timings), - }, - } - cases.append(case_record) - devices.add(str(case_device)) - dtypes.add(case_dtype) - - payload = { - "name": cls.__name__, - "qualified_name": f"{cls.__module__}.{cls.__name__}", - "generated_at": time.time(), - "options": { - "repeats": repeats, - "warmup": warmup, - }, - "devices": sorted(devices), - "dtypes": sorted(dtypes), - "cases": cases, - } - - return payload - - @staticmethod - def save_benchmark(payload: dict, directory: str | Path = _DEFAULT_RESULTS_DIR) -> Path: - """Persist a benchmark payload to disk.""" - - output_dir = Path(directory) - output_dir.mkdir(parents=True, exist_ok=True) - destination = output_dir / f"{payload['qualified_name']}.json" - destination.write_text(json.dumps(payload, indent=2)) - return destination - - @classmethod - def plot_benchmarks( - cls, - payload: dict | None = None, - *, - save: str | Path | None = None, - ): - """Render a bar chart comparing current vs reference timings.""" - - if payload is None: - payload = cls.benchmark() - - labels = [case["label"] for case in payload.get("cases", [])] - if not labels: - raise ValueError("No benchmark cases available to plot") - - current = [case["current"]["statistics"]["mean_ms"] for case in payload["cases"]] - reference = [ - case["reference"]["statistics"]["mean_ms"] for case in payload["cases"] - ] - - import matplotlib.pyplot as plt # Imported lazily to avoid hard dependency - - x = range(len(labels)) - width = 0.35 - fig, ax = plt.subplots(figsize=(max(6, len(labels) * 1.5), 4)) - ax.bar([xi - width / 2 for xi in x], current, width, color="#2ca02c", label="Current") - ax.bar([xi + width / 2 for xi in x], reference, width, color="#888888", label="Reference") - ax.set_xticks(list(x)) - ax.set_xticklabels(labels, rotation=30, ha="right") - ax.set_ylabel("Time (ms)") - ax.set_title(f"{cls.__name__} Benchmark") - ax.legend() - fig.tight_layout() - - if save is not None: - fig.savefig(save, bbox_inches="tight") - else: - plt.show() - - return fig - - @classmethod - def _benchmark_forward(cls, *args, **kwargs): - raise NotImplementedError( - f"{cls.__name__} must implement _benchmark_forward for benchmarking" - ) - - @classmethod - def _benchmark_case_label(cls, index: int, inputs: Tuple[Any, ...]) -> str: - """Return a human readable label for a benchmark case.""" - - return f"case_{index}" - - @classmethod - def _benchmark_case_metadata(cls, inputs: Tuple[Any, ...]) -> dict: - """Summarize the benchmark inputs for downstream reporting.""" - - return { - "arguments": [cls._summarize_value(value) for value in inputs], - } - - @classmethod - def _summarize_value(cls, value: Any) -> Any: - if isinstance(value, torch.Tensor): - return { - "type": "tensor", - "shape": list(value.shape), - "dtype": str(value.dtype).replace("torch.", ""), - "device": str(value.device), - "requires_grad": bool(value.requires_grad), - } - if isinstance(value, (int, float, bool, str)): - return value - if isinstance(value, (list, tuple)): - if all(isinstance(elem, (int, float, bool, str)) for elem in value): - return list(value) - return { - "type": value.__class__.__name__, - "length": len(value), - } - if isinstance(value, dict): - summary = {} - for key, val in value.items(): - summary[str(key)] = cls._summarize_value(val) - return summary - return { - "type": value.__class__.__name__, - "repr": repr(value), - } - - @staticmethod - def _tuple_inputs(raw_inputs: Any) -> Tuple[Any, ...]: - if isinstance(raw_inputs, tuple): - return raw_inputs - if isinstance(raw_inputs, list): - return tuple(raw_inputs) - return (raw_inputs,) - - @classmethod - def _first_tensor(cls, inputs: Tuple[Any, ...]) -> torch.Tensor | None: - for value in inputs: - tensor = cls._extract_tensor(value) - if tensor is not None: - return tensor - return None - - @classmethod - def _extract_tensor(cls, value: Any) -> torch.Tensor | None: - if isinstance(value, torch.Tensor): - return value - if isinstance(value, (list, tuple)): - for elem in value: - tensor = cls._extract_tensor(elem) - if tensor is not None: - return tensor - return None - if isinstance(value, dict): - for elem in value.values(): - tensor = cls._extract_tensor(elem) - if tensor is not None: - return tensor - return None - - @classmethod - def _time_callable( - cls, - runner: Callable[..., Any], - args: Tuple[Any, ...], - repeats: int, - warmup: int, - device: torch.device, - ) -> list[float]: - use_cuda = device.type == "cuda" - if use_cuda: - torch.cuda.synchronize(device) - for _ in range(warmup): - runner(*args) - if use_cuda: - torch.cuda.synchronize(device) - - timings: list[float] = [] - if use_cuda: - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - for _ in range(repeats): - torch.cuda.synchronize(device) - start_event.record() - runner(*args) - end_event.record() - torch.cuda.synchronize(device) - timings.append(start_event.elapsed_time(end_event)) - else: - for _ in range(repeats): - t0 = time.perf_counter() - runner(*args) - timings.append((time.perf_counter() - t0) * 1e3) - return timings - - @classmethod - def _normalize_case( - cls, case: Tuple[Any, ...] - ) -> Tuple[str, dict[str, Any], Tuple[Any, ...]]: - if len(case) != 3: - raise ValueError( - "Benchmark cases must yield (label, init_kwargs, run_args) tuples" - ) - label, init_kwargs, run_args = case - if not isinstance(label, str): - raise TypeError("Benchmark case label must be a string") - if init_kwargs is None: - init_kwargs = {} - if not isinstance(init_kwargs, dict): - raise TypeError("Benchmark case init_kwargs must be a dict") - return label, init_kwargs, run_args - - @classmethod - def _instantiate_case(cls, init_kwargs: dict[str, Any]): - if not init_kwargs and not issubclass(cls, torch.nn.Module): - return None - if issubclass(cls, torch.nn.Module): - instance = cls(**init_kwargs) - instance.eval() - return instance - return None - - @classmethod - def _maybe_move_instance(cls, instance, device: torch.device) -> None: - if instance is None: - return - if isinstance(instance, torch.nn.Module): - try: - instance.to(device) - except Exception: # pragma: no cover - best effort move - pass - - @classmethod - def _resolve_current_runner(cls, instance) -> Callable[..., Any]: - if instance is not None: - return instance - return cls._benchmark_forward - - @classmethod - def _resolve_reference_runner(cls, instance) -> Callable[..., Any]: - if hasattr(cls, "reference_impl"): - return cls.reference_impl raise NotImplementedError( - f"{cls.__name__} must define reference_impl for benchmarking" - ) - - @staticmethod - def _summarize_timings(timings: list[float]) -> dict: - if not timings: - return { - "mean_ms": 0.0, - "median_ms": 0.0, - "min_ms": 0.0, - "max_ms": 0.0, - } - sorted_times = sorted(timings) - count = len(sorted_times) - total = sum(sorted_times) - if count % 2: - median = sorted_times[count // 2] - else: - median = 0.5 * ( - sorted_times[count // 2 - 1] + sorted_times[count // 2] - ) - return { - "mean_ms": total / count, - "median_ms": median, - "min_ms": sorted_times[0], - "max_ms": sorted_times[-1], - } - - -__all__ = ["BenchmarkMixin"] + f"Benchmarking not supported yet" + ) \ No newline at end of file From ec1bf4ee3be03fb3994252b1d8e25ee2b41a8a27 Mon Sep 17 00:00:00 2001 From: Oliver Date: Wed, 3 Dec 2025 17:12:31 -0800 Subject: [PATCH 03/15] imports --- physicsnemo/core/module.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/physicsnemo/core/module.py b/physicsnemo/core/module.py index 2b20d0ea43..8a2035b099 100644 --- a/physicsnemo/core/module.py +++ b/physicsnemo/core/module.py @@ -34,7 +34,6 @@ import torch from physicsnemo.core.benchmark import BenchmarkMixin -from physicsnemo.core.base import RegisterableModule from physicsnemo.core.filesystem import _download_cached, _get_fs from physicsnemo.core.meta import ModelMetaData from physicsnemo.core.registry import ModelRegistry @@ -69,7 +68,7 @@ def _load_state_dict_with_logging( return missing_keys, unexpected_keys -class Module(RegisterableModule, BenchmarkMixin): +class Module(torch.nn.Module, BenchmarkMixin): """The base class for all network models in PhysicsNeMo. This should be used as a direct replacement for torch.nn.module and provides From 70e92f9fc858da91cfad6854bb18bcd92f9ed720 Mon Sep 17 00:00:00 2001 From: Oliver Date: Wed, 3 Dec 2025 17:13:17 -0800 Subject: [PATCH 04/15] imports --- physicsnemo/core/version_check.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/physicsnemo/core/version_check.py b/physicsnemo/core/version_check.py index 46307eb03f..cf00b7212f 100644 --- a/physicsnemo/core/version_check.py +++ b/physicsnemo/core/version_check.py @@ -32,21 +32,6 @@ from packaging.version import parse -def check_min_version( - distribution_name: str, - min_version: str, - *, - hard_fail: bool = True, -) -> bool: - """Backwards-compatible helper that checks ``package >= min_version``.""" - - return ensure_available( - distribution_name, - spec=f">={min_version}", - hard_fail=hard_fail, - ) - - @functools.lru_cache(maxsize=None) def get_installed_version(distribution_name: str) -> Optional[str]: """ From a37a65564398fda0bf0364d95c3f1b4549716c6a Mon Sep 17 00:00:00 2001 From: Oliver Date: Wed, 3 Dec 2025 17:17:13 -0800 Subject: [PATCH 05/15] imports --- physicsnemo/nn/__init__.py | 70 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 67 insertions(+), 3 deletions(-) diff --git a/physicsnemo/nn/__init__.py b/physicsnemo/nn/__init__.py index 4c113ac5c8..ea97824557 100644 --- a/physicsnemo/nn/__init__.py +++ b/physicsnemo/nn/__init__.py @@ -1,7 +1,71 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Neural network building blocks for PhysicsNeMo.""" +# Make physicsnemo.nn.Module an available import like torch.nn.Module from physicsnemo.core import Module -# from physicsnemo.nn.module import FiniteDifferenceNd -from physicsnemo.nn.module.finite_difference import FiniteDifferenceNd -__all__ = ["Module", "FiniteDifferenceNd"] +from .module.activations import ( + CappedGELU, + CappedLeakyReLU, + Identity, + SquarePlus, + Stan, + get_activation, +) +from .module.ball_query import BQWarp +from .module.conv_layers import ConvBlock, CubeEmbedding +from .module.dgm_layers import DGMLayer +from .module.fourier_layers import ( + FourierFilter, + FourierLayer, + FourierMLP, + GaborFilter, + fourier_encode, +) +from .module.fully_connected_layers import ( + Conv1dFCLayer, + Conv2dFCLayer, + Conv3dFCLayer, + ConvNdFCLayer, + ConvNdKernel1Layer, + FCLayer, +) +from .module.kan_layers import KolmogorovArnoldNetwork +from .module.mlp_layers import Mlp +from .module.resample_layers import ( + DownSample2D, + DownSample3D, + UpSample2D, + UpSample3D, +) +from .module.siren_layers import SirenLayer, SirenLayerType +from .module.spectral_layers import ( + SpectralConv1d, + SpectralConv2d, + SpectralConv3d, + SpectralConv4d, +) +from .module.transformer_layers import ( + DecoderLayer, + EncoderLayer, + FuserLayer, + SwinTransformer, +) +from .module.weight_fact import WeightFactLinear +from .module.weight_norm import WeightNormLinear +# from physicsnemo.nn.module import FiniteDifferenceNd +from physicsnemo.nn.module.finite_difference import FiniteDifferenceNd \ No newline at end of file From b771f7dd91f43bf33d1ea903e286270f5bf803a2 Mon Sep 17 00:00:00 2001 From: Oliver Date: Wed, 3 Dec 2025 17:17:42 -0800 Subject: [PATCH 06/15] imports --- physicsnemo/nn/__init__.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/physicsnemo/nn/__init__.py b/physicsnemo/nn/__init__.py index ea97824557..1d1fcec859 100644 --- a/physicsnemo/nn/__init__.py +++ b/physicsnemo/nn/__init__.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Neural network building blocks for PhysicsNeMo.""" # Make physicsnemo.nn.Module an available import like torch.nn.Module from physicsnemo.core import Module @@ -66,6 +65,4 @@ SwinTransformer, ) from .module.weight_fact import WeightFactLinear -from .module.weight_norm import WeightNormLinear -# from physicsnemo.nn.module import FiniteDifferenceNd -from physicsnemo.nn.module.finite_difference import FiniteDifferenceNd \ No newline at end of file +from .module.weight_norm import WeightNormLinear \ No newline at end of file From af34e1d015d8bf5ed5cca5a0a922ebbda04ed19c Mon Sep 17 00:00:00 2001 From: Oliver Date: Wed, 3 Dec 2025 17:20:04 -0800 Subject: [PATCH 07/15] imports --- physicsnemo/core/function.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/physicsnemo/core/function.py b/physicsnemo/core/function.py index e777bfbda0..84bf189106 100644 --- a/physicsnemo/core/function.py +++ b/physicsnemo/core/function.py @@ -23,10 +23,4 @@ class Function(TorchAutogradFunction, BenchmarkMixin): """Base class for PhysicsNeMo custom autograd functions.""" - - @classmethod - def _benchmark_forward(cls, values, *rest): - cls.apply(values, *rest) - - -__all__ = ["Function"] + # Placeholder for utilities to bring in warp, fuser, etc. \ No newline at end of file From 13383643ca2b167683b772352a5178b7e1d27400 Mon Sep 17 00:00:00 2001 From: Oliver Date: Wed, 3 Dec 2025 17:20:41 -0800 Subject: [PATCH 08/15] imports --- physicsnemo/nn/functional/__init__.py | 20 +- .../nn/functional/finite_difference.py | 367 ------------------ 2 files changed, 15 insertions(+), 372 deletions(-) delete mode 100644 physicsnemo/nn/functional/finite_difference.py diff --git a/physicsnemo/nn/functional/__init__.py b/physicsnemo/nn/functional/__init__.py index 7d629d82e2..f48d21e283 100644 --- a/physicsnemo/nn/functional/__init__.py +++ b/physicsnemo/nn/functional/__init__.py @@ -1,5 +1,15 @@ -"""Functional operators for PhysicsNeMo.""" - -from .finite_difference import FiniteDifference - -__all__ = ["FiniteDifference"] +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/physicsnemo/nn/functional/finite_difference.py b/physicsnemo/nn/functional/finite_difference.py deleted file mode 100644 index d274de2e2f..0000000000 --- a/physicsnemo/nn/functional/finite_difference.py +++ /dev/null @@ -1,367 +0,0 @@ -"""Finite difference functional backed by Warp custom ops.""" - -from __future__ import annotations - -from typing import Any, Iterable, Sequence, Tuple - -import torch -from torch import Tensor - -try: - import warp as wp - wp.init() -except ImportError as err: # pragma: no cover - ImportError only raised in misconfigured envs. - raise ImportError( - "Warp is required for the finite difference functional. Install it from https://github.com/NVIDIA/warp" - ) from err - -from physicsnemo.core.function import Function - -wp.config.quiet = True - - -def _normalize_spacing(spacing: Sequence[float] | float, dims: int) -> Tuple[float, ...]: - if isinstance(spacing, Iterable) and not isinstance(spacing, (str, bytes)): - spacing_values = tuple(float(s) for s in spacing) - else: - spacing_values = (float(spacing),) - - if len(spacing_values) == 1 and dims > 1: - spacing_values = spacing_values * dims - - if len(spacing_values) != dims: - raise ValueError( - f"Spacing must provide {dims} value(s) for the spatial dimensions, got {spacing_values}" - ) - - if any(s <= 0.0 for s in spacing_values): - raise ValueError(f"Spacing values must be positive, got {spacing_values}") - - return spacing_values - - -def _prepare_input(values: Tensor, has_batch: bool) -> tuple[Tensor, bool]: - if has_batch: - return values, False - return values.unsqueeze(0), True - - -def _check_spatial_shape(spatial_shape: Tuple[int, ...]) -> None: - for size in spatial_shape: - if size < 3: - raise ValueError( - "Finite differences require at least three points per spatial dimension; " - f"got shape {spatial_shape}." - ) - - -def _get_wp_context(tensor: Tensor) -> tuple[wp.stream, str | None]: - if tensor.device.type == "cuda": - stream = wp.stream_from_torch(torch.cuda.current_stream(tensor.device)) - return stream, None - return None, "cpu" - - -def _launch_kernel(kernel, dim: int, inputs: list, stream: wp.stream, device: str | None) -> None: - wp.launch( - kernel, - dim=dim, - inputs=inputs, - stream=stream, - device=device, - ) - - -@wp.func -def _wrap_index(idx: int, size: int): - result = idx % size - if result < 0: - result += size - return result - - -@wp.kernel -def _finite_difference_1d_kernel( - values: wp.array(dtype=wp.float32, ndim=2), - gradients: wp.array(dtype=wp.float32, ndim=2), - batch: int, - dim0: int, - spacing0: float, -): - tid = wp.tid() - total = batch * dim0 - if tid >= total: - return - - b = tid // dim0 - i = tid - b * dim0 - - ip = _wrap_index(i + 1, dim0) - im = _wrap_index(i - 1, dim0) - - grad = (values[b, ip] - values[b, im]) / (2.0 * spacing0) - - gradients[b, i] = grad - - -@wp.kernel -def _finite_difference_2d_kernel( - values: wp.array(dtype=wp.float32, ndim=3), - gradients: wp.array(dtype=wp.vec2f, ndim=3), - batch: int, - dim0: int, - dim1: int, - spacing0: float, - spacing1: float, -): - tid = wp.tid() - total = batch * dim0 * dim1 - if tid >= total: - return - - plane = dim0 * dim1 - b = tid // plane - rem = tid - b * plane - i = rem // dim1 - j = rem - i * dim1 - - # Axis 0 derivative - ip = _wrap_index(i + 1, dim0) - im = _wrap_index(i - 1, dim0) - grad0 = (values[b, ip, j] - values[b, im, j]) / (2.0 * spacing0) - - # Axis 1 derivative - jp = _wrap_index(j + 1, dim1) - jm = _wrap_index(j - 1, dim1) - grad1 = (values[b, i, jp] - values[b, i, jm]) / (2.0 * spacing1) - - gradients[b, i, j] = wp.vec2f(grad0, grad1) - - -@wp.kernel -def _finite_difference_3d_kernel( - values: wp.array(dtype=wp.float32, ndim=4), - gradients: wp.array(dtype=wp.vec3f, ndim=4), - batch: int, - dim0: int, - dim1: int, - dim2: int, - spacing0: float, - spacing1: float, - spacing2: float, -): - tid = wp.tid() - total = batch * dim0 * dim1 * dim2 - if tid >= total: - return - - plane = dim1 * dim2 - volume = dim0 * plane - b = tid // volume - rem = tid - b * volume - i = rem // plane - rem = rem - i * plane - j = rem // dim2 - k = rem - j * dim2 - - # Axis 0 derivative - ip = _wrap_index(i + 1, dim0) - im = _wrap_index(i - 1, dim0) - grad0 = (values[b, ip, j, k] - values[b, im, j, k]) / (2.0 * spacing0) - - # Axis 1 derivative - jp = _wrap_index(j + 1, dim1) - jm = _wrap_index(j - 1, dim1) - grad1 = (values[b, i, jp, k] - values[b, i, jm, k]) / (2.0 * spacing1) - - # Axis 2 derivative - kp = _wrap_index(k + 1, dim2) - km = _wrap_index(k - 1, dim2) - grad2 = (values[b, i, j, kp] - values[b, i, j, km]) / (2.0 * spacing2) - - gradients[b, i, j, k] = wp.vec3f(grad0, grad1, grad2) - - -def _run_finite_difference( - values: Tensor, - spacing: Tuple[float, ...], - has_batch: bool, -) -> Tensor: - if not values.is_floating_point(): - raise TypeError("Finite differences require floating point tensors") - - dims = values.dim() - (1 if has_batch else 0) - if dims not in (1, 2, 3): - raise ValueError( - f"Finite differences support 1D, 2D, or 3D inputs, got tensor with {values.dim()} dims" - ) - - spacing_tuple = _normalize_spacing(spacing, dims) - - values32 = values.to(torch.float32).contiguous() - prepared, added_batch = _prepare_input(values32, has_batch) - batch = prepared.shape[0] - spatial_shape = tuple(int(s) for s in prepared.shape[1:]) - _check_spatial_shape(spatial_shape) - - if dims == 1: - grad_shape = (batch, *spatial_shape) - grad_dtype = wp.float32 - elif dims == 2: - grad_shape = (batch, *spatial_shape, dims) - grad_dtype = wp.vec2f - else: - grad_shape = (batch, *spatial_shape, dims) - grad_dtype = wp.vec3f - - gradients = torch.empty( - grad_shape, - device=values.device, - dtype=torch.float32, - ) - - stream, wp_device = _get_wp_context(prepared) - - with wp.ScopedStream(stream): - # wp_values = wp.from_torch(prepared, dtype=wp.float32, return_ctype=True) - wp_values = wp.from_torch(prepared, dtype=wp.float32) - # wp_grads = wp.from_torch(gradients, dtype=wp.float32, return_ctype=True) - wp_grads = wp.from_torch(gradients, dtype=grad_dtype) - - if dims == 1: - inputs = [wp_values, wp_grads, batch, spatial_shape[0], spacing_tuple[0]] - _launch_kernel( - _finite_difference_1d_kernel, - batch * spatial_shape[0], - inputs, - stream, - wp_device, - ) - elif dims == 2: - inputs = [ - wp_values, - wp_grads, - batch, - spatial_shape[0], - spatial_shape[1], - spacing_tuple[0], - spacing_tuple[1], - ] - _launch_kernel( - _finite_difference_2d_kernel, - batch * spatial_shape[0] * spatial_shape[1], - inputs, - stream, - wp_device, - ) - else: - inputs = [ - wp_values, - wp_grads, - batch, - spatial_shape[0], - spatial_shape[1], - spatial_shape[2], - spacing_tuple[0], - spacing_tuple[1], - spacing_tuple[2], - ] - _launch_kernel( - _finite_difference_3d_kernel, - batch * spatial_shape[0] * spatial_shape[1] * spatial_shape[2], - inputs, - stream, - wp_device, - ) - - if dims == 1: - gradients = gradients.unsqueeze(1) - else: - gradients = gradients.movedim(-1, 1) - - if added_batch: - gradients = gradients.squeeze(0) - - if values.dtype != torch.float32: - gradients = gradients.to(values.dtype) - - return gradients - - -@torch.library.custom_op("physicsnemo::finite_difference_nd", mutates_args=()) -def finite_difference_op( - values: Tensor, - spacing: Sequence[float], - has_batch: bool = True, -) -> Tensor: - return _run_finite_difference(values, spacing, has_batch) - - -@finite_difference_op.register_fake -def _finite_difference_fake( - values: Tensor, - spacing: Sequence[float], - has_batch: bool = True, -) -> Tensor: - dims = values.dim() - (1 if has_batch else 0) - if dims not in (1, 2, 3): - raise RuntimeError( - "Finite differences support only 1D, 2D, or 3D inputs in fake tensor mode" - ) - - spatial_shape = values.shape[1:] if has_batch else values.shape - if has_batch: - shape = (values.shape[0], dims, *spatial_shape) - else: - shape = (dims, *spatial_shape) - - return values.new_empty(shape) - - -class FiniteDifference(Function): - """Autograd wrapper around the custom finite difference op.""" - - @staticmethod - def forward(ctx, values: Tensor, spacing: Sequence[float], has_batch: bool = True): - output = torch.ops.physicsnemo.finite_difference_nd(values, spacing, has_batch) - ctx.mark_non_differentiable(output) - return output - - @staticmethod - def backward(ctx, *grad_outputs): # pragma: no cover - no autograd yet - return None, None, None - - @classmethod - def make_inputs( - cls, - ) -> Iterable[tuple[str, dict[str, Any], tuple[Tensor, Sequence[float], bool]]]: - configs = { - "1D": ((1, 1024*8), (1.0,)), - "2D": ((1, 256, 256), (1.0, 1.2)), - "3D": ((1, 64, 64, 64), (0.8, 1.0, 1.2)), - } - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - for label, (shape, spacing) in configs.items(): - values = torch.randn(shape, device=device, dtype=torch.float32) - yield label, {}, (values, spacing, True) - - @classmethod - def reference_impl( - cls, values: Tensor, spacing: Sequence[float], has_batch: bool = True - ) -> Tensor: - dims = values.dim() - (1 if has_batch else 0) - tensor = values if has_batch else values.unsqueeze(0) - grads = [] - for axis in range(1, dims + 1): - forward = torch.roll(tensor, shifts=-1, dims=axis) - backward = torch.roll(tensor, shifts=1, dims=axis) - grads.append((forward - backward) / (2.0 * spacing[axis - 1])) - stacked = torch.stack(grads, dim=1) - return stacked if has_batch else stacked.squeeze(0) - - @classmethod - def check(cls, actual: Tensor, expected: Tensor) -> None: - torch.testing.assert_close(actual, expected, rtol=1e-4, atol=1e-4) - - -__all__ = ["FiniteDifference"] From 6b117ddcf48623d05970ef7d95da8127a2e1d62a Mon Sep 17 00:00:00 2001 From: Oliver Date: Wed, 3 Dec 2025 17:21:31 -0800 Subject: [PATCH 09/15] imports --- physicsnemo/nn/module/__init__.py | 53 +------------------------------ 1 file changed, 1 insertion(+), 52 deletions(-) diff --git a/physicsnemo/nn/module/__init__.py b/physicsnemo/nn/module/__init__.py index 2bd0044494..f48d21e283 100644 --- a/physicsnemo/nn/module/__init__.py +++ b/physicsnemo/nn/module/__init__.py @@ -12,55 +12,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. - -from .activations import ( - CappedGELU, - CappedLeakyReLU, - Identity, - SquarePlus, - Stan, - get_activation, -) -from .ball_query import BQWarp -from .conv_layers import ConvBlock, CubeEmbedding -from .dgm_layers import DGMLayer -from .finite_difference import FiniteDifferenceNd -from .fourier_layers import ( - FourierFilter, - FourierLayer, - FourierMLP, - GaborFilter, - fourier_encode, -) -from .fully_connected_layers import ( - Conv1dFCLayer, - Conv2dFCLayer, - Conv3dFCLayer, - ConvNdFCLayer, - ConvNdKernel1Layer, - FCLayer, -) -from .kan_layers import KolmogorovArnoldNetwork -from .mlp_layers import Mlp -from .resample_layers import ( - DownSample2D, - DownSample3D, - UpSample2D, - UpSample3D, -) -from .siren_layers import SirenLayer, SirenLayerType -from .spectral_layers import ( - SpectralConv1d, - SpectralConv2d, - SpectralConv3d, - SpectralConv4d, -) -from .transformer_layers import ( - DecoderLayer, - EncoderLayer, - FuserLayer, - SwinTransformer, -) -from .weight_fact import WeightFactLinear -from .weight_norm import WeightNormLinear +# limitations under the License. \ No newline at end of file From 76b71aa66ae043dc572cc87d7b96183ff0082448 Mon Sep 17 00:00:00 2001 From: Oliver Date: Wed, 3 Dec 2025 17:22:33 -0800 Subject: [PATCH 10/15] imports --- physicsnemo/nn/module/finite_difference.py | 108 --------------------- 1 file changed, 108 deletions(-) delete mode 100644 physicsnemo/nn/module/finite_difference.py diff --git a/physicsnemo/nn/module/finite_difference.py b/physicsnemo/nn/module/finite_difference.py deleted file mode 100644 index 5dffdce303..0000000000 --- a/physicsnemo/nn/module/finite_difference.py +++ /dev/null @@ -1,108 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from typing import Any, Iterable, Sequence, Tuple - -import torch -from torch import Tensor - -from physicsnemo.core import Module -from physicsnemo.nn.functional.finite_difference import ( - FiniteDifference, - _normalize_spacing, -) - - -class FiniteDifferenceNd(Module): - """Finite-difference stencil implemented with Warp-backed functionals. - - Parameters - ---------- - spacing: - Grid spacing for each spatial dimension. Provide a single value to - reuse across all axes. - has_batch: - Whether the input tensors include a batch dimension as the first axis. - """ - - def __init__( - self, - spacing: Sequence[float] | float, - has_batch: bool = True, - ) -> None: - super().__init__() - if isinstance(spacing, Sequence) and not isinstance(spacing, (str, bytes)): - self.spacing: Sequence[float] | float = tuple(float(s) for s in spacing) - else: - self.spacing = float(spacing) - self.has_batch = has_batch - - def forward(self, values: Tensor) -> Tensor: - """Apply the finite difference stencil.""" - - if not torch.is_tensor(values): - raise TypeError("values must be a torch.Tensor") - dims = values.dim() - (1 if self.has_batch else 0) - spacing_tuple = _normalize_spacing(self.spacing, dims) - return FiniteDifference.apply(values, spacing_tuple, self.has_batch) - - @classmethod - def make_inputs( - cls, - ) -> Iterable[tuple[str, dict[str, Any], Tuple[Tensor]]]: - configs = { - "1D": ((1, 64), (1.0,)), - "2D": ((1, 32, 32), (1.0, 1.2)), - "3D": ((1, 16, 16, 16), (0.8, 1.0, 1.2)), - } - device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - for label, (shape, spacing) in configs.items(): - values = torch.randn(shape, device=device, dtype=torch.float32) - init_kwargs: dict[str, Any] = { - "spacing": spacing, - "has_batch": True, - } - yield label, init_kwargs, (values,) - - @classmethod - def reference_impl( - cls, values: Tensor, spacing: Sequence[float] | float, has_batch: bool = True - ) -> Tensor: - dims = values.dim() - (1 if has_batch else 0) - spacing_tuple = _normalize_spacing(spacing, dims) - return FiniteDifference.reference_impl(values, spacing_tuple, has_batch) - - @classmethod - def _resolve_reference_runner(cls, instance): - if instance is None: - return super()._resolve_reference_runner(instance) - - spacing = instance.spacing - has_batch = instance.has_batch - - def runner(values: Tensor) -> Tensor: - return cls.reference_impl(values, spacing, has_batch) - - return runner - - @classmethod - def check(cls, actual: Tensor, expected: Tensor) -> None: - FiniteDifference.check(actual, expected) - - -__all__ = ["FiniteDifferenceNd"] From c045c5828d3c72f5ab12fc0beb24a26308497e54 Mon Sep 17 00:00:00 2001 From: Oliver Date: Wed, 3 Dec 2025 17:23:42 -0800 Subject: [PATCH 11/15] imports --- physicsnemo/nn/utils/__init__.py | 21 --------------------- 1 file changed, 21 deletions(-) delete mode 100644 physicsnemo/nn/utils/__init__.py diff --git a/physicsnemo/nn/utils/__init__.py b/physicsnemo/nn/utils/__init__.py deleted file mode 100644 index 343d286e6e..0000000000 --- a/physicsnemo/nn/utils/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Backward-compatible proxy to the reorganized nn.module.utils package.""" - -from physicsnemo.nn.module.utils import * # noqa: F401,F403 - -__all__ = [name for name in globals().keys() if not name.startswith("_")] From f18617db7b626fee2e2c7f58a1e307f73d620fd8 Mon Sep 17 00:00:00 2001 From: Oliver Date: Wed, 3 Dec 2025 17:24:22 -0800 Subject: [PATCH 12/15] imports --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 903f2675a8..cd1c075fc6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -193,7 +193,6 @@ gnns = [ healpix = [ "earth2grid", ] -<<<<<<< HEAD dynamic = ["version"] [project.urls] From 530a2273357052d7f47c88846c1cae3b0e30ca36 Mon Sep 17 00:00:00 2001 From: Oliver Date: Wed, 3 Dec 2025 17:24:46 -0800 Subject: [PATCH 13/15] imports --- pyproject.toml | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cd1c075fc6..cb0178f3b0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -193,17 +193,6 @@ gnns = [ healpix = [ "earth2grid", ] -dynamic = ["version"] - -[project.urls] -Homepage = "https://github.com/NVIDIA/physicsnemo" -Documentation = "https://docs.nvidia.com/physicsnemo/index.html#core" -Issues = "https://github.com/NVIDIA/physicsnemo/issues" -Changelog = "https://github.com/NVIDIA/physicsnemo/blob/main/CHANGELOG.md" - -[tool.setuptools] -packages = ["physicsnemo"] - [tool.ruff] # Enable flake8/pycodestyle (`E`), Pyflakes (`F`), flake8-bandit (`S`), From de20bd9ac4f25074ffa21c083417d3a7455b16ff Mon Sep 17 00:00:00 2001 From: Oliver Date: Thu, 4 Dec 2025 11:47:12 -0800 Subject: [PATCH 14/15] update doc: --- CODING_STANDARDS/MODELS_IMPLEMENTATION.md | 62 ++++++++++++----------- 1 file changed, 32 insertions(+), 30 deletions(-) diff --git a/CODING_STANDARDS/MODELS_IMPLEMENTATION.md b/CODING_STANDARDS/MODELS_IMPLEMENTATION.md index 0000698f4e..0d40c54c9a 100644 --- a/CODING_STANDARDS/MODELS_IMPLEMENTATION.md +++ b/CODING_STANDARDS/MODELS_IMPLEMENTATION.md @@ -102,11 +102,10 @@ This document is structured in two main sections: **Description:** Reusable layers that are the building blocks of more complex architectures -should go into `physicsnemo/nn`. Those include for instance `FullyConnected`, +should go into `physicsnemo/nn/module`. Those include for instance `FullyConnected`, various variants of attention layers, `UNetBlock` (a block of a U-Net), etc. - -All layers that are directly exposed to the user should be imported in -`physicsnemo/nn/__init__.py`, such that they can be used as follows: +Implementations live in the `module/` subpackage but are re-exported from +`physicsnemo/nn/__init__.py` so the public import path stays stable: ```python from physicsnemo.nn import MyLayer @@ -125,13 +124,13 @@ and promotes code reuse across different models. **Example:** ```python -# Good: Reusable layer in physicsnemo/nn/attention.py +# Good: Reusable layer in physicsnemo/nn/module/attention_layers.py class MultiHeadAttention(Module): """A reusable attention layer that can be used in various architectures.""" pass # Good: Import in physicsnemo/nn/__init__.py -from physicsnemo.nn.attention import MultiHeadAttention +from physicsnemo.nn.module.attention_layers import MultiHeadAttention # Good: Example-specific layer in examples/weather/utils/nn.py class WeatherSpecificLayer(Module): @@ -145,7 +144,7 @@ class WeatherSpecificLayer(Module): # WRONG: Reusable layer placed in physicsnemo/models/ # File: physicsnemo/models/attention.py class MultiHeadAttention(Module): - """Should be in physicsnemo/nn/ not physicsnemo/models/""" + """Should be in physicsnemo/nn/module/ not physicsnemo/models/""" pass ``` @@ -192,10 +191,10 @@ from physicsnemo.models.transformer import TransformerModel **Anti-pattern:** ```python -# WRONG: Complete model placed in physicsnemo/nn/ -# File: physicsnemo/nn/transformer.py +# WRONG: Complete model placed in physicsnemo/nn/module/ +# File: physicsnemo/nn/module/transformer.py class TransformerModel(Module): - """Should be in physicsnemo/models/ not physicsnemo/nn/""" + """Should be in physicsnemo/models/ not physicsnemo/nn/module/""" pass ``` @@ -248,15 +247,16 @@ class MyModel(nn.Module): **Description:** For the vast majority of models, new classes are created either in -`physicsnemo/experimental/nn` for reusable layers, or in +`physicsnemo/experimental/nn/module` for reusable layers, or in `physicsnemo/experimental/models` for more complete models. The `experimental` folder is used to store models that are still under development (beta or alpha releases), where backward compatibility is not guaranteed. One exception is when the developer is highly confident that the model is sufficiently mature and applicable to many domains or use cases. In this case -the model class can be created in the `physicsnemo/nn` or `physicsnemo/models` -folders directly, and backward compatibility is guaranteed. +the model class can be created in the `physicsnemo/nn/module` (exposed through +`physicsnemo.nn`) or `physicsnemo/models` folders directly, and backward +compatibility is guaranteed. Another exception is when the model class is highly specific to a single example. In this case, it may be acceptable to place it in a module specific to @@ -264,9 +264,9 @@ the example code, such as `examples//utils/nn.py`. After staying in experimental for a sufficient amount of time (typically at least 1 release cycle), the model class can be promoted to production. It is -then moved to the `physicsnemo/nn` or `physicsnemo/models` folders, based on -whether it's a reusable layer (MOD-000a) or complete model (MOD-000b). During -the production stage, backward compatibility is guaranteed. +then moved to the `physicsnemo/nn/module` or `physicsnemo/models` folders, +based on whether it's a reusable layer (MOD-000a) or complete model (MOD-000b). +During the production stage, backward compatibility is guaranteed. **Note:** Per MOD-008a, MOD-008b, and MOD-008c, it is forbidden to move a model out of the experimental stage/directory without the required CI tests. @@ -309,9 +309,10 @@ class BrandNewModel(Module): **Description:** -For a model class being deprecated in `physicsnemo/nn` or `physicsnemo/models`, -the developer must add warning messages indicating that the model class is -deprecated and will be removed in a future release. +For a model class being deprecated in `physicsnemo/nn/module` (exposed via +`physicsnemo.nn`) or `physicsnemo/models`, the developer must add warning +messages indicating that the model class is deprecated and will be removed in a +future release. The warning message should be clear and concise, explaining why the model class is being deprecated and what the user should do instead. The deprecation message @@ -1306,9 +1307,9 @@ def forward( **Description:** -For any model in `physicsnemo/nn` or `physicsnemo/models`, adding new required -parameters (parameters without default values) to `__init__` or any public -method is strictly forbidden. This breaks backward compatibility. +For any model in `physicsnemo/nn/module` or `physicsnemo/models`, adding new +required parameters (parameters without default values) to `__init__` or any +public method is strictly forbidden. This breaks backward compatibility. New parameters must have default values to ensure existing code and checkpoints continue to work. If a new parameter is truly required, increment the model @@ -1363,8 +1364,9 @@ class MyModel(Module): **Description:** -For any model in `physicsnemo/nn` or `physicsnemo/models`, removing or renaming -parameters is strictly forbidden without proper backward compatibility support. +For any model in `physicsnemo/nn/module` or `physicsnemo/models`, removing or +renaming parameters is strictly forbidden without proper backward compatibility +support. If a parameter must be renamed or removed, the developer must: 1. Increment `__model_checkpoint_version__` @@ -1447,9 +1449,9 @@ class MyModel(Module): **Description:** -For any model in `physicsnemo/nn` or `physicsnemo/models`, changing the return -type of any public method (including `forward`) is strictly forbidden. This -includes: +For any model in `physicsnemo/nn/module` or `physicsnemo/models`, changing the +return type of any public method (including `forward`) is strictly forbidden. +This includes: - Changing from returning a single value to returning a tuple - Changing from a tuple to a single value - Changing the number of elements in a returned tuple @@ -1504,9 +1506,9 @@ class MyModel(Module): **Description:** -Every model in `physicsnemo/nn` or `physicsnemo/models` must have tests that -verify model instantiation and all public attributes (excluding buffers and -parameters). +Every model in `physicsnemo/nn/module` or `physicsnemo/models` must have tests +that verify model instantiation and all public attributes (excluding buffers +and parameters). These tests should: - Use `pytest` parameterization to test at least 2 configurations From 2cd8920a38f3a267a14bfa846c2e74c5ba3ae79e Mon Sep 17 00:00:00 2001 From: Oliver Date: Thu, 4 Dec 2025 16:10:20 -0800 Subject: [PATCH 15/15] updated cursor rules --- .../mod-000a-reusable-layers-belong-in-nn.mdc | 17 ++++++++--------- ...od-000b-complete-models-belong-in-models.mdc | 8 ++++---- ...perimental-models-belong-in-experimental.mdc | 14 ++++++++------ ...d-002b-add-deprecation-warnings-to-model.mdc | 9 +++++---- ...add-required-parameters-without-defaults.mdc | 6 +++--- ...-rename-parameters-without-compat-mapper.mdc | 5 +++-- ...ot-change-return-types-of-public-methods.mdc | 6 +++--- ...del-missing-constructor-attributes-tests.mdc | 6 +++--- 8 files changed, 37 insertions(+), 34 deletions(-) diff --git a/.cursor/rules/mod-000a-reusable-layers-belong-in-nn.mdc b/.cursor/rules/mod-000a-reusable-layers-belong-in-nn.mdc index d1f5d544bb..3e9142cdd2 100644 --- a/.cursor/rules/mod-000a-reusable-layers-belong-in-nn.mdc +++ b/.cursor/rules/mod-000a-reusable-layers-belong-in-nn.mdc @@ -1,20 +1,19 @@ --- -description: Reusable layers and building blocks should be placed in physicsnemo/nn, not physicsnemo/models. Examples include FullyConnected, attention layers, and UNetBlock. +description: Reusable layers and building blocks should be placed in physicsnemo/nn/module (exposed via physicsnemo.nn), not physicsnemo/models. Examples include FullyConnected, attention layers, and UNetBlock. alwaysApply: false --- -When creating or refactoring reusable layer code, rule MOD-000a must be followed. Explicitly reference "Following rule MOD-000a, which states that reusable layers should go in physicsnemo/nn..." when explaining placement decisions. +When creating or refactoring reusable layer code, rule MOD-000a must be followed. Explicitly reference "Following rule MOD-000a, which states that reusable layers should go in physicsnemo/nn/module (and be re-exported through physicsnemo.nn)..." when explaining placement decisions. ## MOD-000a: Reusable layers/blocks belong in physicsnemo.nn **Description:** Reusable layers that are the building blocks of more complex architectures -should go into `physicsnemo/nn`. Those include for instance `FullyConnected`, +should go into `physicsnemo/nn/module`. Those include for instance `FullyConnected`, various variants of attention layers, `UNetBlock` (a block of a U-Net), etc. - -All layers that are directly exposed to the user should be imported in -`physicsnemo/nn/__init__.py`, such that they can be used as follows: +Implementations live in the `module/` subpackage but are re-exported from +`physicsnemo/nn/__init__.py` so the public import path stays stable: ```python from physicsnemo.nn import MyLayer @@ -33,13 +32,13 @@ and promotes code reuse across different models. **Example:** ```python -# Good: Reusable layer in physicsnemo/nn/attention.py +# Good: Reusable layer in physicsnemo/nn/module/attention_layers.py class MultiHeadAttention(Module): """A reusable attention layer that can be used in various architectures.""" pass # Good: Import in physicsnemo/nn/__init__.py -from physicsnemo.nn.attention import MultiHeadAttention +from physicsnemo.nn.module.attention_layers import MultiHeadAttention # Good: Example-specific layer in examples/weather/utils/nn.py class WeatherSpecificLayer(Module): @@ -53,6 +52,6 @@ class WeatherSpecificLayer(Module): # WRONG: Reusable layer placed in physicsnemo/models/ # File: physicsnemo/models/attention.py class MultiHeadAttention(Module): - """Should be in physicsnemo/nn/ not physicsnemo/models/""" + """Should be in physicsnemo/nn/module/ not physicsnemo/models/""" pass ``` diff --git a/.cursor/rules/mod-000b-complete-models-belong-in-models.mdc b/.cursor/rules/mod-000b-complete-models-belong-in-models.mdc index 889bb4aae3..a817b57756 100644 --- a/.cursor/rules/mod-000b-complete-models-belong-in-models.mdc +++ b/.cursor/rules/mod-000b-complete-models-belong-in-models.mdc @@ -1,5 +1,5 @@ --- -description: Complete models composed of multiple layers should be placed in physicsnemo/models, not physicsnemo/nn. These are domain-specific or modality-specific models. +description: Complete models composed of multiple layers should be placed in physicsnemo/models, not physicsnemo/nn/module (exposed via physicsnemo.nn). These are domain-specific or modality-specific models. alwaysApply: false --- @@ -46,9 +46,9 @@ from physicsnemo.models.transformer import TransformerModel **Anti-pattern:** ```python -# WRONG: Complete model placed in physicsnemo/nn/ -# File: physicsnemo/nn/transformer.py +# WRONG: Complete model placed in physicsnemo/nn/module/ +# File: physicsnemo/nn/module/transformer.py class TransformerModel(Module): - """Should be in physicsnemo/models/ not physicsnemo/nn/""" + """Should be in physicsnemo/models/ not physicsnemo/nn/module/""" pass ``` diff --git a/.cursor/rules/mod-002a-experimental-models-belong-in-experimental.mdc b/.cursor/rules/mod-002a-experimental-models-belong-in-experimental.mdc index 44076619e5..0e2bacc571 100644 --- a/.cursor/rules/mod-002a-experimental-models-belong-in-experimental.mdc +++ b/.cursor/rules/mod-002a-experimental-models-belong-in-experimental.mdc @@ -1,5 +1,5 @@ --- -description: New model classes should start in physicsnemo/experimental/nn or physicsnemo/experimental/models during development, where backward compatibility is not guaranteed. +description: New model classes should start in physicsnemo/experimental/nn/module or physicsnemo/experimental/models during development, where backward compatibility is not guaranteed. alwaysApply: false --- @@ -10,15 +10,16 @@ When creating new model or layer classes, rule MOD-002a must be followed. Explic **Description:** For the vast majority of models, new classes are created either in -`physicsnemo/experimental/nn` for reusable layers, or in +`physicsnemo/experimental/nn/module` for reusable layers, or in `physicsnemo/experimental/models` for more complete models. The `experimental` folder is used to store models that are still under development (beta or alpha releases) during this stage, backward compatibility is not guaranteed. One exception is when the developer is highly confident that the model is sufficiently mature and applicable to many domains or use cases. In this case -the model class can be created in the `physicsnemo/nn` or `physicsnemo/models` -folders directly, and backward compatibility is guaranteed. +the model class can be created in the `physicsnemo/nn/module` (exposed through +`physicsnemo.nn`) or `physicsnemo/models` folders directly, and backward +compatibility is guaranteed. Another exception is when the model class is highly specific to a single example. In this case, it may be acceptable to place it in a module specific to @@ -26,8 +27,9 @@ the example code, such as `examples//utils/nn.py`. After staying in experimental for a sufficient amount of time (typically at least 1 release cycle), the model class can be promoted to production. It is -then moved to the `physicsnemo/nn` or `physicsnemo/models` folders, based on -whether it's a reusable layer or complete model (see MOD-000a and MOD-000b). +then moved to the `physicsnemo/nn/module` or `physicsnemo/models` folders, +based on whether it's a reusable layer or complete model (see MOD-000a and +MOD-000b). **Note:** Per MOD-008a, MOD-008b, and MOD-008c, it is forbidden to move a model out of the experimental stage/directory without the required CI tests. diff --git a/.cursor/rules/mod-002b-add-deprecation-warnings-to-model.mdc b/.cursor/rules/mod-002b-add-deprecation-warnings-to-model.mdc index 45e7f66883..d4b2303341 100644 --- a/.cursor/rules/mod-002b-add-deprecation-warnings-to-model.mdc +++ b/.cursor/rules/mod-002b-add-deprecation-warnings-to-model.mdc @@ -9,10 +9,11 @@ When deprecating a model class, rule MOD-002b must be followed. Explicitly refer **Description:** -For a model class in the pre-deprecation stage in `physicsnemo/nn` or -`physicsnemo/models`, the developer should start planning its deprecation. This -is done by adding a warning message to the model class, indicating that the -model class is deprecated and will be removed in a future release. +For a model class in the pre-deprecation stage in `physicsnemo/nn/module` +(exposed via `physicsnemo.nn`) or `physicsnemo/models`, the developer should +start planning its deprecation. This is done by adding a warning message to the +model class, indicating that the model class is deprecated and will be removed +in a future release. The warning message should be a clear and concise message that explains why the model class is being deprecated and what the user should do instead. The diff --git a/.cursor/rules/mod-007a-cannot-add-required-parameters-without-defaults.mdc b/.cursor/rules/mod-007a-cannot-add-required-parameters-without-defaults.mdc index 21e0cd8789..36cf4292a3 100644 --- a/.cursor/rules/mod-007a-cannot-add-required-parameters-without-defaults.mdc +++ b/.cursor/rules/mod-007a-cannot-add-required-parameters-without-defaults.mdc @@ -9,9 +9,9 @@ When adding parameters to production models, rule MOD-007a must be strictly foll **Description:** -For any model in `physicsnemo/nn` or `physicsnemo/models`, adding new required -parameters (parameters without default values) to `__init__` or any public -method is strictly forbidden. This breaks backward compatibility. +For any model in `physicsnemo/nn/module` or `physicsnemo/models`, adding new +required parameters (parameters without default values) to `__init__` or any +public method is strictly forbidden. This breaks backward compatibility. New parameters must have default values to ensure existing code and checkpoints continue to work. If a new parameter is truly required, increment the model diff --git a/.cursor/rules/mod-007b-cannot-remove-or-rename-parameters-without-compat-mapper.mdc b/.cursor/rules/mod-007b-cannot-remove-or-rename-parameters-without-compat-mapper.mdc index 9862e0e6ee..405138f4fb 100644 --- a/.cursor/rules/mod-007b-cannot-remove-or-rename-parameters-without-compat-mapper.mdc +++ b/.cursor/rules/mod-007b-cannot-remove-or-rename-parameters-without-compat-mapper.mdc @@ -9,8 +9,9 @@ When removing or renaming parameters in production models, rule MOD-007b must be **Description:** -For any model in `physicsnemo/nn` or `physicsnemo/models`, removing or renaming -parameters is strictly forbidden without proper backward compatibility support. +For any model in `physicsnemo/nn/module` or `physicsnemo/models`, removing or +renaming parameters is strictly forbidden without proper backward compatibility +support. If a parameter must be renamed or removed, the developer must: 1. Increment `__model_checkpoint_version__` diff --git a/.cursor/rules/mod-007c-cannot-change-return-types-of-public-methods.mdc b/.cursor/rules/mod-007c-cannot-change-return-types-of-public-methods.mdc index dd347c0ab6..b898cacfb5 100644 --- a/.cursor/rules/mod-007c-cannot-change-return-types-of-public-methods.mdc +++ b/.cursor/rules/mod-007c-cannot-change-return-types-of-public-methods.mdc @@ -9,9 +9,9 @@ When modifying public method return types, rule MOD-007c must be strictly follow **Description:** -For any model in `physicsnemo/nn` or `physicsnemo/models`, changing the return -type of any public method (including `forward`) is strictly forbidden. This -includes: +For any model in `physicsnemo/nn/module` or `physicsnemo/models`, changing the +return type of any public method (including `forward`) is strictly forbidden. +This includes: - Changing from returning a single value to returning a tuple - Changing from a tuple to a single value - Changing the number of elements in a returned tuple diff --git a/.cursor/rules/mod-008a-model-missing-constructor-attributes-tests.mdc b/.cursor/rules/mod-008a-model-missing-constructor-attributes-tests.mdc index b188c133ae..f8512ceb6c 100644 --- a/.cursor/rules/mod-008a-model-missing-constructor-attributes-tests.mdc +++ b/.cursor/rules/mod-008a-model-missing-constructor-attributes-tests.mdc @@ -9,9 +9,9 @@ When creating tests for models, rule MOD-008a must be followed. Explicitly refer **Description:** -Every model in `physicsnemo/nn` or `physicsnemo/models` must have tests that -verify model instantiation and all public attributes (excluding buffers and -parameters). +Every model in `physicsnemo/nn/module` or `physicsnemo/models` must have tests +that verify model instantiation and all public attributes (excluding buffers +and parameters). These tests should: - Use `pytest` parameterization to test at least 2 configurations